numrs/array/
ops_traits.rs

1//! Operator overloading for Array
2//!
3//! This module implements std::ops traits to enable natural operator syntax:
4//! ```rust
5//! # use numrs::array::Array;
6//! # use numrs::ops;
7//! let a = Array::new(vec![1], vec![1.0]);
8//! let b = Array::new(vec![1], vec![2.0]);
9//! let c = &a + &b;  // instead of ops::add(&a, &b)
10//! let d = &a * &b;  // instead of ops::mul(&a, &b)
11//! ```
12//!
13//! Note: Operators return Array (with default dtype f32) after type promotion.
14
15use crate::array::Array;
16use crate::ops;
17use std::ops::{Add, Div, Mul, Neg, Sub};
18
19// ============================================================================
20// Binary Operations: Add, Sub, Mul, Div
21// ============================================================================
22
23impl Add for Array {
24    type Output = Array;
25
26    /// Addition: `a + b`
27    fn add(self, rhs: Self) -> Self::Output {
28        ops::add(&self, &rhs).expect("Addition failed")
29    }
30}
31
32impl Add<&Array> for Array {
33    type Output = Array;
34
35    /// Addition: `a + &b`
36    fn add(self, rhs: &Array) -> Self::Output {
37        ops::add(&self, rhs).expect("Addition failed")
38    }
39}
40
41impl Add<Array> for &Array {
42    type Output = Array;
43
44    /// Addition: `&a + b`
45    fn add(self, rhs: Array) -> Self::Output {
46        ops::add(self, &rhs).expect("Addition failed")
47    }
48}
49
50impl Add for &Array {
51    type Output = Array;
52
53    /// Addition: `&a + &b` (most common, avoids moves)
54    fn add(self, rhs: Self) -> Self::Output {
55        ops::add(self, rhs).expect("Addition failed")
56    }
57}
58
59// ----------------------------------------------------------------------------
60
61impl Sub for Array {
62    type Output = Array;
63
64    /// Subtraction: `a - b`
65    fn sub(self, rhs: Self) -> Self::Output {
66        ops::sub(&self, &rhs).expect("Subtraction failed")
67    }
68}
69
70impl Sub<&Array> for Array {
71    type Output = Array;
72
73    fn sub(self, rhs: &Array) -> Self::Output {
74        ops::sub(&self, rhs).expect("Subtraction failed")
75    }
76}
77
78impl Sub<Array> for &Array {
79    type Output = Array;
80
81    fn sub(self, rhs: Array) -> Self::Output {
82        ops::sub(self, &rhs).expect("Subtraction failed")
83    }
84}
85
86impl Sub for &Array {
87    type Output = Array;
88
89    /// Subtraction: `&a - &b`
90    fn sub(self, rhs: Self) -> Self::Output {
91        ops::sub(self, rhs).expect("Subtraction failed")
92    }
93}
94
95// ----------------------------------------------------------------------------
96
97impl Mul for Array {
98    type Output = Array;
99
100    /// Element-wise multiplication: `a * b`
101    fn mul(self, rhs: Self) -> Self::Output {
102        ops::mul(&self, &rhs).expect("Multiplication failed")
103    }
104}
105
106impl Mul<&Array> for Array {
107    type Output = Array;
108
109    fn mul(self, rhs: &Array) -> Self::Output {
110        ops::mul(&self, rhs).expect("Multiplication failed")
111    }
112}
113
114impl Mul<Array> for &Array {
115    type Output = Array;
116
117    fn mul(self, rhs: Array) -> Self::Output {
118        ops::mul(self, &rhs).expect("Multiplication failed")
119    }
120}
121
122impl Mul for &Array {
123    type Output = Array;
124
125    /// Element-wise multiplication: `&a * &b`
126    fn mul(self, rhs: Self) -> Self::Output {
127        ops::mul(self, rhs).expect("Multiplication failed")
128    }
129}
130
131// ----------------------------------------------------------------------------
132
133impl Div for Array {
134    type Output = Array;
135
136    /// Element-wise division: `a / b`
137    fn div(self, rhs: Self) -> Self::Output {
138        ops::div(&self, &rhs).expect("Division failed")
139    }
140}
141
142impl Div<&Array> for Array {
143    type Output = Array;
144
145    fn div(self, rhs: &Array) -> Self::Output {
146        ops::div(&self, rhs).expect("Division failed")
147    }
148}
149
150impl Div<Array> for &Array {
151    type Output = Array;
152
153    fn div(self, rhs: Array) -> Self::Output {
154        ops::div(self, &rhs).expect("Division failed")
155    }
156}
157
158impl Div for &Array {
159    type Output = Array;
160
161    /// Element-wise division: `&a / &b`
162    fn div(self, rhs: Self) -> Self::Output {
163        ops::div(self, rhs).expect("Division failed")
164    }
165}
166
167// ============================================================================
168// Unary Operations
169// ============================================================================
170
171impl Neg for Array {
172    type Output = Array;
173
174    /// Negation: `-a`
175    fn neg(self) -> Self::Output {
176        ops::neg(&self).expect("Negation failed")
177    }
178}
179
180impl Neg for &Array {
181    type Output = Array;
182
183    /// Negation: `-&a`
184    fn neg(self) -> Self::Output {
185        ops::neg(self).expect("Negation failed")
186    }
187}
188
189// ============================================================================
190// Tests
191// ============================================================================
192
193#[cfg(test)]
194mod tests {
195    use super::*;
196
197    #[test]
198    fn test_add_operator() {
199        let a = Array::new(vec![3], vec![1.0f32, 2.0, 3.0]);
200        let b = Array::new(vec![3], vec![4.0f32, 5.0, 6.0]);
201
202        // Test with references (most common)
203        let c: Array<f32> = &a + &b;
204        assert_eq!(c.data, vec![5.0, 7.0, 9.0]);
205    }
206
207    #[test]
208    fn test_sub_operator() {
209        let a = Array::new(vec![3], vec![5.0f32, 7.0, 9.0]);
210        let b = Array::new(vec![3], vec![1.0f32, 2.0, 3.0]);
211
212        let c: Array<f32> = &a - &b;
213        assert_eq!(c.data, vec![4.0, 5.0, 6.0]);
214    }
215
216    #[test]
217    fn test_mul_operator() {
218        let a = Array::new(vec![3], vec![2.0f32, 3.0, 4.0]);
219        let b = Array::new(vec![3], vec![5.0f32, 6.0, 7.0]);
220
221        let c: Array<f32> = &a * &b;
222        assert_eq!(c.data, vec![10.0, 18.0, 28.0]);
223    }
224
225    #[test]
226    fn test_div_operator() {
227        let a = Array::new(vec![3], vec![10.0f32, 20.0, 30.0]);
228        let b = Array::new(vec![3], vec![2.0f32, 4.0, 5.0]);
229
230        let c: Array<f32> = &a / &b;
231        assert_eq!(c.data, vec![5.0, 5.0, 6.0]);
232    }
233
234    #[test]
235    fn test_neg_operator() {
236        let a = Array::new(vec![3], vec![1.0f32, -2.0, 3.0]);
237
238        let b: Array<f32> = -&a;
239        assert_eq!(b.data, vec![-1.0, 2.0, -3.0]);
240    }
241
242    #[test]
243    fn test_chained_operators() {
244        let a = Array::new(vec![3], vec![1.0f32, 2.0, 3.0]);
245        let b = Array::new(vec![3], vec![4.0f32, 5.0, 6.0]);
246        let c = Array::new(vec![3], vec![2.0f32, 2.0, 2.0]);
247
248        // Test: (a + b) * c
249        let sum: Array<f32> = &a + &b;
250        let result: Array<f32> = &sum * &c;
251        assert_eq!(result.data, vec![10.0, 14.0, 18.0]);
252    }
253
254    #[test]
255    fn test_different_dtypes() {
256        // Operators work with any dtype through type promotion
257        use crate::array::dtype::DType;
258        use crate::ops;
259
260        let a_f64 = Array::new(vec![2], vec![1.0f64, 2.0]);
261        let b_f64 = Array::new(vec![2], vec![3.0f64, 4.0]);
262
263        // Result is always f32 currently
264        let c: Array = ops::add(&a_f64, &b_f64).unwrap();
265        assert_eq!(c.dtype, DType::F32);
266        // assert_eq!(c.data, vec![4.0, 6.0]);
267
268        let a_i32 = Array::new(vec![2], vec![10i32, 20]);
269        let b_i32 = Array::new(vec![2], vec![5i32, 10]);
270
271        // Result is always f32 currently
272        let d: Array = ops::sub(&a_i32, &b_i32).unwrap();
273        assert_eq!(d.dtype, DType::F32); // Promoted or defaulted to F32
274                                         // assert_eq!(d.data, vec![5.0, 10.0]);
275    }
276}