1use crate::array::owned::Array;
15use crate::dimension::Dimension;
16use crate::dtype::Element;
17use crate::error::{FerrayError, FerrayResult};
18
19fn elementwise_binary<T, D, F>(
23 a: &Array<T, D>,
24 b: &Array<T, D>,
25 op: F,
26 op_name: &str,
27) -> FerrayResult<Array<T, D>>
28where
29 T: Element + Copy,
30 D: Dimension,
31 F: Fn(T, T) -> T,
32{
33 if a.shape() != b.shape() {
34 return Err(FerrayError::shape_mismatch(format!(
35 "operator {}: shapes {:?} and {:?} are not compatible",
36 op_name,
37 a.shape(),
38 b.shape()
39 )));
40 }
41 let data: Vec<T> = a.iter().zip(b.iter()).map(|(&x, &y)| op(x, y)).collect();
42 Array::from_vec(a.dim().clone(), data)
43}
44
45macro_rules! impl_binary_op {
53 ($trait:ident, $method:ident, $op_fn:expr, $op_name:expr) => {
54 impl<T, D> std::ops::$trait<&Array<T, D>> for &Array<T, D>
56 where
57 T: Element + Copy + std::ops::$trait<Output = T>,
58 D: Dimension,
59 {
60 type Output = FerrayResult<Array<T, D>>;
61
62 fn $method(self, rhs: &Array<T, D>) -> Self::Output {
63 elementwise_binary(self, rhs, $op_fn, $op_name)
64 }
65 }
66
67 impl<T, D> std::ops::$trait<Array<T, D>> for Array<T, D>
69 where
70 T: Element + Copy + std::ops::$trait<Output = T>,
71 D: Dimension,
72 {
73 type Output = FerrayResult<Array<T, D>>;
74
75 fn $method(self, rhs: Array<T, D>) -> Self::Output {
76 elementwise_binary(&self, &rhs, $op_fn, $op_name)
77 }
78 }
79
80 impl<T, D> std::ops::$trait<&Array<T, D>> for Array<T, D>
82 where
83 T: Element + Copy + std::ops::$trait<Output = T>,
84 D: Dimension,
85 {
86 type Output = FerrayResult<Array<T, D>>;
87
88 fn $method(self, rhs: &Array<T, D>) -> Self::Output {
89 elementwise_binary(&self, rhs, $op_fn, $op_name)
90 }
91 }
92
93 impl<T, D> std::ops::$trait<Array<T, D>> for &Array<T, D>
95 where
96 T: Element + Copy + std::ops::$trait<Output = T>,
97 D: Dimension,
98 {
99 type Output = FerrayResult<Array<T, D>>;
100
101 fn $method(self, rhs: Array<T, D>) -> Self::Output {
102 elementwise_binary(self, &rhs, $op_fn, $op_name)
103 }
104 }
105 };
106}
107
108impl_binary_op!(Add, add, |a, b| a + b, "+");
109impl_binary_op!(Sub, sub, |a, b| a - b, "-");
110impl_binary_op!(Mul, mul, |a, b| a * b, "*");
111impl_binary_op!(Div, div, |a, b| a / b, "/");
112impl_binary_op!(Rem, rem, |a, b| a % b, "%");
113
114impl<T, D> std::ops::Neg for &Array<T, D>
116where
117 T: Element + Copy + std::ops::Neg<Output = T>,
118 D: Dimension,
119{
120 type Output = FerrayResult<Array<T, D>>;
121
122 fn neg(self) -> Self::Output {
123 let data: Vec<T> = self.iter().map(|&x| -x).collect();
124 Array::from_vec(self.dim().clone(), data)
125 }
126}
127
128impl<T, D> std::ops::Neg for Array<T, D>
129where
130 T: Element + Copy + std::ops::Neg<Output = T>,
131 D: Dimension,
132{
133 type Output = FerrayResult<Array<T, D>>;
134
135 fn neg(self) -> Self::Output {
136 -&self
137 }
138}
139
140#[cfg(test)]
141mod tests {
142 use super::*;
143 use crate::dimension::Ix1;
144
145 fn arr(data: Vec<f64>) -> Array<f64, Ix1> {
146 let n = data.len();
147 Array::from_vec(Ix1::new([n]), data).unwrap()
148 }
149
150 fn arr_i32(data: Vec<i32>) -> Array<i32, Ix1> {
151 let n = data.len();
152 Array::from_vec(Ix1::new([n]), data).unwrap()
153 }
154
155 #[test]
156 fn test_add_ref_ref() {
157 let a = arr(vec![1.0, 2.0, 3.0]);
158 let b = arr(vec![4.0, 5.0, 6.0]);
159 let c = (&a + &b).unwrap();
160 assert_eq!(c.as_slice().unwrap(), &[5.0, 7.0, 9.0]);
161 }
162
163 #[test]
164 fn test_add_owned_owned() {
165 let a = arr(vec![1.0, 2.0]);
166 let b = arr(vec![3.0, 4.0]);
167 let c = (a + b).unwrap();
168 assert_eq!(c.as_slice().unwrap(), &[4.0, 6.0]);
169 }
170
171 #[test]
172 fn test_add_mixed() {
173 let a = arr(vec![1.0, 2.0]);
174 let b = arr(vec![3.0, 4.0]);
175 let c = (a + &b).unwrap();
176 assert_eq!(c.as_slice().unwrap(), &[4.0, 6.0]);
177
178 let d = arr(vec![10.0, 20.0]);
179 let e = (&b + d).unwrap();
180 assert_eq!(e.as_slice().unwrap(), &[13.0, 24.0]);
181 }
182
183 #[test]
184 fn test_sub() {
185 let a = arr(vec![5.0, 7.0]);
186 let b = arr(vec![1.0, 2.0]);
187 let c = (&a - &b).unwrap();
188 assert_eq!(c.as_slice().unwrap(), &[4.0, 5.0]);
189 }
190
191 #[test]
192 fn test_mul() {
193 let a = arr(vec![2.0, 3.0]);
194 let b = arr(vec![4.0, 5.0]);
195 let c = (&a * &b).unwrap();
196 assert_eq!(c.as_slice().unwrap(), &[8.0, 15.0]);
197 }
198
199 #[test]
200 fn test_div() {
201 let a = arr(vec![10.0, 20.0]);
202 let b = arr(vec![2.0, 5.0]);
203 let c = (&a / &b).unwrap();
204 assert_eq!(c.as_slice().unwrap(), &[5.0, 4.0]);
205 }
206
207 #[test]
208 fn test_rem() {
209 let a = arr_i32(vec![7, 10]);
210 let b = arr_i32(vec![3, 4]);
211 let c = (&a % &b).unwrap();
212 assert_eq!(c.as_slice().unwrap(), &[1, 2]);
213 }
214
215 #[test]
216 fn test_neg() {
217 let a = arr(vec![1.0, -2.0, 3.0]);
218 let b = (-&a).unwrap();
219 assert_eq!(b.as_slice().unwrap(), &[-1.0, 2.0, -3.0]);
220 }
221
222 #[test]
223 fn test_neg_owned() {
224 let a = arr(vec![1.0, -2.0]);
225 let b = (-a).unwrap();
226 assert_eq!(b.as_slice().unwrap(), &[-1.0, 2.0]);
227 }
228
229 #[test]
230 fn test_shape_mismatch_errors() {
231 let a = arr(vec![1.0, 2.0]);
232 let b = arr(vec![1.0, 2.0, 3.0]);
233 let result = &a + &b;
234 assert!(result.is_err());
235 }
236
237 #[test]
238 fn test_chained_ops() {
239 let a = arr(vec![1.0, 2.0, 3.0]);
240 let b = arr(vec![4.0, 5.0, 6.0]);
241 let c = arr(vec![10.0, 10.0, 10.0]);
242 let result = (&(&a + &b).unwrap() * &c).unwrap();
244 assert_eq!(result.as_slice().unwrap(), &[50.0, 70.0, 90.0]);
245 }
246}