Skip to main content

ferray_ufunc/
operator_overloads.rs

1// ferray-ufunc: Operator-style convenience functions
2//
3// REQ-9: +, -, *, /, % on arrays — provided as functions since the orphan rule
4// prevents implementing std::ops traits on ferray_core::Array from this crate.
5// The operator trait impls themselves should live in ferray-core; these functions
6// serve as the underlying implementations that ferray-core can delegate to.
7//
8// REQ-13: &, |, ^, !, <<, >> on integer arrays — same approach.
9//
10// Users can call these directly: `ufunc::array_add(&a, &b)` etc.
11
12use ferray_core::Array;
13use ferray_core::dimension::Dimension;
14use ferray_core::dtype::Element;
15use ferray_core::error::FerrayResult;
16use num_traits::Float;
17
18use crate::ops::bitwise::{BitwiseOps, ShiftOps};
19
20/// Array addition (delegates to `arithmetic::add`).
21pub fn array_add<T, D>(a: &Array<T, D>, b: &Array<T, D>) -> FerrayResult<Array<T, D>>
22where
23    T: Element + std::ops::Add<Output = T> + Copy,
24    D: Dimension,
25{
26    crate::ops::arithmetic::add(a, b)
27}
28
29/// Array subtraction (delegates to `arithmetic::subtract`).
30pub fn array_sub<T, D>(a: &Array<T, D>, b: &Array<T, D>) -> FerrayResult<Array<T, D>>
31where
32    T: Element + std::ops::Sub<Output = T> + Copy,
33    D: Dimension,
34{
35    crate::ops::arithmetic::subtract(a, b)
36}
37
38/// Array multiplication (delegates to `arithmetic::multiply`).
39pub fn array_mul<T, D>(a: &Array<T, D>, b: &Array<T, D>) -> FerrayResult<Array<T, D>>
40where
41    T: Element + std::ops::Mul<Output = T> + Copy,
42    D: Dimension,
43{
44    crate::ops::arithmetic::multiply(a, b)
45}
46
47/// Array division (delegates to `arithmetic::divide`).
48pub fn array_div<T, D>(a: &Array<T, D>, b: &Array<T, D>) -> FerrayResult<Array<T, D>>
49where
50    T: Element + std::ops::Div<Output = T> + Copy,
51    D: Dimension,
52{
53    crate::ops::arithmetic::divide(a, b)
54}
55
56/// Array remainder (delegates to `arithmetic::remainder`).
57pub fn array_rem<T, D>(a: &Array<T, D>, b: &Array<T, D>) -> FerrayResult<Array<T, D>>
58where
59    T: Element + Float,
60    D: Dimension,
61{
62    crate::ops::arithmetic::remainder(a, b)
63}
64
65/// Array negation (delegates to `arithmetic::negative`).
66pub fn array_neg<T, D>(a: &Array<T, D>) -> FerrayResult<Array<T, D>>
67where
68    T: Element + Float,
69    D: Dimension,
70{
71    crate::ops::arithmetic::negative(a)
72}
73
74/// Array bitwise AND (delegates to `bitwise::bitwise_and`).
75pub fn array_bitand<T, D>(a: &Array<T, D>, b: &Array<T, D>) -> FerrayResult<Array<T, D>>
76where
77    T: Element + BitwiseOps,
78    D: Dimension,
79{
80    crate::ops::bitwise::bitwise_and(a, b)
81}
82
83/// Array bitwise OR (delegates to `bitwise::bitwise_or`).
84pub fn array_bitor<T, D>(a: &Array<T, D>, b: &Array<T, D>) -> FerrayResult<Array<T, D>>
85where
86    T: Element + BitwiseOps,
87    D: Dimension,
88{
89    crate::ops::bitwise::bitwise_or(a, b)
90}
91
92/// Array bitwise XOR (delegates to `bitwise::bitwise_xor`).
93pub fn array_bitxor<T, D>(a: &Array<T, D>, b: &Array<T, D>) -> FerrayResult<Array<T, D>>
94where
95    T: Element + BitwiseOps,
96    D: Dimension,
97{
98    crate::ops::bitwise::bitwise_xor(a, b)
99}
100
101/// Array bitwise NOT (delegates to `bitwise::bitwise_not`).
102pub fn array_bitnot<T, D>(a: &Array<T, D>) -> FerrayResult<Array<T, D>>
103where
104    T: Element + BitwiseOps,
105    D: Dimension,
106{
107    crate::ops::bitwise::bitwise_not(a)
108}
109
110/// Array left shift (delegates to `bitwise::left_shift`).
111pub fn array_shl<T, D>(a: &Array<T, D>, b: &Array<u32, D>) -> FerrayResult<Array<T, D>>
112where
113    T: Element + ShiftOps,
114    D: Dimension,
115{
116    crate::ops::bitwise::left_shift(a, b)
117}
118
119/// Array right shift (delegates to `bitwise::right_shift`).
120pub fn array_shr<T, D>(a: &Array<T, D>, b: &Array<u32, D>) -> FerrayResult<Array<T, D>>
121where
122    T: Element + ShiftOps,
123    D: Dimension,
124{
125    crate::ops::bitwise::right_shift(a, b)
126}
127
128#[cfg(test)]
129mod tests {
130    use super::*;
131    use ferray_core::dimension::Ix1;
132
133    fn arr1_f64(data: Vec<f64>) -> Array<f64, Ix1> {
134        let n = data.len();
135        Array::from_vec(Ix1::new([n]), data).unwrap()
136    }
137
138    fn arr1_i32(data: Vec<i32>) -> Array<i32, Ix1> {
139        let n = data.len();
140        Array::from_vec(Ix1::new([n]), data).unwrap()
141    }
142
143    fn arr1_u32(data: Vec<u32>) -> Array<u32, Ix1> {
144        let n = data.len();
145        Array::from_vec(Ix1::new([n]), data).unwrap()
146    }
147
148    // AC-8: Operator functions produce identical results to ufunc functions
149
150    #[test]
151    fn test_array_add() {
152        let a = arr1_f64(vec![1.0, 2.0, 3.0]);
153        let b = arr1_f64(vec![4.0, 5.0, 6.0]);
154        let r = array_add(&a, &b).unwrap();
155        let r2 = crate::ops::arithmetic::add(&a, &b).unwrap();
156        assert_eq!(r.as_slice().unwrap(), r2.as_slice().unwrap());
157    }
158
159    #[test]
160    fn test_array_sub() {
161        let a = arr1_f64(vec![5.0, 7.0, 9.0]);
162        let b = arr1_f64(vec![1.0, 2.0, 3.0]);
163        let r = array_sub(&a, &b).unwrap();
164        assert_eq!(r.as_slice().unwrap(), &[4.0, 5.0, 6.0]);
165    }
166
167    #[test]
168    fn test_array_mul() {
169        let a = arr1_f64(vec![2.0, 3.0]);
170        let b = arr1_f64(vec![4.0, 5.0]);
171        let r = array_mul(&a, &b).unwrap();
172        assert_eq!(r.as_slice().unwrap(), &[8.0, 15.0]);
173    }
174
175    #[test]
176    fn test_array_div() {
177        let a = arr1_f64(vec![10.0, 20.0]);
178        let b = arr1_f64(vec![2.0, 5.0]);
179        let r = array_div(&a, &b).unwrap();
180        assert_eq!(r.as_slice().unwrap(), &[5.0, 4.0]);
181    }
182
183    #[test]
184    fn test_array_rem() {
185        let a = arr1_f64(vec![7.0, 10.0]);
186        let b = arr1_f64(vec![3.0, 4.0]);
187        let r = array_rem(&a, &b).unwrap();
188        let s = r.as_slice().unwrap();
189        assert!((s[0] - 1.0).abs() < 1e-12);
190        assert!((s[1] - 2.0).abs() < 1e-12);
191    }
192
193    #[test]
194    fn test_array_neg() {
195        let a = arr1_f64(vec![1.0, -2.0, 3.0]);
196        let r = array_neg(&a).unwrap();
197        assert_eq!(r.as_slice().unwrap(), &[-1.0, 2.0, -3.0]);
198    }
199
200    #[test]
201    fn test_array_bitand() {
202        let a = arr1_i32(vec![0b1100, 0b1010]);
203        let b = arr1_i32(vec![0b1010, 0b1010]);
204        let r = array_bitand(&a, &b).unwrap();
205        assert_eq!(r.as_slice().unwrap(), &[0b1000, 0b1010]);
206    }
207
208    #[test]
209    fn test_array_bitor() {
210        let a = arr1_i32(vec![0b1100, 0b1010]);
211        let b = arr1_i32(vec![0b1010, 0b0101]);
212        let r = array_bitor(&a, &b).unwrap();
213        assert_eq!(r.as_slice().unwrap(), &[0b1110, 0b1111]);
214    }
215
216    #[test]
217    fn test_array_bitxor() {
218        let a = arr1_i32(vec![0b1100]);
219        let b = arr1_i32(vec![0b1010]);
220        let r = array_bitxor(&a, &b).unwrap();
221        assert_eq!(r.as_slice().unwrap(), &[0b0110]);
222    }
223
224    #[test]
225    fn test_array_bitnot() {
226        let a = Array::<u8, Ix1>::from_vec(Ix1::new([1]), vec![0b0000_1111]).unwrap();
227        let r = array_bitnot(&a).unwrap();
228        assert_eq!(r.as_slice().unwrap(), &[0b1111_0000]);
229    }
230
231    #[test]
232    fn test_array_shl() {
233        let a = arr1_i32(vec![1, 2, 4]);
234        let s = arr1_u32(vec![1, 2, 3]);
235        let r = array_shl(&a, &s).unwrap();
236        assert_eq!(r.as_slice().unwrap(), &[2, 8, 32]);
237    }
238
239    #[test]
240    fn test_array_shr() {
241        let a = arr1_i32(vec![8, 16, 32]);
242        let s = arr1_u32(vec![1, 2, 3]);
243        let r = array_shr(&a, &s).unwrap();
244        assert_eq!(r.as_slice().unwrap(), &[4, 4, 4]);
245    }
246}