Skip to main content

ronn_core/ops/
arithmetic.rs

1//! Arithmetic tensor operations with broadcasting support.
2//!
3//! This module provides element-wise arithmetic operations (Add, Sub, Mul, Div)
4//! with full broadcasting support using the Candle backend.
5
6use crate::tensor::Tensor;
7use anyhow::{Result, anyhow};
8
9/// Trait for arithmetic operations on tensors.
10pub trait ArithmeticOps {
11    /// Element-wise addition with broadcasting.
12    fn add(&self, other: &Tensor) -> Result<Tensor>;
13
14    /// Element-wise subtraction with broadcasting.
15    fn sub(&self, other: &Tensor) -> Result<Tensor>;
16
17    /// Element-wise multiplication with broadcasting.
18    fn mul(&self, other: &Tensor) -> Result<Tensor>;
19
20    /// Element-wise division with broadcasting.
21    fn div(&self, other: &Tensor) -> Result<Tensor>;
22
23    /// Add a scalar to all elements.
24    fn add_scalar(&self, scalar: f32) -> Result<Tensor>;
25
26    /// Subtract a scalar from all elements.
27    fn sub_scalar(&self, scalar: f32) -> Result<Tensor>;
28
29    /// Multiply all elements by a scalar.
30    fn mul_scalar(&self, scalar: f32) -> Result<Tensor>;
31
32    /// Divide all elements by a scalar.
33    fn div_scalar(&self, scalar: f32) -> Result<Tensor>;
34
35    /// Element-wise negation.
36    fn neg(&self) -> Result<Tensor>;
37
38    /// Element-wise absolute value.
39    fn abs(&self) -> Result<Tensor>;
40
41    /// Element-wise power operation.
42    fn pow(&self, exponent: f32) -> Result<Tensor>;
43
44    /// Element-wise square root.
45    fn sqrt(&self) -> Result<Tensor>;
46
47    /// Element-wise exponential.
48    fn exp(&self) -> Result<Tensor>;
49
50    /// Element-wise natural logarithm.
51    fn log(&self) -> Result<Tensor>;
52}
53
54impl ArithmeticOps for Tensor {
55    fn add(&self, other: &Tensor) -> Result<Tensor> {
56        // Check if tensors can be broadcast together
57        if !self.is_broadcastable_with(other) {
58            return Err(anyhow!(
59                "Cannot broadcast tensors with shapes {:?} and {:?}",
60                self.shape(),
61                other.shape()
62            ));
63        }
64
65        let result_candle = self.candle_tensor().broadcast_add(other.candle_tensor())?;
66
67        Ok(Tensor::from_candle(
68            result_candle,
69            self.dtype(),
70            self.layout(),
71        ))
72    }
73
74    fn sub(&self, other: &Tensor) -> Result<Tensor> {
75        if !self.is_broadcastable_with(other) {
76            return Err(anyhow!(
77                "Cannot broadcast tensors with shapes {:?} and {:?}",
78                self.shape(),
79                other.shape()
80            ));
81        }
82
83        let result_candle = self.candle_tensor().broadcast_sub(other.candle_tensor())?;
84
85        Ok(Tensor::from_candle(
86            result_candle,
87            self.dtype(),
88            self.layout(),
89        ))
90    }
91
92    fn mul(&self, other: &Tensor) -> Result<Tensor> {
93        if !self.is_broadcastable_with(other) {
94            return Err(anyhow!(
95                "Cannot broadcast tensors with shapes {:?} and {:?}",
96                self.shape(),
97                other.shape()
98            ));
99        }
100
101        let result_candle = self.candle_tensor().broadcast_mul(other.candle_tensor())?;
102
103        Ok(Tensor::from_candle(
104            result_candle,
105            self.dtype(),
106            self.layout(),
107        ))
108    }
109
110    fn div(&self, other: &Tensor) -> Result<Tensor> {
111        if !self.is_broadcastable_with(other) {
112            return Err(anyhow!(
113                "Cannot broadcast tensors with shapes {:?} and {:?}",
114                self.shape(),
115                other.shape()
116            ));
117        }
118
119        let result_candle = self.candle_tensor().broadcast_div(other.candle_tensor())?;
120
121        Ok(Tensor::from_candle(
122            result_candle,
123            self.dtype(),
124            self.layout(),
125        ))
126    }
127
128    fn add_scalar(&self, scalar: f32) -> Result<Tensor> {
129        let result_candle = (self.candle_tensor() + scalar as f64)?;
130
131        Ok(Tensor::from_candle(
132            result_candle,
133            self.dtype(),
134            self.layout(),
135        ))
136    }
137
138    fn sub_scalar(&self, scalar: f32) -> Result<Tensor> {
139        let result_candle = (self.candle_tensor() - scalar as f64)?;
140
141        Ok(Tensor::from_candle(
142            result_candle,
143            self.dtype(),
144            self.layout(),
145        ))
146    }
147
148    fn mul_scalar(&self, scalar: f32) -> Result<Tensor> {
149        let result_candle = (self.candle_tensor() * scalar as f64)?;
150
151        Ok(Tensor::from_candle(
152            result_candle,
153            self.dtype(),
154            self.layout(),
155        ))
156    }
157
158    fn div_scalar(&self, scalar: f32) -> Result<Tensor> {
159        if scalar == 0.0 {
160            return Err(anyhow!("Division by zero"));
161        }
162
163        let result_candle = (self.candle_tensor() / scalar as f64)?;
164
165        Ok(Tensor::from_candle(
166            result_candle,
167            self.dtype(),
168            self.layout(),
169        ))
170    }
171
172    fn neg(&self) -> Result<Tensor> {
173        let result_candle = self.candle_tensor().neg()?;
174
175        Ok(Tensor::from_candle(
176            result_candle,
177            self.dtype(),
178            self.layout(),
179        ))
180    }
181
182    fn abs(&self) -> Result<Tensor> {
183        let result_candle = self.candle_tensor().abs()?;
184
185        Ok(Tensor::from_candle(
186            result_candle,
187            self.dtype(),
188            self.layout(),
189        ))
190    }
191
192    fn pow(&self, exponent: f32) -> Result<Tensor> {
193        let result_candle = self.candle_tensor().powf(exponent as f64)?;
194
195        Ok(Tensor::from_candle(
196            result_candle,
197            self.dtype(),
198            self.layout(),
199        ))
200    }
201
202    fn sqrt(&self) -> Result<Tensor> {
203        let result_candle = self.candle_tensor().sqrt()?;
204
205        Ok(Tensor::from_candle(
206            result_candle,
207            self.dtype(),
208            self.layout(),
209        ))
210    }
211
212    fn exp(&self) -> Result<Tensor> {
213        let result_candle = self.candle_tensor().exp()?;
214
215        Ok(Tensor::from_candle(
216            result_candle,
217            self.dtype(),
218            self.layout(),
219        ))
220    }
221
222    fn log(&self) -> Result<Tensor> {
223        let result_candle = self.candle_tensor().log()?;
224
225        Ok(Tensor::from_candle(
226            result_candle,
227            self.dtype(),
228            self.layout(),
229        ))
230    }
231}
232
233/// Convenience functions for arithmetic operations.
234impl Tensor {
235    /// Clamp tensor values between min and max.
236    pub fn clamp(&self, min: f32, max: f32) -> Result<Tensor> {
237        if min > max {
238            return Err(anyhow!(
239                "Min value {} is greater than max value {}",
240                min,
241                max
242            ));
243        }
244
245        let result_candle = self.candle_tensor().clamp(min as f64, max as f64)?;
246
247        Ok(Tensor::from_candle(
248            result_candle,
249            self.dtype(),
250            self.layout(),
251        ))
252    }
253
254    /// Apply ReLU activation function.
255    pub fn relu(&self) -> Result<Tensor> {
256        self.clamp(0.0, f32::INFINITY)
257    }
258
259    /// Apply Sigmoid activation function.
260    pub fn sigmoid(&self) -> Result<Tensor> {
261        // sigmoid(x) = 1 / (1 + exp(-x))
262        let neg_x = self.neg()?;
263        let exp_neg_x = neg_x.exp()?;
264        let one = Tensor::ones(vec![1], self.dtype(), self.layout())?;
265        let one_plus_exp = one.add(&exp_neg_x)?;
266        one.div(&one_plus_exp)
267    }
268
269    /// Apply Tanh activation function.
270    pub fn tanh(&self) -> Result<Tensor> {
271        let result_candle = self.candle_tensor().tanh()?;
272
273        Ok(Tensor::from_candle(
274            result_candle,
275            self.dtype(),
276            self.layout(),
277        ))
278    }
279
280    /// Apply GELU activation function.
281    pub fn gelu(&self) -> Result<Tensor> {
282        // GELU(x) = x * Φ(x) where Φ is the standard Gaussian CDF
283        // Approximation: GELU(x) ≈ 0.5 * x * (1 + tanh(sqrt(2/π) * (x + 0.044715 * x³)))
284        let x = self;
285        let x_cubed = x.pow(3.0)?;
286        let term1 = x_cubed.mul_scalar(0.044715)?;
287        let term2 = x.add(&term1)?;
288        let sqrt_2_over_pi = (2.0 / std::f32::consts::PI).sqrt();
289        let term3 = term2.mul_scalar(sqrt_2_over_pi)?;
290        let tanh_term = term3.tanh()?;
291        let one = Tensor::ones(vec![1], self.dtype(), self.layout())?;
292        let one_plus_tanh = one.add(&tanh_term)?;
293        let half = Tensor::from_data(vec![0.5], vec![1], self.dtype(), self.layout())?;
294        let result = x.mul(&half)?.mul(&one_plus_tanh)?;
295        Ok(result)
296    }
297}
298
299#[cfg(test)]
300mod tests {
301    use super::*;
302    use crate::types::{DataType, TensorLayout};
303
304    #[test]
305    fn test_arithmetic_operations() -> Result<()> {
306        let a = Tensor::from_data(
307            vec![1.0, 2.0, 3.0, 4.0],
308            vec![2, 2],
309            DataType::F32,
310            TensorLayout::RowMajor,
311        )?;
312        let b = Tensor::from_data(
313            vec![2.0, 1.0, 1.0, 2.0],
314            vec![2, 2],
315            DataType::F32,
316            TensorLayout::RowMajor,
317        )?;
318
319        // Test addition
320        let sum = a.add(&b)?;
321        let sum_data = sum.to_vec()?;
322        assert_eq!(sum_data, vec![3.0, 3.0, 4.0, 6.0]);
323
324        // Test subtraction
325        let diff = a.sub(&b)?;
326        let diff_data = diff.to_vec()?;
327        assert_eq!(diff_data, vec![-1.0, 1.0, 2.0, 2.0]);
328
329        // Test multiplication
330        let product = a.mul(&b)?;
331        let product_data = product.to_vec()?;
332        assert_eq!(product_data, vec![2.0, 2.0, 3.0, 8.0]);
333
334        // Test division
335        let quotient = a.div(&b)?;
336        let quotient_data = quotient.to_vec()?;
337        assert_eq!(quotient_data, vec![0.5, 2.0, 3.0, 2.0]);
338
339        Ok(())
340    }
341
342    #[test]
343    fn test_scalar_operations() -> Result<()> {
344        let a = Tensor::from_data(
345            vec![1.0, 2.0, 3.0, 4.0],
346            vec![2, 2],
347            DataType::F32,
348            TensorLayout::RowMajor,
349        )?;
350
351        // Test scalar addition
352        let sum = a.add_scalar(5.0)?;
353        let sum_data = sum.to_vec()?;
354        assert_eq!(sum_data, vec![6.0, 7.0, 8.0, 9.0]);
355
356        // Test scalar multiplication
357        let product = a.mul_scalar(2.0)?;
358        let product_data = product.to_vec()?;
359        assert_eq!(product_data, vec![2.0, 4.0, 6.0, 8.0]);
360
361        Ok(())
362    }
363
364    #[test]
365    fn test_broadcasting() -> Result<()> {
366        let a = Tensor::from_data(
367            vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
368            vec![2, 3],
369            DataType::F32,
370            TensorLayout::RowMajor,
371        )?;
372        let b = Tensor::from_data(
373            vec![10.0, 20.0, 30.0],
374            vec![3],
375            DataType::F32,
376            TensorLayout::RowMajor,
377        )?;
378
379        let sum = a.add(&b)?;
380        let sum_data = sum.to_vec()?;
381        assert_eq!(sum_data, vec![11.0, 22.0, 33.0, 14.0, 25.0, 36.0]);
382
383        Ok(())
384    }
385
386    #[test]
387    fn test_activation_functions() -> Result<()> {
388        let a = Tensor::from_data(
389            vec![-1.0, 0.0, 1.0, 2.0],
390            vec![4],
391            DataType::F32,
392            TensorLayout::RowMajor,
393        )?;
394
395        // Test ReLU
396        let relu_result = a.relu()?;
397        let relu_data = relu_result.to_vec()?;
398        assert_eq!(relu_data, vec![0.0, 0.0, 1.0, 2.0]);
399
400        // Test absolute value
401        let abs_result = a.abs()?;
402        let abs_data = abs_result.to_vec()?;
403        assert_eq!(abs_data, vec![1.0, 0.0, 1.0, 2.0]);
404
405        // Test negation
406        let neg_result = a.neg()?;
407        let neg_data = neg_result.to_vec()?;
408        assert_eq!(neg_data, vec![1.0, 0.0, -1.0, -2.0]);
409
410        Ok(())
411    }
412
413    #[test]
414    fn test_sigmoid() -> Result<()> {
415        let x = Tensor::from_data(vec![0.0], vec![1], DataType::F32, TensorLayout::RowMajor)?;
416        let sigmoid_result = x.sigmoid()?;
417        let sigmoid_data = sigmoid_result.to_vec()?;
418
419        // sigmoid(0) should be 0.5
420        assert!((sigmoid_data[0] - 0.5).abs() < 1e-6);
421
422        Ok(())
423    }
424
425    #[test]
426    fn test_error_handling() {
427        let a = Tensor::from_data(
428            vec![1.0, 2.0],
429            vec![2],
430            DataType::F32,
431            TensorLayout::RowMajor,
432        )
433        .unwrap();
434        let b = Tensor::from_data(
435            vec![1.0, 2.0, 3.0],
436            vec![3],
437            DataType::F32,
438            TensorLayout::RowMajor,
439        )
440        .unwrap();
441
442        // Should fail because shapes are not broadcastable
443        assert!(a.add(&b).is_err());
444
445        // Division by zero should fail
446        assert!(a.div_scalar(0.0).is_err());
447
448        // Invalid clamp should fail
449        assert!(a.clamp(5.0, 1.0).is_err());
450    }
451}