Skip to main content

ferray_ufunc/ops/
comparison.rs

1// ferray-ufunc: Comparison functions
2//
3// equal, not_equal, less, less_equal, greater, greater_equal,
4// array_equal, array_equiv, allclose, isclose
5
6use ferray_core::Array;
7use ferray_core::dimension::Dimension;
8use ferray_core::dtype::Element;
9use ferray_core::error::FerrayResult;
10use num_traits::Float;
11
12use crate::helpers::binary_map_op;
13
14/// Elementwise equality test.
15pub fn equal<T, D>(a: &Array<T, D>, b: &Array<T, D>) -> FerrayResult<Array<bool, D>>
16where
17    T: Element + PartialEq + Copy,
18    D: Dimension,
19{
20    binary_map_op(a, b, |x, y| x == y)
21}
22
23/// Elementwise inequality test.
24pub fn not_equal<T, D>(a: &Array<T, D>, b: &Array<T, D>) -> FerrayResult<Array<bool, D>>
25where
26    T: Element + PartialEq + Copy,
27    D: Dimension,
28{
29    binary_map_op(a, b, |x, y| x != y)
30}
31
32/// Elementwise less-than test.
33pub fn less<T, D>(a: &Array<T, D>, b: &Array<T, D>) -> FerrayResult<Array<bool, D>>
34where
35    T: Element + PartialOrd + Copy,
36    D: Dimension,
37{
38    binary_map_op(a, b, |x, y| x < y)
39}
40
41/// Elementwise less-than-or-equal test.
42pub fn less_equal<T, D>(a: &Array<T, D>, b: &Array<T, D>) -> FerrayResult<Array<bool, D>>
43where
44    T: Element + PartialOrd + Copy,
45    D: Dimension,
46{
47    binary_map_op(a, b, |x, y| x <= y)
48}
49
50/// Elementwise greater-than test.
51pub fn greater<T, D>(a: &Array<T, D>, b: &Array<T, D>) -> FerrayResult<Array<bool, D>>
52where
53    T: Element + PartialOrd + Copy,
54    D: Dimension,
55{
56    binary_map_op(a, b, |x, y| x > y)
57}
58
59/// Elementwise greater-than-or-equal test.
60pub fn greater_equal<T, D>(a: &Array<T, D>, b: &Array<T, D>) -> FerrayResult<Array<bool, D>>
61where
62    T: Element + PartialOrd + Copy,
63    D: Dimension,
64{
65    binary_map_op(a, b, |x, y| x >= y)
66}
67
68/// Test whether two arrays have the same shape and elements.
69pub fn array_equal<T, D>(a: &Array<T, D>, b: &Array<T, D>) -> bool
70where
71    T: Element + PartialEq,
72    D: Dimension,
73{
74    if a.shape() != b.shape() {
75        return false;
76    }
77    a.iter().zip(b.iter()).all(|(x, y)| x == y)
78}
79
80/// Test whether two arrays are element-wise equal within a tolerance,
81/// or broadcastable to the same shape and element-wise equal.
82///
83/// For arrays of the same shape, this is the same as `array_equal`.
84pub fn array_equiv<T, D>(a: &Array<T, D>, b: &Array<T, D>) -> bool
85where
86    T: Element + PartialEq,
87    D: Dimension,
88{
89    // For same-dimension arrays, just check equality
90    array_equal(a, b)
91}
92
93/// Test whether two arrays are element-wise close within tolerances.
94///
95/// |a - b| <= atol + rtol * |b|
96pub fn allclose<T, D>(a: &Array<T, D>, b: &Array<T, D>, rtol: T, atol: T) -> FerrayResult<bool>
97where
98    T: Element + Float,
99    D: Dimension,
100{
101    let close = isclose(a, b, rtol, atol, false)?;
102    Ok(close.iter().all(|&x| x))
103}
104
105/// Elementwise close-within-tolerance test.
106///
107/// |a - b| <= atol + rtol * |b|
108///
109/// If `equal_nan` is true, NaN values in corresponding positions are considered close.
110pub fn isclose<T, D>(
111    a: &Array<T, D>,
112    b: &Array<T, D>,
113    rtol: T,
114    atol: T,
115    equal_nan: bool,
116) -> FerrayResult<Array<bool, D>>
117where
118    T: Element + Float,
119    D: Dimension,
120{
121    binary_map_op(a, b, |x, y| {
122        if equal_nan && x.is_nan() && y.is_nan() {
123            return true;
124        }
125        if x.is_nan() || y.is_nan() {
126            return false;
127        }
128        (x - y).abs() <= atol + rtol * y.abs()
129    })
130}
131
132#[cfg(test)]
133mod tests {
134    use super::*;
135    use ferray_core::dimension::Ix1;
136
137    fn arr1(data: Vec<f64>) -> Array<f64, Ix1> {
138        let n = data.len();
139        Array::from_vec(Ix1::new([n]), data).unwrap()
140    }
141
142    fn arr1_i32(data: Vec<i32>) -> Array<i32, Ix1> {
143        let n = data.len();
144        Array::from_vec(Ix1::new([n]), data).unwrap()
145    }
146
147    #[test]
148    fn test_equal() {
149        let a = arr1_i32(vec![1, 2, 3]);
150        let b = arr1_i32(vec![1, 5, 3]);
151        let r = equal(&a, &b).unwrap();
152        assert_eq!(r.as_slice().unwrap(), &[true, false, true]);
153    }
154
155    #[test]
156    fn test_not_equal() {
157        let a = arr1_i32(vec![1, 2, 3]);
158        let b = arr1_i32(vec![1, 5, 3]);
159        let r = not_equal(&a, &b).unwrap();
160        assert_eq!(r.as_slice().unwrap(), &[false, true, false]);
161    }
162
163    #[test]
164    fn test_less() {
165        let a = arr1(vec![1.0, 5.0, 3.0]);
166        let b = arr1(vec![2.0, 3.0, 3.0]);
167        let r = less(&a, &b).unwrap();
168        assert_eq!(r.as_slice().unwrap(), &[true, false, false]);
169    }
170
171    #[test]
172    fn test_less_equal() {
173        let a = arr1(vec![1.0, 5.0, 3.0]);
174        let b = arr1(vec![2.0, 3.0, 3.0]);
175        let r = less_equal(&a, &b).unwrap();
176        assert_eq!(r.as_slice().unwrap(), &[true, false, true]);
177    }
178
179    #[test]
180    fn test_greater() {
181        let a = arr1(vec![1.0, 5.0, 3.0]);
182        let b = arr1(vec![2.0, 3.0, 3.0]);
183        let r = greater(&a, &b).unwrap();
184        assert_eq!(r.as_slice().unwrap(), &[false, true, false]);
185    }
186
187    #[test]
188    fn test_greater_equal() {
189        let a = arr1(vec![1.0, 5.0, 3.0]);
190        let b = arr1(vec![2.0, 3.0, 3.0]);
191        let r = greater_equal(&a, &b).unwrap();
192        assert_eq!(r.as_slice().unwrap(), &[false, true, true]);
193    }
194
195    #[test]
196    fn test_array_equal() {
197        let a = arr1(vec![1.0, 2.0, 3.0]);
198        let b = arr1(vec![1.0, 2.0, 3.0]);
199        let c = arr1(vec![1.0, 2.0, 4.0]);
200        assert!(array_equal(&a, &b));
201        assert!(!array_equal(&a, &c));
202    }
203
204    #[test]
205    fn test_array_equal_different_shapes() {
206        let a = arr1(vec![1.0, 2.0]);
207        let b = arr1(vec![1.0, 2.0, 3.0]);
208        assert!(!array_equal(&a, &b));
209    }
210
211    #[test]
212    fn test_allclose() {
213        let a = arr1(vec![1.0, 2.0, 3.0]);
214        let b = arr1(vec![1.0 + 1e-9, 2.0 + 1e-9, 3.0 + 1e-9]);
215        assert!(allclose(&a, &b, 1e-5, 1e-8).unwrap());
216    }
217
218    #[test]
219    fn test_allclose_not_close() {
220        let a = arr1(vec![1.0, 2.0, 3.0]);
221        let b = arr1(vec![1.0, 2.0, 4.0]);
222        assert!(!allclose(&a, &b, 1e-5, 1e-8).unwrap());
223    }
224
225    #[test]
226    fn test_isclose() {
227        let a = arr1(vec![1.0, 2.0, 3.0]);
228        let b = arr1(vec![1.0, 2.1, 3.0]);
229        let r = isclose(&a, &b, 1e-5, 1e-8, false).unwrap();
230        assert_eq!(r.as_slice().unwrap(), &[true, false, true]);
231    }
232
233    #[test]
234    fn test_isclose_equal_nan() {
235        let a = arr1(vec![f64::NAN, 1.0]);
236        let b = arr1(vec![f64::NAN, 1.0]);
237        let r = isclose(&a, &b, 1e-5, 1e-8, true).unwrap();
238        assert_eq!(r.as_slice().unwrap(), &[true, true]);
239    }
240}