Skip to main content

ferray_core/
ops.rs

1// ferray-core: Operator overloading for Array<T, D>
2//
3// Implements std::ops::{Add, Sub, Mul, Div, Rem, Neg} with
4// Output = FerrayResult<Array<T, D>>.
5//
6// Users write `(a + b)?` to get the result, maintaining the zero-panic
7// guarantee while enabling natural math syntax.
8//
9// These operate elementwise on same-shape arrays. For broadcasting,
10// use the functions in ferray-ufunc (e.g. ferray::add(&a, &b)).
11//
12// See: https://github.com/dollspace-gay/ferray/issues/7
13
14use crate::array::owned::Array;
15use crate::dimension::Dimension;
16use crate::dtype::Element;
17use crate::error::{FerrayError, FerrayResult};
18
19/// Elementwise binary operation on two same-shape arrays.
20///
21/// Returns an error if the shapes don't match.
22fn elementwise_binary<T, D, F>(
23    a: &Array<T, D>,
24    b: &Array<T, D>,
25    op: F,
26    op_name: &str,
27) -> FerrayResult<Array<T, D>>
28where
29    T: Element + Copy,
30    D: Dimension,
31    F: Fn(T, T) -> T,
32{
33    if a.shape() != b.shape() {
34        return Err(FerrayError::shape_mismatch(format!(
35            "operator {}: shapes {:?} and {:?} are not compatible",
36            op_name,
37            a.shape(),
38            b.shape()
39        )));
40    }
41    let data: Vec<T> = a.iter().zip(b.iter()).map(|(&x, &y)| op(x, y)).collect();
42    Array::from_vec(a.dim().clone(), data)
43}
44
45/// Implement a binary operator for all ownership combinations of Array.
46///
47/// Generates impls for:
48///   &Array op &Array
49///   Array  op Array
50///   Array  op &Array
51///   &Array op Array
52macro_rules! impl_binary_op {
53    ($trait:ident, $method:ident, $op_fn:expr, $op_name:expr) => {
54        // &Array op &Array
55        impl<T, D> std::ops::$trait<&Array<T, D>> for &Array<T, D>
56        where
57            T: Element + Copy + std::ops::$trait<Output = T>,
58            D: Dimension,
59        {
60            type Output = FerrayResult<Array<T, D>>;
61
62            fn $method(self, rhs: &Array<T, D>) -> Self::Output {
63                elementwise_binary(self, rhs, $op_fn, $op_name)
64            }
65        }
66
67        // Array op Array
68        impl<T, D> std::ops::$trait<Array<T, D>> for Array<T, D>
69        where
70            T: Element + Copy + std::ops::$trait<Output = T>,
71            D: Dimension,
72        {
73            type Output = FerrayResult<Array<T, D>>;
74
75            fn $method(self, rhs: Array<T, D>) -> Self::Output {
76                elementwise_binary(&self, &rhs, $op_fn, $op_name)
77            }
78        }
79
80        // Array op &Array
81        impl<T, D> std::ops::$trait<&Array<T, D>> for Array<T, D>
82        where
83            T: Element + Copy + std::ops::$trait<Output = T>,
84            D: Dimension,
85        {
86            type Output = FerrayResult<Array<T, D>>;
87
88            fn $method(self, rhs: &Array<T, D>) -> Self::Output {
89                elementwise_binary(&self, rhs, $op_fn, $op_name)
90            }
91        }
92
93        // &Array op Array
94        impl<T, D> std::ops::$trait<Array<T, D>> for &Array<T, D>
95        where
96            T: Element + Copy + std::ops::$trait<Output = T>,
97            D: Dimension,
98        {
99            type Output = FerrayResult<Array<T, D>>;
100
101            fn $method(self, rhs: Array<T, D>) -> Self::Output {
102                elementwise_binary(self, &rhs, $op_fn, $op_name)
103            }
104        }
105    };
106}
107
108impl_binary_op!(Add, add, |a, b| a + b, "+");
109impl_binary_op!(Sub, sub, |a, b| a - b, "-");
110impl_binary_op!(Mul, mul, |a, b| a * b, "*");
111impl_binary_op!(Div, div, |a, b| a / b, "/");
112impl_binary_op!(Rem, rem, |a, b| a % b, "%");
113
114// Unary negation: -&Array and -Array
115impl<T, D> std::ops::Neg for &Array<T, D>
116where
117    T: Element + Copy + std::ops::Neg<Output = T>,
118    D: Dimension,
119{
120    type Output = FerrayResult<Array<T, D>>;
121
122    fn neg(self) -> Self::Output {
123        let data: Vec<T> = self.iter().map(|&x| -x).collect();
124        Array::from_vec(self.dim().clone(), data)
125    }
126}
127
128impl<T, D> std::ops::Neg for Array<T, D>
129where
130    T: Element + Copy + std::ops::Neg<Output = T>,
131    D: Dimension,
132{
133    type Output = FerrayResult<Array<T, D>>;
134
135    fn neg(self) -> Self::Output {
136        -&self
137    }
138}
139
140#[cfg(test)]
141mod tests {
142    use super::*;
143    use crate::dimension::Ix1;
144
145    fn arr(data: Vec<f64>) -> Array<f64, Ix1> {
146        let n = data.len();
147        Array::from_vec(Ix1::new([n]), data).unwrap()
148    }
149
150    fn arr_i32(data: Vec<i32>) -> Array<i32, Ix1> {
151        let n = data.len();
152        Array::from_vec(Ix1::new([n]), data).unwrap()
153    }
154
155    #[test]
156    fn test_add_ref_ref() {
157        let a = arr(vec![1.0, 2.0, 3.0]);
158        let b = arr(vec![4.0, 5.0, 6.0]);
159        let c = (&a + &b).unwrap();
160        assert_eq!(c.as_slice().unwrap(), &[5.0, 7.0, 9.0]);
161    }
162
163    #[test]
164    fn test_add_owned_owned() {
165        let a = arr(vec![1.0, 2.0]);
166        let b = arr(vec![3.0, 4.0]);
167        let c = (a + b).unwrap();
168        assert_eq!(c.as_slice().unwrap(), &[4.0, 6.0]);
169    }
170
171    #[test]
172    fn test_add_mixed() {
173        let a = arr(vec![1.0, 2.0]);
174        let b = arr(vec![3.0, 4.0]);
175        let c = (a + &b).unwrap();
176        assert_eq!(c.as_slice().unwrap(), &[4.0, 6.0]);
177
178        let d = arr(vec![10.0, 20.0]);
179        let e = (&b + d).unwrap();
180        assert_eq!(e.as_slice().unwrap(), &[13.0, 24.0]);
181    }
182
183    #[test]
184    fn test_sub() {
185        let a = arr(vec![5.0, 7.0]);
186        let b = arr(vec![1.0, 2.0]);
187        let c = (&a - &b).unwrap();
188        assert_eq!(c.as_slice().unwrap(), &[4.0, 5.0]);
189    }
190
191    #[test]
192    fn test_mul() {
193        let a = arr(vec![2.0, 3.0]);
194        let b = arr(vec![4.0, 5.0]);
195        let c = (&a * &b).unwrap();
196        assert_eq!(c.as_slice().unwrap(), &[8.0, 15.0]);
197    }
198
199    #[test]
200    fn test_div() {
201        let a = arr(vec![10.0, 20.0]);
202        let b = arr(vec![2.0, 5.0]);
203        let c = (&a / &b).unwrap();
204        assert_eq!(c.as_slice().unwrap(), &[5.0, 4.0]);
205    }
206
207    #[test]
208    fn test_rem() {
209        let a = arr_i32(vec![7, 10]);
210        let b = arr_i32(vec![3, 4]);
211        let c = (&a % &b).unwrap();
212        assert_eq!(c.as_slice().unwrap(), &[1, 2]);
213    }
214
215    #[test]
216    fn test_neg() {
217        let a = arr(vec![1.0, -2.0, 3.0]);
218        let b = (-&a).unwrap();
219        assert_eq!(b.as_slice().unwrap(), &[-1.0, 2.0, -3.0]);
220    }
221
222    #[test]
223    fn test_neg_owned() {
224        let a = arr(vec![1.0, -2.0]);
225        let b = (-a).unwrap();
226        assert_eq!(b.as_slice().unwrap(), &[-1.0, 2.0]);
227    }
228
229    #[test]
230    fn test_shape_mismatch_errors() {
231        let a = arr(vec![1.0, 2.0]);
232        let b = arr(vec![1.0, 2.0, 3.0]);
233        let result = &a + &b;
234        assert!(result.is_err());
235    }
236
237    #[test]
238    fn test_chained_ops() {
239        let a = arr(vec![1.0, 2.0, 3.0]);
240        let b = arr(vec![4.0, 5.0, 6.0]);
241        let c = arr(vec![10.0, 10.0, 10.0]);
242        // (a + b)? * c)?
243        let result = (&(&a + &b).unwrap() * &c).unwrap();
244        assert_eq!(result.as_slice().unwrap(), &[50.0, 70.0, 90.0]);
245    }
246}