ghostflow_core/ops/
arithmetic.rs

1//! Element-wise arithmetic operations
2
3use crate::tensor::Tensor;
4use crate::error::Result;
5use rayon::prelude::*;
6
7impl Tensor {
8    /// Element-wise addition
9    pub fn add(&self, other: &Tensor) -> Result<Tensor> {
10        let a = self.data_f32();
11        let b = other.data_f32();
12        
13        // Handle broadcasting
14        let (result, shape) = broadcast_binary_op(&a, self.dims(), &b, other.dims(), |x, y| x + y)?;
15        Tensor::from_slice(&result, &shape)
16    }
17
18    /// Element-wise subtraction
19    pub fn sub(&self, other: &Tensor) -> Result<Tensor> {
20        let a = self.data_f32();
21        let b = other.data_f32();
22        
23        let (result, shape) = broadcast_binary_op(&a, self.dims(), &b, other.dims(), |x, y| x - y)?;
24        Tensor::from_slice(&result, &shape)
25    }
26
27    /// Element-wise multiplication
28    pub fn mul(&self, other: &Tensor) -> Result<Tensor> {
29        let a = self.data_f32();
30        let b = other.data_f32();
31        
32        let (result, shape) = broadcast_binary_op(&a, self.dims(), &b, other.dims(), |x, y| x * y)?;
33        Tensor::from_slice(&result, &shape)
34    }
35
36    /// Element-wise division
37    pub fn div(&self, other: &Tensor) -> Result<Tensor> {
38        let a = self.data_f32();
39        let b = other.data_f32();
40        
41        let (result, shape) = broadcast_binary_op(&a, self.dims(), &b, other.dims(), |x, y| x / y)?;
42        Tensor::from_slice(&result, &shape)
43    }
44
45    /// Add scalar
46    pub fn add_scalar(&self, scalar: f32) -> Tensor {
47        let data: Vec<f32> = self.data_f32().par_iter().map(|&x| x + scalar).collect();
48        Tensor::from_slice(&data, self.dims()).unwrap()
49    }
50
51    /// Subtract scalar
52    pub fn sub_scalar(&self, scalar: f32) -> Tensor {
53        let data: Vec<f32> = self.data_f32().par_iter().map(|&x| x - scalar).collect();
54        Tensor::from_slice(&data, self.dims()).unwrap()
55    }
56
57    /// Multiply by scalar
58    pub fn mul_scalar(&self, scalar: f32) -> Tensor {
59        let data: Vec<f32> = self.data_f32().par_iter().map(|&x| x * scalar).collect();
60        Tensor::from_slice(&data, self.dims()).unwrap()
61    }
62
63    /// Divide by scalar
64    pub fn div_scalar(&self, scalar: f32) -> Tensor {
65        let data: Vec<f32> = self.data_f32().par_iter().map(|&x| x / scalar).collect();
66        Tensor::from_slice(&data, self.dims()).unwrap()
67    }
68
69    /// Negation
70    pub fn neg(&self) -> Tensor {
71        let data: Vec<f32> = self.data_f32().par_iter().map(|&x| -x).collect();
72        Tensor::from_slice(&data, self.dims()).unwrap()
73    }
74
75    /// Absolute value
76    pub fn abs(&self) -> Tensor {
77        let data: Vec<f32> = self.data_f32().par_iter().map(|&x| x.abs()).collect();
78        Tensor::from_slice(&data, self.dims()).unwrap()
79    }
80
81    /// Power
82    pub fn pow(&self, exp: f32) -> Tensor {
83        let data: Vec<f32> = self.data_f32().par_iter().map(|&x| x.powf(exp)).collect();
84        Tensor::from_slice(&data, self.dims()).unwrap()
85    }
86
87    /// Square root
88    pub fn sqrt(&self) -> Tensor {
89        let data: Vec<f32> = self.data_f32().par_iter().map(|&x| x.sqrt()).collect();
90        Tensor::from_slice(&data, self.dims()).unwrap()
91    }
92
93    /// Exponential
94    pub fn exp(&self) -> Tensor {
95        let data: Vec<f32> = self.data_f32().par_iter().map(|&x| x.exp()).collect();
96        Tensor::from_slice(&data, self.dims()).unwrap()
97    }
98
99    /// Natural logarithm
100    pub fn log(&self) -> Tensor {
101        let data: Vec<f32> = self.data_f32().par_iter().map(|&x| x.ln()).collect();
102        Tensor::from_slice(&data, self.dims()).unwrap()
103    }
104
105    /// Clamp values to range
106    pub fn clamp(&self, min: f32, max: f32) -> Tensor {
107        let data: Vec<f32> = self.data_f32().par_iter().map(|&x| x.clamp(min, max)).collect();
108        Tensor::from_slice(&data, self.dims()).unwrap()
109    }
110}
111
112/// Broadcast and apply binary operation
113fn broadcast_binary_op<F>(
114    a: &[f32],
115    a_shape: &[usize],
116    b: &[f32],
117    b_shape: &[usize],
118    op: F,
119) -> Result<(Vec<f32>, Vec<usize>)>
120where
121    F: Fn(f32, f32) -> f32 + Sync,
122{
123    use crate::shape::Shape;
124    
125    let shape_a = Shape::new(a_shape);
126    let shape_b = Shape::new(b_shape);
127    let result_shape = shape_a.broadcast_with(&shape_b)?;
128    let result_dims = result_shape.dims().to_vec();
129    let numel = result_shape.numel();
130
131    // Fast path: same shape
132    if a_shape == b_shape {
133        let result: Vec<f32> = a.par_iter()
134            .zip(b.par_iter())
135            .map(|(&x, &y)| op(x, y))
136            .collect();
137        return Ok((result, result_dims));
138    }
139
140    // Broadcast path
141    let a_strides = compute_broadcast_strides(a_shape, &result_dims);
142    let b_strides = compute_broadcast_strides(b_shape, &result_dims);
143
144    let result: Vec<f32> = (0..numel)
145        .into_par_iter()
146        .map(|i| {
147            let a_idx = compute_broadcast_index(i, &result_dims, &a_strides);
148            let b_idx = compute_broadcast_index(i, &result_dims, &b_strides);
149            op(a[a_idx], b[b_idx])
150        })
151        .collect();
152
153    Ok((result, result_dims))
154}
155
156/// Compute strides for broadcasting
157fn compute_broadcast_strides(shape: &[usize], target_shape: &[usize]) -> Vec<usize> {
158    let ndim = target_shape.len();
159    let offset = ndim - shape.len();
160    
161    let mut strides = vec![0usize; ndim];
162    let mut stride = 1usize;
163    
164    for i in (0..shape.len()).rev() {
165        if shape[i] == target_shape[i + offset] {
166            strides[i + offset] = stride;
167            stride *= shape[i];
168        } else {
169            // Broadcast dimension (size 1)
170            strides[i + offset] = 0;
171        }
172    }
173    
174    strides
175}
176
177/// Compute source index for broadcast
178fn compute_broadcast_index(flat_idx: usize, shape: &[usize], strides: &[usize]) -> usize {
179    let mut idx = 0;
180    let mut remaining = flat_idx;
181    
182    for i in (0..shape.len()).rev() {
183        let coord = remaining % shape[i];
184        remaining /= shape[i];
185        idx += coord * strides[i];
186    }
187    
188    idx
189}
190
191// Operator overloads
192impl std::ops::Add for &Tensor {
193    type Output = Tensor;
194    fn add(self, other: &Tensor) -> Tensor {
195        self.add(other).unwrap()
196    }
197}
198
199impl std::ops::Sub for &Tensor {
200    type Output = Tensor;
201    fn sub(self, other: &Tensor) -> Tensor {
202        self.sub(other).unwrap()
203    }
204}
205
206impl std::ops::Mul for &Tensor {
207    type Output = Tensor;
208    fn mul(self, other: &Tensor) -> Tensor {
209        self.mul(other).unwrap()
210    }
211}
212
213impl std::ops::Div for &Tensor {
214    type Output = Tensor;
215    fn div(self, other: &Tensor) -> Tensor {
216        self.div(other).unwrap()
217    }
218}
219
220impl std::ops::Neg for &Tensor {
221    type Output = Tensor;
222    fn neg(self) -> Tensor {
223        self.neg()
224    }
225}
226
227#[cfg(test)]
228mod tests {
229    use super::*;
230
231    #[test]
232    fn test_add() {
233        let a = Tensor::from_slice(&[1.0f32, 2.0, 3.0], &[3]).unwrap();
234        let b = Tensor::from_slice(&[4.0f32, 5.0, 6.0], &[3]).unwrap();
235        let c = a.add(&b).unwrap();
236        assert_eq!(c.data_f32(), vec![5.0, 7.0, 9.0]);
237    }
238
239    #[test]
240    fn test_broadcast_add() {
241        let a = Tensor::from_slice(&[1.0f32, 2.0, 3.0], &[3, 1]).unwrap();
242        let b = Tensor::from_slice(&[10.0f32, 20.0], &[1, 2]).unwrap();
243        let c = a.add(&b).unwrap();
244        assert_eq!(c.dims(), &[3, 2]);
245    }
246
247    #[test]
248    fn test_scalar_ops() {
249        let a = Tensor::from_slice(&[1.0f32, 2.0, 3.0], &[3]).unwrap();
250        let b = a.mul_scalar(2.0);
251        assert_eq!(b.data_f32(), vec![2.0, 4.0, 6.0]);
252    }
253}