Skip to main content

ferray_ufunc/ops/
comparison.rs

1// ferray-ufunc: Comparison functions
2//
3// equal, not_equal, less, less_equal, greater, greater_equal,
4// array_equal, array_equiv, allclose, isclose
5//
6// Cross-rank broadcasting variants (`equal_broadcast`, …) added for #387:
7// they accept different dimension types and return `Array<bool, IxDyn>`,
8// matching the `add_broadcast` / `subtract_broadcast` pattern from
9// arithmetic.rs.
10
11use ferray_core::Array;
12use ferray_core::dimension::{Dimension, IxDyn};
13use ferray_core::dtype::Element;
14use ferray_core::error::FerrayResult;
15use num_traits::Float;
16
17use crate::helpers::{binary_broadcast_map_op, binary_map_op};
18
19/// Elementwise equality test.
20pub fn equal<T, D>(a: &Array<T, D>, b: &Array<T, D>) -> FerrayResult<Array<bool, D>>
21where
22    T: Element + PartialEq + Copy,
23    D: Dimension,
24{
25    binary_map_op(a, b, |x, y| x == y)
26}
27
28/// Elementwise inequality test.
29pub fn not_equal<T, D>(a: &Array<T, D>, b: &Array<T, D>) -> FerrayResult<Array<bool, D>>
30where
31    T: Element + PartialEq + Copy,
32    D: Dimension,
33{
34    binary_map_op(a, b, |x, y| x != y)
35}
36
37/// Elementwise less-than test.
38pub fn less<T, D>(a: &Array<T, D>, b: &Array<T, D>) -> FerrayResult<Array<bool, D>>
39where
40    T: Element + PartialOrd + Copy,
41    D: Dimension,
42{
43    binary_map_op(a, b, |x, y| x < y)
44}
45
46/// Elementwise less-than-or-equal test.
47pub fn less_equal<T, D>(a: &Array<T, D>, b: &Array<T, D>) -> FerrayResult<Array<bool, D>>
48where
49    T: Element + PartialOrd + Copy,
50    D: Dimension,
51{
52    binary_map_op(a, b, |x, y| x <= y)
53}
54
55/// Elementwise greater-than test.
56pub fn greater<T, D>(a: &Array<T, D>, b: &Array<T, D>) -> FerrayResult<Array<bool, D>>
57where
58    T: Element + PartialOrd + Copy,
59    D: Dimension,
60{
61    binary_map_op(a, b, |x, y| x > y)
62}
63
64/// Elementwise greater-than-or-equal test.
65pub fn greater_equal<T, D>(a: &Array<T, D>, b: &Array<T, D>) -> FerrayResult<Array<bool, D>>
66where
67    T: Element + PartialOrd + Copy,
68    D: Dimension,
69{
70    binary_map_op(a, b, |x, y| x >= y)
71}
72
73// ---------------------------------------------------------------------------
74// Cross-rank broadcasting variants (#387)
75//
76// Same semantics as the same-D versions above but accept distinct dimension
77// types `D1` / `D2`, broadcasting both inputs into a common shape and
78// returning `Array<bool, IxDyn>`. Mirrors the `add_broadcast` /
79// `subtract_broadcast` family in arithmetic.rs so callers can use either
80// `equal(a, b)` (same shape / same rank) or `equal_broadcast(a, b)` (Ix2
81// against Ix1, scalar threshold against ND, etc.).
82// ---------------------------------------------------------------------------
83
84/// Cross-rank broadcasting equality test.
85pub fn equal_broadcast<T, D1, D2>(
86    a: &Array<T, D1>,
87    b: &Array<T, D2>,
88) -> FerrayResult<Array<bool, IxDyn>>
89where
90    T: Element + PartialEq + Copy,
91    D1: Dimension,
92    D2: Dimension,
93{
94    binary_broadcast_map_op(a, b, |x, y| x == y)
95}
96
97/// Cross-rank broadcasting inequality test.
98pub fn not_equal_broadcast<T, D1, D2>(
99    a: &Array<T, D1>,
100    b: &Array<T, D2>,
101) -> FerrayResult<Array<bool, IxDyn>>
102where
103    T: Element + PartialEq + Copy,
104    D1: Dimension,
105    D2: Dimension,
106{
107    binary_broadcast_map_op(a, b, |x, y| x != y)
108}
109
110/// Cross-rank broadcasting less-than test.
111pub fn less_broadcast<T, D1, D2>(
112    a: &Array<T, D1>,
113    b: &Array<T, D2>,
114) -> FerrayResult<Array<bool, IxDyn>>
115where
116    T: Element + PartialOrd + Copy,
117    D1: Dimension,
118    D2: Dimension,
119{
120    binary_broadcast_map_op(a, b, |x, y| x < y)
121}
122
123/// Cross-rank broadcasting less-than-or-equal test.
124pub fn less_equal_broadcast<T, D1, D2>(
125    a: &Array<T, D1>,
126    b: &Array<T, D2>,
127) -> FerrayResult<Array<bool, IxDyn>>
128where
129    T: Element + PartialOrd + Copy,
130    D1: Dimension,
131    D2: Dimension,
132{
133    binary_broadcast_map_op(a, b, |x, y| x <= y)
134}
135
136/// Cross-rank broadcasting greater-than test.
137pub fn greater_broadcast<T, D1, D2>(
138    a: &Array<T, D1>,
139    b: &Array<T, D2>,
140) -> FerrayResult<Array<bool, IxDyn>>
141where
142    T: Element + PartialOrd + Copy,
143    D1: Dimension,
144    D2: Dimension,
145{
146    binary_broadcast_map_op(a, b, |x, y| x > y)
147}
148
149/// Cross-rank broadcasting greater-than-or-equal test.
150pub fn greater_equal_broadcast<T, D1, D2>(
151    a: &Array<T, D1>,
152    b: &Array<T, D2>,
153) -> FerrayResult<Array<bool, IxDyn>>
154where
155    T: Element + PartialOrd + Copy,
156    D1: Dimension,
157    D2: Dimension,
158{
159    binary_broadcast_map_op(a, b, |x, y| x >= y)
160}
161
162/// Cross-rank broadcasting close-within-tolerance test.
163///
164/// Same `|a - b| <= atol + rtol * |b|` semantics as [`isclose`], but
165/// accepts inputs with distinct ranks. Returns `Array<bool, IxDyn>`.
166pub fn isclose_broadcast<T, D1, D2>(
167    a: &Array<T, D1>,
168    b: &Array<T, D2>,
169    rtol: T,
170    atol: T,
171    equal_nan: bool,
172) -> FerrayResult<Array<bool, IxDyn>>
173where
174    T: Element + Float,
175    D1: Dimension,
176    D2: Dimension,
177{
178    binary_broadcast_map_op(a, b, |x, y| {
179        if equal_nan && x.is_nan() && y.is_nan() {
180            return true;
181        }
182        if x.is_nan() || y.is_nan() {
183            return false;
184        }
185        (x - y).abs() <= atol + rtol * y.abs()
186    })
187}
188
189/// Test whether two arrays have the same shape and elements.
190pub fn array_equal<T, D>(a: &Array<T, D>, b: &Array<T, D>) -> bool
191where
192    T: Element + PartialEq,
193    D: Dimension,
194{
195    if a.shape() != b.shape() {
196        return false;
197    }
198    a.iter().zip(b.iter()).all(|(x, y)| x == y)
199}
200
201/// Test whether two arrays are element-wise equal within a tolerance,
202/// or broadcastable to the same shape and element-wise equal.
203///
204/// For arrays of the same shape, this is the same as `array_equal`.
205pub fn array_equiv<T, D>(a: &Array<T, D>, b: &Array<T, D>) -> bool
206where
207    T: Element + PartialEq,
208    D: Dimension,
209{
210    // For same-dimension arrays, just check equality
211    array_equal(a, b)
212}
213
214/// Test whether two arrays are element-wise close within tolerances.
215///
216/// |a - b| <= atol + rtol * |b|
217pub fn allclose<T, D>(a: &Array<T, D>, b: &Array<T, D>, rtol: T, atol: T) -> FerrayResult<bool>
218where
219    T: Element + Float,
220    D: Dimension,
221{
222    let close = isclose(a, b, rtol, atol, false)?;
223    Ok(close.iter().all(|&x| x))
224}
225
226/// Elementwise close-within-tolerance test.
227///
228/// |a - b| <= atol + rtol * |b|
229///
230/// If `equal_nan` is true, NaN values in corresponding positions are considered close.
231pub fn isclose<T, D>(
232    a: &Array<T, D>,
233    b: &Array<T, D>,
234    rtol: T,
235    atol: T,
236    equal_nan: bool,
237) -> FerrayResult<Array<bool, D>>
238where
239    T: Element + Float,
240    D: Dimension,
241{
242    binary_map_op(a, b, |x, y| {
243        if equal_nan && x.is_nan() && y.is_nan() {
244            return true;
245        }
246        if x.is_nan() || y.is_nan() {
247            return false;
248        }
249        (x - y).abs() <= atol + rtol * y.abs()
250    })
251}
252
253#[cfg(test)]
254mod tests {
255    use super::*;
256    use ferray_core::dimension::Ix1;
257
258    use crate::test_util::arr1;
259
260    fn arr1_i32(data: Vec<i32>) -> Array<i32, Ix1> {
261        let n = data.len();
262        Array::from_vec(Ix1::new([n]), data).unwrap()
263    }
264
265    #[test]
266    fn test_equal() {
267        let a = arr1_i32(vec![1, 2, 3]);
268        let b = arr1_i32(vec![1, 5, 3]);
269        let r = equal(&a, &b).unwrap();
270        assert_eq!(r.as_slice().unwrap(), &[true, false, true]);
271    }
272
273    #[test]
274    fn test_not_equal() {
275        let a = arr1_i32(vec![1, 2, 3]);
276        let b = arr1_i32(vec![1, 5, 3]);
277        let r = not_equal(&a, &b).unwrap();
278        assert_eq!(r.as_slice().unwrap(), &[false, true, false]);
279    }
280
281    #[test]
282    fn test_less() {
283        let a = arr1(vec![1.0, 5.0, 3.0]);
284        let b = arr1(vec![2.0, 3.0, 3.0]);
285        let r = less(&a, &b).unwrap();
286        assert_eq!(r.as_slice().unwrap(), &[true, false, false]);
287    }
288
289    #[test]
290    fn test_less_equal() {
291        let a = arr1(vec![1.0, 5.0, 3.0]);
292        let b = arr1(vec![2.0, 3.0, 3.0]);
293        let r = less_equal(&a, &b).unwrap();
294        assert_eq!(r.as_slice().unwrap(), &[true, false, true]);
295    }
296
297    #[test]
298    fn test_greater() {
299        let a = arr1(vec![1.0, 5.0, 3.0]);
300        let b = arr1(vec![2.0, 3.0, 3.0]);
301        let r = greater(&a, &b).unwrap();
302        assert_eq!(r.as_slice().unwrap(), &[false, true, false]);
303    }
304
305    #[test]
306    fn test_greater_equal() {
307        let a = arr1(vec![1.0, 5.0, 3.0]);
308        let b = arr1(vec![2.0, 3.0, 3.0]);
309        let r = greater_equal(&a, &b).unwrap();
310        assert_eq!(r.as_slice().unwrap(), &[false, true, true]);
311    }
312
313    #[test]
314    fn test_array_equal() {
315        let a = arr1(vec![1.0, 2.0, 3.0]);
316        let b = arr1(vec![1.0, 2.0, 3.0]);
317        let c = arr1(vec![1.0, 2.0, 4.0]);
318        assert!(array_equal(&a, &b));
319        assert!(!array_equal(&a, &c));
320    }
321
322    #[test]
323    fn test_array_equal_different_shapes() {
324        let a = arr1(vec![1.0, 2.0]);
325        let b = arr1(vec![1.0, 2.0, 3.0]);
326        assert!(!array_equal(&a, &b));
327    }
328
329    #[test]
330    fn test_allclose() {
331        let a = arr1(vec![1.0, 2.0, 3.0]);
332        let b = arr1(vec![1.0 + 1e-9, 2.0 + 1e-9, 3.0 + 1e-9]);
333        assert!(allclose(&a, &b, 1e-5, 1e-8).unwrap());
334    }
335
336    #[test]
337    fn test_allclose_not_close() {
338        let a = arr1(vec![1.0, 2.0, 3.0]);
339        let b = arr1(vec![1.0, 2.0, 4.0]);
340        assert!(!allclose(&a, &b, 1e-5, 1e-8).unwrap());
341    }
342
343    #[test]
344    fn test_isclose() {
345        let a = arr1(vec![1.0, 2.0, 3.0]);
346        let b = arr1(vec![1.0, 2.1, 3.0]);
347        let r = isclose(&a, &b, 1e-5, 1e-8, false).unwrap();
348        assert_eq!(r.as_slice().unwrap(), &[true, false, true]);
349    }
350
351    #[test]
352    fn test_isclose_equal_nan() {
353        let a = arr1(vec![f64::NAN, 1.0]);
354        let b = arr1(vec![f64::NAN, 1.0]);
355        let r = isclose(&a, &b, 1e-5, 1e-8, true).unwrap();
356        assert_eq!(r.as_slice().unwrap(), &[true, true]);
357    }
358
359    // -----------------------------------------------------------------------
360    // Broadcasting tests for comparison ops (issue #379)
361    // -----------------------------------------------------------------------
362
363    #[test]
364    fn test_equal_broadcasts() {
365        use ferray_core::dimension::Ix2;
366        let a = Array::<i32, Ix2>::from_vec(Ix2::new([2, 3]), vec![1, 2, 3, 4, 2, 6]).unwrap();
367        let b = Array::<i32, Ix2>::from_vec(Ix2::new([1, 3]), vec![1, 2, 3]).unwrap();
368        let r = equal(&a, &b).unwrap();
369        assert_eq!(r.shape(), &[2, 3]);
370        assert_eq!(
371            r.iter().copied().collect::<Vec<_>>(),
372            vec![true, true, true, false, true, false]
373        );
374    }
375
376    #[test]
377    fn test_less_broadcasts() {
378        use ferray_core::dimension::Ix2;
379        let a = Array::<f64, Ix2>::from_vec(Ix2::new([3, 1]), vec![1.0, 5.0, 10.0]).unwrap();
380        let b = Array::<f64, Ix2>::from_vec(Ix2::new([1, 3]), vec![3.0, 5.0, 7.0]).unwrap();
381        let r = less(&a, &b).unwrap();
382        assert_eq!(r.shape(), &[3, 3]);
383        assert_eq!(
384            r.iter().copied().collect::<Vec<_>>(),
385            vec![
386                true, true, true, // 1 < {3,5,7}
387                false, false, true, // 5 < {3,5,7}
388                false, false, false, // 10 < {3,5,7}
389            ]
390        );
391    }
392
393    #[test]
394    fn test_isclose_broadcasts() {
395        use ferray_core::dimension::Ix2;
396        let a = Array::<f64, Ix2>::from_vec(
397            Ix2::new([2, 3]),
398            vec![1.0, 2.0, 3.0, 1.0001, 2.0001, 3.0001],
399        )
400        .unwrap();
401        let b = Array::<f64, Ix2>::from_vec(Ix2::new([1, 3]), vec![1.0, 2.0, 3.0]).unwrap();
402        let r = isclose(&a, &b, 1e-3, 1e-8, false).unwrap();
403        assert_eq!(r.shape(), &[2, 3]);
404        assert_eq!(
405            r.iter().copied().collect::<Vec<_>>(),
406            vec![true, true, true, true, true, true]
407        );
408    }
409
410    // -----------------------------------------------------------------------
411    // Cross-rank broadcasting comparison ops (#387)
412    // -----------------------------------------------------------------------
413
414    #[test]
415    fn equal_broadcast_ix2_against_ix1() {
416        use ferray_core::dimension::Ix2;
417        let a = Array::<i32, Ix2>::from_vec(Ix2::new([2, 3]), vec![1, 2, 3, 4, 2, 6]).unwrap();
418        let b = arr1_i32(vec![1, 2, 3]);
419        let r = equal_broadcast(&a, &b).unwrap();
420        assert_eq!(r.shape(), &[2, 3]);
421        assert_eq!(
422            r.iter().copied().collect::<Vec<_>>(),
423            vec![true, true, true, false, true, false]
424        );
425    }
426
427    #[test]
428    fn not_equal_broadcast_ix2_against_ix1() {
429        use ferray_core::dimension::Ix2;
430        let a = Array::<i32, Ix2>::from_vec(Ix2::new([2, 3]), vec![1, 2, 3, 4, 2, 6]).unwrap();
431        let b = arr1_i32(vec![1, 2, 3]);
432        let r = not_equal_broadcast(&a, &b).unwrap();
433        assert_eq!(
434            r.iter().copied().collect::<Vec<_>>(),
435            vec![false, false, false, true, false, true]
436        );
437    }
438
439    #[test]
440    fn less_broadcast_ix2_against_scalar_like_ix1() {
441        // The most common pattern: arr > threshold where threshold is a
442        // length-1 1-D stand-in for a scalar.
443        use ferray_core::dimension::Ix2;
444        let a = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
445            .unwrap();
446        let threshold = arr1(vec![3.0]);
447        let r = less_broadcast(&a, &threshold).unwrap();
448        assert_eq!(r.shape(), &[2, 3]);
449        assert_eq!(
450            r.iter().copied().collect::<Vec<_>>(),
451            vec![true, true, false, false, false, false]
452        );
453    }
454
455    #[test]
456    fn greater_broadcast_ix2_against_ix1() {
457        use ferray_core::dimension::Ix2;
458        let a = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![0.0, 5.0, 10.0, 1.0, 5.0, 9.0])
459            .unwrap();
460        let b = arr1(vec![1.0, 5.0, 9.0]);
461        let r = greater_broadcast(&a, &b).unwrap();
462        assert_eq!(
463            r.iter().copied().collect::<Vec<_>>(),
464            vec![false, false, true, false, false, false]
465        );
466    }
467
468    #[test]
469    fn less_equal_broadcast_ix1_against_ix2() {
470        // The reverse direction: 1-D LHS broadcast against 2-D RHS.
471        use ferray_core::dimension::Ix2;
472        let a = arr1(vec![1.0, 5.0, 9.0]);
473        let b = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![1.0, 5.0, 9.0, 0.5, 5.0, 10.0])
474            .unwrap();
475        let r = less_equal_broadcast(&a, &b).unwrap();
476        assert_eq!(r.shape(), &[2, 3]);
477        assert_eq!(
478            r.iter().copied().collect::<Vec<_>>(),
479            vec![true, true, true, false, true, true]
480        );
481    }
482
483    #[test]
484    fn greater_equal_broadcast_ix1_against_ix2() {
485        use ferray_core::dimension::Ix2;
486        let a = arr1(vec![5.0, 5.0, 5.0]);
487        let b = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![5.0, 4.0, 6.0, 5.0, 5.0, 5.0])
488            .unwrap();
489        let r = greater_equal_broadcast(&a, &b).unwrap();
490        assert_eq!(
491            r.iter().copied().collect::<Vec<_>>(),
492            vec![true, true, false, true, true, true]
493        );
494    }
495
496    #[test]
497    fn isclose_broadcast_ix2_against_ix1() {
498        use ferray_core::dimension::Ix2;
499        let a =
500            Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![1.0, 2.0, 3.0, 1.0001, 2.5, 3.0001])
501                .unwrap();
502        let b = arr1(vec![1.0, 2.0, 3.0]);
503        let r = isclose_broadcast(&a, &b, 1e-3, 1e-8, false).unwrap();
504        assert_eq!(r.shape(), &[2, 3]);
505        assert_eq!(
506            r.iter().copied().collect::<Vec<_>>(),
507            vec![true, true, true, true, false, true]
508        );
509    }
510
511    #[test]
512    fn equal_broadcast_returns_ixdyn_dim_type() {
513        // The return type must be Array<bool, IxDyn> regardless of the
514        // input dimension types — that's the whole point of the *_broadcast
515        // family.
516        use ferray_core::dimension::{Ix2, IxDyn};
517        let a = Array::<i32, Ix2>::from_vec(Ix2::new([2, 2]), vec![1, 2, 3, 4]).unwrap();
518        let b = arr1_i32(vec![1, 2]);
519        let r: Array<bool, IxDyn> = equal_broadcast(&a, &b).unwrap();
520        assert_eq!(r.ndim(), 2);
521    }
522
523    #[test]
524    fn equal_broadcast_incompatible_shapes_errors() {
525        let a = arr1_i32(vec![1, 2, 3]);
526        let b = arr1_i32(vec![1, 2]);
527        assert!(equal_broadcast(&a, &b).is_err());
528    }
529}