ghostflow_core/ops/
arithmetic.rs

1//! Element-wise arithmetic operations
2
3use crate::tensor::Tensor;
4use crate::error::Result;
5#[cfg(feature = "rayon")]
6use rayon::prelude::*;
7
8// Macro to conditionally use parallel or sequential iteration
9macro_rules! map_elements {
10    ($data:expr, $op:expr) => {{
11        #[cfg(feature = "rayon")]
12        { $data.par_iter().map($op).collect() }
13        #[cfg(not(feature = "rayon"))]
14        { $data.iter().map($op).collect() }
15    }};
16}
17
18impl Tensor {
19    /// Element-wise addition
20    pub fn add(&self, other: &Tensor) -> Result<Tensor> {
21        let a = self.data_f32();
22        let b = other.data_f32();
23        
24        // Handle broadcasting
25        let (result, shape) = broadcast_binary_op(&a, self.dims(), &b, other.dims(), |x, y| x + y)?;
26        Tensor::from_slice(&result, &shape)
27    }
28
29    /// Element-wise subtraction
30    pub fn sub(&self, other: &Tensor) -> Result<Tensor> {
31        let a = self.data_f32();
32        let b = other.data_f32();
33        
34        let (result, shape) = broadcast_binary_op(&a, self.dims(), &b, other.dims(), |x, y| x - y)?;
35        Tensor::from_slice(&result, &shape)
36    }
37
38    /// Element-wise multiplication
39    pub fn mul(&self, other: &Tensor) -> Result<Tensor> {
40        let a = self.data_f32();
41        let b = other.data_f32();
42        
43        let (result, shape) = broadcast_binary_op(&a, self.dims(), &b, other.dims(), |x, y| x * y)?;
44        Tensor::from_slice(&result, &shape)
45    }
46
47    /// Element-wise division
48    pub fn div(&self, other: &Tensor) -> Result<Tensor> {
49        let a = self.data_f32();
50        let b = other.data_f32();
51        
52        let (result, shape) = broadcast_binary_op(&a, self.dims(), &b, other.dims(), |x, y| x / y)?;
53        Tensor::from_slice(&result, &shape)
54    }
55
56    /// Add scalar
57    pub fn add_scalar(&self, scalar: f32) -> Tensor {
58        let data: Vec<f32> = map_elements!(self.data_f32(), |&x| x + scalar);
59        Tensor::from_slice(&data, self.dims()).unwrap()
60    }
61
62    /// Subtract scalar
63    pub fn sub_scalar(&self, scalar: f32) -> Tensor {
64        let data: Vec<f32> = map_elements!(self.data_f32(), |&x| x - scalar);
65        Tensor::from_slice(&data, self.dims()).unwrap()
66    }
67
68    /// Multiply by scalar
69    pub fn mul_scalar(&self, scalar: f32) -> Tensor {
70        let data: Vec<f32> = map_elements!(self.data_f32(), |&x| x * scalar);
71        Tensor::from_slice(&data, self.dims()).unwrap()
72    }
73
74    /// Divide by scalar
75    pub fn div_scalar(&self, scalar: f32) -> Tensor {
76        let data: Vec<f32> = map_elements!(self.data_f32(), |&x| x / scalar);
77        Tensor::from_slice(&data, self.dims()).unwrap()
78    }
79
80    /// Negation
81    pub fn neg(&self) -> Tensor {
82        let data: Vec<f32> = map_elements!(self.data_f32(), |&x| -x);
83        Tensor::from_slice(&data, self.dims()).unwrap()
84    }
85
86    /// Absolute value
87    pub fn abs(&self) -> Tensor {
88        let data: Vec<f32> = map_elements!(self.data_f32(), |&x| x.abs());
89        Tensor::from_slice(&data, self.dims()).unwrap()
90    }
91
92    /// Power
93    pub fn pow(&self, exp: f32) -> Tensor {
94        let data: Vec<f32> = map_elements!(self.data_f32(), |&x| x.powf(exp));
95        Tensor::from_slice(&data, self.dims()).unwrap()
96    }
97
98    /// Square root
99    pub fn sqrt(&self) -> Tensor {
100        let data: Vec<f32> = map_elements!(self.data_f32(), |&x| x.sqrt());
101        Tensor::from_slice(&data, self.dims()).unwrap()
102    }
103
104    /// Exponential
105    pub fn exp(&self) -> Tensor {
106        let data: Vec<f32> = map_elements!(self.data_f32(), |&x| x.exp());
107        Tensor::from_slice(&data, self.dims()).unwrap()
108    }
109
110    /// Natural logarithm
111    pub fn log(&self) -> Tensor {
112        let data: Vec<f32> = map_elements!(self.data_f32(), |&x| x.ln());
113        Tensor::from_slice(&data, self.dims()).unwrap()
114    }
115
116    /// Clamp values to range
117    pub fn clamp(&self, min: f32, max: f32) -> Tensor {
118        let data: Vec<f32> = map_elements!(self.data_f32(), |&x| x.clamp(min, max));
119        Tensor::from_slice(&data, self.dims()).unwrap()
120    }
121}
122
123/// Broadcast and apply binary operation
124fn broadcast_binary_op<F>(
125    a: &[f32],
126    a_shape: &[usize],
127    b: &[f32],
128    b_shape: &[usize],
129    op: F,
130) -> Result<(Vec<f32>, Vec<usize>)>
131where
132    F: Fn(f32, f32) -> f32 + Sync,
133{
134    use crate::shape::Shape;
135    
136    let shape_a = Shape::new(a_shape);
137    let shape_b = Shape::new(b_shape);
138    let result_shape = shape_a.broadcast_with(&shape_b)?;
139    let result_dims = result_shape.dims().to_vec();
140    let numel = result_shape.numel();
141
142    // Fast path: same shape
143    if a_shape == b_shape {
144        #[cfg(feature = "rayon")]
145        let result: Vec<f32> = a.par_iter()
146            .zip(b.par_iter())
147            .map(|(&x, &y)| op(x, y))
148            .collect();
149        #[cfg(not(feature = "rayon"))]
150        let result: Vec<f32> = a.iter()
151            .zip(b.iter())
152            .map(|(&x, &y)| op(x, y))
153            .collect();
154        return Ok((result, result_dims));
155    }
156
157    // Broadcast path
158    let a_strides = compute_broadcast_strides(a_shape, &result_dims);
159    let b_strides = compute_broadcast_strides(b_shape, &result_dims);
160
161    #[cfg(feature = "rayon")]
162    let result: Vec<f32> = (0..numel)
163        .into_par_iter()
164        .map(|i| {
165            let a_idx = compute_broadcast_index(i, &result_dims, &a_strides);
166            let b_idx = compute_broadcast_index(i, &result_dims, &b_strides);
167            op(a[a_idx], b[b_idx])
168        })
169        .collect();
170    #[cfg(not(feature = "rayon"))]
171    let result: Vec<f32> = (0..numel)
172        .map(|i| {
173            let a_idx = compute_broadcast_index(i, &result_dims, &a_strides);
174            let b_idx = compute_broadcast_index(i, &result_dims, &b_strides);
175            op(a[a_idx], b[b_idx])
176        })
177        .collect();
178
179    Ok((result, result_dims))
180}
181
182/// Compute strides for broadcasting
183fn compute_broadcast_strides(shape: &[usize], target_shape: &[usize]) -> Vec<usize> {
184    let ndim = target_shape.len();
185    let offset = ndim - shape.len();
186    
187    let mut strides = vec![0usize; ndim];
188    let mut stride = 1usize;
189    
190    for i in (0..shape.len()).rev() {
191        if shape[i] == target_shape[i + offset] {
192            strides[i + offset] = stride;
193            stride *= shape[i];
194        } else {
195            // Broadcast dimension (size 1)
196            strides[i + offset] = 0;
197        }
198    }
199    
200    strides
201}
202
203/// Compute source index for broadcast
204fn compute_broadcast_index(flat_idx: usize, shape: &[usize], strides: &[usize]) -> usize {
205    let mut idx = 0;
206    let mut remaining = flat_idx;
207    
208    for i in (0..shape.len()).rev() {
209        let coord = remaining % shape[i];
210        remaining /= shape[i];
211        idx += coord * strides[i];
212    }
213    
214    idx
215}
216
217// Operator overloads
218impl std::ops::Add for &Tensor {
219    type Output = Tensor;
220    fn add(self, other: &Tensor) -> Tensor {
221        self.add(other).unwrap()
222    }
223}
224
225impl std::ops::Sub for &Tensor {
226    type Output = Tensor;
227    fn sub(self, other: &Tensor) -> Tensor {
228        self.sub(other).unwrap()
229    }
230}
231
232impl std::ops::Mul for &Tensor {
233    type Output = Tensor;
234    fn mul(self, other: &Tensor) -> Tensor {
235        self.mul(other).unwrap()
236    }
237}
238
239impl std::ops::Div for &Tensor {
240    type Output = Tensor;
241    fn div(self, other: &Tensor) -> Tensor {
242        self.div(other).unwrap()
243    }
244}
245
246impl std::ops::Neg for &Tensor {
247    type Output = Tensor;
248    fn neg(self) -> Tensor {
249        self.neg()
250    }
251}
252
253#[cfg(test)]
254mod tests {
255    use super::*;
256
257    #[test]
258    fn test_add() {
259        let a = Tensor::from_slice(&[1.0f32, 2.0, 3.0], &[3]).unwrap();
260        let b = Tensor::from_slice(&[4.0f32, 5.0, 6.0], &[3]).unwrap();
261        let c = a.add(&b).unwrap();
262        assert_eq!(c.data_f32(), vec![5.0, 7.0, 9.0]);
263    }
264
265    #[test]
266    fn test_broadcast_add() {
267        let a = Tensor::from_slice(&[1.0f32, 2.0, 3.0], &[3, 1]).unwrap();
268        let b = Tensor::from_slice(&[10.0f32, 20.0], &[1, 2]).unwrap();
269        let c = a.add(&b).unwrap();
270        assert_eq!(c.dims(), &[3, 2]);
271    }
272
273    #[test]
274    fn test_scalar_ops() {
275        let a = Tensor::from_slice(&[1.0f32, 2.0, 3.0], &[3]).unwrap();
276        let b = a.mul_scalar(2.0);
277        assert_eq!(b.data_f32(), vec![2.0, 4.0, 6.0]);
278    }
279}