1use crate::array::Array;
16use crate::ops;
17use std::ops::{Add, Div, Mul, Neg, Sub};
18
19impl Add for Array {
24 type Output = Array;
25
26 fn add(self, rhs: Self) -> Self::Output {
28 ops::add(&self, &rhs).expect("Addition failed")
29 }
30}
31
32impl Add<&Array> for Array {
33 type Output = Array;
34
35 fn add(self, rhs: &Array) -> Self::Output {
37 ops::add(&self, rhs).expect("Addition failed")
38 }
39}
40
41impl Add<Array> for &Array {
42 type Output = Array;
43
44 fn add(self, rhs: Array) -> Self::Output {
46 ops::add(self, &rhs).expect("Addition failed")
47 }
48}
49
50impl Add for &Array {
51 type Output = Array;
52
53 fn add(self, rhs: Self) -> Self::Output {
55 ops::add(self, rhs).expect("Addition failed")
56 }
57}
58
59impl Sub for Array {
62 type Output = Array;
63
64 fn sub(self, rhs: Self) -> Self::Output {
66 ops::sub(&self, &rhs).expect("Subtraction failed")
67 }
68}
69
70impl Sub<&Array> for Array {
71 type Output = Array;
72
73 fn sub(self, rhs: &Array) -> Self::Output {
74 ops::sub(&self, rhs).expect("Subtraction failed")
75 }
76}
77
78impl Sub<Array> for &Array {
79 type Output = Array;
80
81 fn sub(self, rhs: Array) -> Self::Output {
82 ops::sub(self, &rhs).expect("Subtraction failed")
83 }
84}
85
86impl Sub for &Array {
87 type Output = Array;
88
89 fn sub(self, rhs: Self) -> Self::Output {
91 ops::sub(self, rhs).expect("Subtraction failed")
92 }
93}
94
95impl Mul for Array {
98 type Output = Array;
99
100 fn mul(self, rhs: Self) -> Self::Output {
102 ops::mul(&self, &rhs).expect("Multiplication failed")
103 }
104}
105
106impl Mul<&Array> for Array {
107 type Output = Array;
108
109 fn mul(self, rhs: &Array) -> Self::Output {
110 ops::mul(&self, rhs).expect("Multiplication failed")
111 }
112}
113
114impl Mul<Array> for &Array {
115 type Output = Array;
116
117 fn mul(self, rhs: Array) -> Self::Output {
118 ops::mul(self, &rhs).expect("Multiplication failed")
119 }
120}
121
122impl Mul for &Array {
123 type Output = Array;
124
125 fn mul(self, rhs: Self) -> Self::Output {
127 ops::mul(self, rhs).expect("Multiplication failed")
128 }
129}
130
131impl Div for Array {
134 type Output = Array;
135
136 fn div(self, rhs: Self) -> Self::Output {
138 ops::div(&self, &rhs).expect("Division failed")
139 }
140}
141
142impl Div<&Array> for Array {
143 type Output = Array;
144
145 fn div(self, rhs: &Array) -> Self::Output {
146 ops::div(&self, rhs).expect("Division failed")
147 }
148}
149
150impl Div<Array> for &Array {
151 type Output = Array;
152
153 fn div(self, rhs: Array) -> Self::Output {
154 ops::div(self, &rhs).expect("Division failed")
155 }
156}
157
158impl Div for &Array {
159 type Output = Array;
160
161 fn div(self, rhs: Self) -> Self::Output {
163 ops::div(self, rhs).expect("Division failed")
164 }
165}
166
167impl Neg for Array {
172 type Output = Array;
173
174 fn neg(self) -> Self::Output {
176 ops::neg(&self).expect("Negation failed")
177 }
178}
179
180impl Neg for &Array {
181 type Output = Array;
182
183 fn neg(self) -> Self::Output {
185 ops::neg(self).expect("Negation failed")
186 }
187}
188
189#[cfg(test)]
194mod tests {
195 use super::*;
196
197 #[test]
198 fn test_add_operator() {
199 let a = Array::new(vec![3], vec![1.0f32, 2.0, 3.0]);
200 let b = Array::new(vec![3], vec![4.0f32, 5.0, 6.0]);
201
202 let c: Array<f32> = &a + &b;
204 assert_eq!(c.data, vec![5.0, 7.0, 9.0]);
205 }
206
207 #[test]
208 fn test_sub_operator() {
209 let a = Array::new(vec![3], vec![5.0f32, 7.0, 9.0]);
210 let b = Array::new(vec![3], vec![1.0f32, 2.0, 3.0]);
211
212 let c: Array<f32> = &a - &b;
213 assert_eq!(c.data, vec![4.0, 5.0, 6.0]);
214 }
215
216 #[test]
217 fn test_mul_operator() {
218 let a = Array::new(vec![3], vec![2.0f32, 3.0, 4.0]);
219 let b = Array::new(vec![3], vec![5.0f32, 6.0, 7.0]);
220
221 let c: Array<f32> = &a * &b;
222 assert_eq!(c.data, vec![10.0, 18.0, 28.0]);
223 }
224
225 #[test]
226 fn test_div_operator() {
227 let a = Array::new(vec![3], vec![10.0f32, 20.0, 30.0]);
228 let b = Array::new(vec![3], vec![2.0f32, 4.0, 5.0]);
229
230 let c: Array<f32> = &a / &b;
231 assert_eq!(c.data, vec![5.0, 5.0, 6.0]);
232 }
233
234 #[test]
235 fn test_neg_operator() {
236 let a = Array::new(vec![3], vec![1.0f32, -2.0, 3.0]);
237
238 let b: Array<f32> = -&a;
239 assert_eq!(b.data, vec![-1.0, 2.0, -3.0]);
240 }
241
242 #[test]
243 fn test_chained_operators() {
244 let a = Array::new(vec![3], vec![1.0f32, 2.0, 3.0]);
245 let b = Array::new(vec![3], vec![4.0f32, 5.0, 6.0]);
246 let c = Array::new(vec![3], vec![2.0f32, 2.0, 2.0]);
247
248 let sum: Array<f32> = &a + &b;
250 let result: Array<f32> = &sum * &c;
251 assert_eq!(result.data, vec![10.0, 14.0, 18.0]);
252 }
253
254 #[test]
255 fn test_different_dtypes() {
256 use crate::array::dtype::DType;
258 use crate::ops;
259
260 let a_f64 = Array::new(vec![2], vec![1.0f64, 2.0]);
261 let b_f64 = Array::new(vec![2], vec![3.0f64, 4.0]);
262
263 let c: Array = ops::add(&a_f64, &b_f64).unwrap();
265 assert_eq!(c.dtype, DType::F32);
266 let a_i32 = Array::new(vec![2], vec![10i32, 20]);
269 let b_i32 = Array::new(vec![2], vec![5i32, 10]);
270
271 let d: Array = ops::sub(&a_i32, &b_i32).unwrap();
273 assert_eq!(d.dtype, DType::F32); }
276}