jax_rs/ops/
comparison.rs

1//! Comparison operations on arrays.
2
3use crate::{buffer::Buffer, Array, DType, Device};
4
5#[cfg(test)]
6use crate::Shape;
7
8/// Apply a comparison function element-wise to two arrays with broadcasting.
9fn compare_op<F>(lhs: &Array, rhs: &Array, f: F) -> Array
10where
11    F: Fn(f32, f32) -> bool,
12{
13    assert_eq!(lhs.dtype(), DType::Float32, "Only Float32 supported");
14    assert_eq!(rhs.dtype(), DType::Float32, "Only Float32 supported");
15    assert_eq!(lhs.device(), Device::Cpu, "Only CPU supported for now");
16    assert_eq!(rhs.device(), Device::Cpu, "Only CPU supported for now");
17
18    // Check if shapes are broadcast-compatible
19    let result_shape = lhs
20        .shape()
21        .broadcast_with(rhs.shape())
22        .expect("Shapes are not broadcast-compatible");
23
24    let lhs_data = lhs.to_vec();
25    let rhs_data = rhs.to_vec();
26
27    let result_data: Vec<f32> = if lhs.shape() == rhs.shape() {
28        // Same shape - simple element-wise operation
29        lhs_data
30            .iter()
31            .zip(rhs_data.iter())
32            .map(|(&a, &b)| if f(a, b) { 1.0 } else { 0.0 })
33            .collect()
34    } else {
35        // Need broadcasting
36        let size = result_shape.size();
37        (0..size)
38            .map(|i| {
39                let lhs_idx = crate::ops::binary::broadcast_index(
40                    i,
41                    &result_shape,
42                    lhs.shape(),
43                );
44                let rhs_idx = crate::ops::binary::broadcast_index(
45                    i,
46                    &result_shape,
47                    rhs.shape(),
48                );
49                if f(lhs_data[lhs_idx], rhs_data[rhs_idx]) {
50                    1.0
51                } else {
52                    0.0
53                }
54            })
55            .collect()
56    };
57
58    let buffer = Buffer::from_f32(result_data, Device::Cpu);
59    Array::from_buffer(buffer, result_shape)
60}
61
62impl Array {
63    /// Element-wise less than comparison.
64    ///
65    /// Returns an array of 1.0 where condition is true, 0.0 otherwise.
66    ///
67    /// # Examples
68    ///
69    /// ```
70    /// # use jax_rs::{Array, Shape};
71    /// let a = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
72    /// let b = Array::from_vec(vec![2.0, 2.0, 2.0], Shape::new(vec![3]));
73    /// let c = a.lt(&b);
74    /// assert_eq!(c.to_vec(), vec![1.0, 0.0, 0.0]);
75    /// ```
76    pub fn lt(&self, other: &Array) -> Array {
77        compare_op(self, other, |a, b| a < b)
78    }
79
80    /// Element-wise less than or equal comparison.
81    pub fn le(&self, other: &Array) -> Array {
82        compare_op(self, other, |a, b| a <= b)
83    }
84
85    /// Element-wise greater than comparison.
86    pub fn gt(&self, other: &Array) -> Array {
87        compare_op(self, other, |a, b| a > b)
88    }
89
90    /// Element-wise greater than or equal comparison.
91    pub fn ge(&self, other: &Array) -> Array {
92        compare_op(self, other, |a, b| a >= b)
93    }
94
95    /// Element-wise equality comparison.
96    ///
97    /// Note: For floating point, this is exact equality. Use `allclose` for
98    /// approximate equality.
99    pub fn eq(&self, other: &Array) -> Array {
100        compare_op(self, other, |a, b| a == b)
101    }
102
103    /// Element-wise equality comparison with a scalar.
104    ///
105    /// Returns an array where each element is 1.0 if equal to the scalar, 0.0 otherwise.
106    pub fn eq_scalar(&self, value: f32) -> Array {
107        let data = self.to_vec();
108        let result: Vec<f32> = data
109            .iter()
110            .map(|&x| if x == value { 1.0 } else { 0.0 })
111            .collect();
112        Array::from_vec(result, self.shape().clone())
113    }
114
115    /// Element-wise inequality comparison.
116    pub fn ne(&self, other: &Array) -> Array {
117        compare_op(self, other, |a, b| a != b)
118    }
119
120    /// Logical NOT element-wise.
121    ///
122    /// Treats 0.0 as false, non-zero as true.
123    ///
124    /// # Examples
125    ///
126    /// ```
127    /// # use jax_rs::{Array, Shape};
128    /// let a = Array::from_vec(vec![0.0, 1.0, 0.0], Shape::new(vec![3]));
129    /// let b = a.logical_not();
130    /// assert_eq!(b.to_vec(), vec![1.0, 0.0, 1.0]);
131    /// ```
132    pub fn logical_not(&self) -> Array {
133        assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
134        let data = self.to_vec();
135        let result: Vec<f32> = data
136            .iter()
137            .map(|&x| if x == 0.0 { 1.0 } else { 0.0 })
138            .collect();
139        Array::from_vec(result, self.shape().clone())
140    }
141
142    /// Logical AND element-wise.
143    ///
144    /// # Examples
145    ///
146    /// ```
147    /// # use jax_rs::{Array, Shape};
148    /// let a = Array::from_vec(vec![1.0, 1.0, 0.0], Shape::new(vec![3]));
149    /// let b = Array::from_vec(vec![1.0, 0.0, 0.0], Shape::new(vec![3]));
150    /// let c = a.logical_and(&b);
151    /// assert_eq!(c.to_vec(), vec![1.0, 0.0, 0.0]);
152    /// ```
153    pub fn logical_and(&self, other: &Array) -> Array {
154        compare_op(self, other, |a, b| a != 0.0 && b != 0.0)
155    }
156
157    /// Logical OR element-wise.
158    ///
159    /// # Examples
160    ///
161    /// ```
162    /// # use jax_rs::{Array, Shape};
163    /// let a = Array::from_vec(vec![1.0, 1.0, 0.0], Shape::new(vec![3]));
164    /// let b = Array::from_vec(vec![1.0, 0.0, 0.0], Shape::new(vec![3]));
165    /// let c = a.logical_or(&b);
166    /// assert_eq!(c.to_vec(), vec![1.0, 1.0, 0.0]);
167    /// ```
168    pub fn logical_or(&self, other: &Array) -> Array {
169        compare_op(self, other, |a, b| a != 0.0 || b != 0.0)
170    }
171
172    /// Logical XOR element-wise.
173    ///
174    /// # Examples
175    ///
176    /// ```
177    /// # use jax_rs::{Array, Shape};
178    /// let a = Array::from_vec(vec![1.0, 1.0, 0.0], Shape::new(vec![3]));
179    /// let b = Array::from_vec(vec![1.0, 0.0, 0.0], Shape::new(vec![3]));
180    /// let c = a.logical_xor(&b);
181    /// assert_eq!(c.to_vec(), vec![0.0, 1.0, 0.0]);
182    /// ```
183    pub fn logical_xor(&self, other: &Array) -> Array {
184        compare_op(self, other, |a, b| (a != 0.0) != (b != 0.0))
185    }
186
187    /// Test if all elements are true (non-zero).
188    ///
189    /// # Examples
190    ///
191    /// ```
192    /// # use jax_rs::{Array, Shape};
193    /// let a = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
194    /// assert!(a.all());
195    /// let b = Array::from_vec(vec![1.0, 0.0, 3.0], Shape::new(vec![3]));
196    /// assert!(!b.all());
197    /// ```
198    pub fn all(&self) -> bool {
199        let data = self.to_vec();
200        data.iter().all(|&x| x != 0.0)
201    }
202
203    /// Test if any element is true (non-zero).
204    ///
205    /// # Examples
206    ///
207    /// ```
208    /// # use jax_rs::{Array, Shape};
209    /// let a = Array::from_vec(vec![0.0, 0.0, 1.0], Shape::new(vec![3]));
210    /// assert!(a.any());
211    /// let b = Array::from_vec(vec![0.0, 0.0, 0.0], Shape::new(vec![3]));
212    /// assert!(!b.any());
213    /// ```
214    pub fn any(&self) -> bool {
215        let data = self.to_vec();
216        data.iter().any(|&x| x != 0.0)
217    }
218
219    /// Count the number of true (non-zero) elements.
220    ///
221    /// # Examples
222    ///
223    /// ```
224    /// # use jax_rs::{Array, Shape};
225    /// let a = Array::from_vec(vec![1.0, 0.0, 1.0, 0.0, 1.0], Shape::new(vec![5]));
226    /// assert_eq!(a.count_nonzero(), 3);
227    /// ```
228    pub fn count_nonzero(&self) -> usize {
229        let data = self.to_vec();
230        data.iter().filter(|&&x| x != 0.0).count()
231    }
232
233    /// Test if two arrays are element-wise equal within a tolerance.
234    ///
235    /// Returns true if all elements satisfy: |a - b| <= atol + rtol * |b|
236    ///
237    /// # Arguments
238    ///
239    /// * `other` - Array to compare with
240    /// * `rtol` - Relative tolerance
241    /// * `atol` - Absolute tolerance
242    ///
243    /// # Examples
244    ///
245    /// ```
246    /// # use jax_rs::{Array, Shape};
247    /// let a = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
248    /// let b = Array::from_vec(vec![1.0001, 2.0001, 3.0001], Shape::new(vec![3]));
249    /// assert!(a.allclose(&b, 1e-3, 1e-3));
250    /// assert!(!a.allclose(&b, 1e-5, 1e-5));
251    /// ```
252    pub fn allclose(&self, other: &Array, rtol: f32, atol: f32) -> bool {
253        assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
254        assert_eq!(other.dtype(), DType::Float32, "Only Float32 supported");
255
256        // Check if shapes are broadcast-compatible
257        let result_shape = match self.shape().broadcast_with(other.shape()) {
258            Some(shape) => shape,
259            None => return false,
260        };
261
262        let self_data = self.to_vec();
263        let other_data = other.to_vec();
264
265        if self.shape() == other.shape() {
266            // Same shape - simple element-wise comparison
267            self_data.iter().zip(other_data.iter()).all(|(&a, &b)| {
268                let diff = (a - b).abs();
269                diff <= atol + rtol * b.abs()
270            })
271        } else {
272            // Need broadcasting
273            let size = result_shape.size();
274            (0..size).all(|i| {
275                let self_idx =
276                    crate::ops::binary::broadcast_index(i, &result_shape, self.shape());
277                let other_idx =
278                    crate::ops::binary::broadcast_index(i, &result_shape, other.shape());
279                let a = self_data[self_idx];
280                let b = other_data[other_idx];
281                let diff = (a - b).abs();
282                diff <= atol + rtol * b.abs()
283            })
284        }
285    }
286
287    /// Element-wise test if values are close within a tolerance.
288    ///
289    /// Returns an array of 1.0 where |a - b| <= atol + rtol * |b|, 0.0 otherwise.
290    ///
291    /// # Examples
292    ///
293    /// ```
294    /// # use jax_rs::{Array, Shape};
295    /// let a = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
296    /// let b = Array::from_vec(vec![1.0001, 2.1, 3.0001], Shape::new(vec![3]));
297    /// let c = a.isclose(&b, 1e-3, 1e-3);
298    /// assert_eq!(c.to_vec(), vec![1.0, 0.0, 1.0]);
299    /// ```
300    pub fn isclose(&self, other: &Array, rtol: f32, atol: f32) -> Array {
301        assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
302        assert_eq!(other.dtype(), DType::Float32, "Only Float32 supported");
303
304        // Check if shapes are broadcast-compatible
305        let result_shape = self
306            .shape()
307            .broadcast_with(other.shape())
308            .expect("Shapes are not broadcast-compatible");
309
310        let self_data = self.to_vec();
311        let other_data = other.to_vec();
312
313        let result_data: Vec<f32> = if self.shape() == other.shape() {
314            // Same shape - simple element-wise operation
315            self_data
316                .iter()
317                .zip(other_data.iter())
318                .map(|(&a, &b)| {
319                    let diff = (a - b).abs();
320                    if diff <= atol + rtol * b.abs() {
321                        1.0
322                    } else {
323                        0.0
324                    }
325                })
326                .collect()
327        } else {
328            // Need broadcasting
329            let size = result_shape.size();
330            (0..size)
331                .map(|i| {
332                    let self_idx =
333                        crate::ops::binary::broadcast_index(i, &result_shape, self.shape());
334                    let other_idx =
335                        crate::ops::binary::broadcast_index(i, &result_shape, other.shape());
336                    let a = self_data[self_idx];
337                    let b = other_data[other_idx];
338                    let diff = (a - b).abs();
339                    if diff <= atol + rtol * b.abs() {
340                        1.0
341                    } else {
342                        0.0
343                    }
344                })
345                .collect()
346        };
347
348        let buffer = Buffer::from_f32(result_data, Device::Cpu);
349        Array::from_buffer(buffer, result_shape)
350    }
351
352    /// Test if two arrays have the same shape and elements.
353    ///
354    /// This is exact equality - for approximate equality use `allclose`.
355    ///
356    /// # Examples
357    ///
358    /// ```
359    /// # use jax_rs::{Array, Shape};
360    /// let a = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
361    /// let b = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
362    /// let c = Array::from_vec(vec![1.0, 2.0, 3.1], Shape::new(vec![3]));
363    /// assert!(a.array_equal(&b));
364    /// assert!(!a.array_equal(&c));
365    /// ```
366    pub fn array_equal(&self, other: &Array) -> bool {
367        if self.shape() != other.shape() {
368            return false;
369        }
370        if self.dtype() != other.dtype() {
371            return false;
372        }
373
374        let self_data = self.to_vec();
375        let other_data = other.to_vec();
376        self_data == other_data
377    }
378
379    /// Test if arrays can be broadcast to the same shape and are equal.
380    ///
381    /// Unlike `array_equal`, this allows broadcasting.
382    ///
383    /// # Examples
384    ///
385    /// ```
386    /// # use jax_rs::{Array, Shape};
387    /// let a = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
388    /// let b = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![1, 3]));
389    /// assert!(a.array_equiv(&b));
390    /// ```
391    pub fn array_equiv(&self, other: &Array) -> bool {
392        assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
393        assert_eq!(other.dtype(), DType::Float32, "Only Float32 supported");
394
395        // Check if shapes are broadcast-compatible
396        let result_shape = match self.shape().broadcast_with(other.shape()) {
397            Some(shape) => shape,
398            None => return false,
399        };
400
401        let self_data = self.to_vec();
402        let other_data = other.to_vec();
403
404        if self.shape() == other.shape() {
405            // Same shape - simple element-wise comparison
406            self_data == other_data
407        } else {
408            // Need broadcasting
409            let size = result_shape.size();
410            (0..size).all(|i| {
411                let self_idx =
412                    crate::ops::binary::broadcast_index(i, &result_shape, self.shape());
413                let other_idx =
414                    crate::ops::binary::broadcast_index(i, &result_shape, other.shape());
415                self_data[self_idx] == other_data[other_idx]
416            })
417        }
418    }
419
420    /// Element-wise greater comparison (alias for gt).
421    ///
422    /// # Examples
423    ///
424    /// ```
425    /// # use jax_rs::{Array, Shape};
426    /// let a = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
427    /// let b = Array::from_vec(vec![2.0, 2.0, 2.0], Shape::new(vec![3]));
428    /// let c = a.greater(&b);
429    /// assert_eq!(c.to_vec(), vec![0.0, 0.0, 1.0]);
430    /// ```
431    pub fn greater(&self, other: &Array) -> Array {
432        self.gt(other)
433    }
434
435    /// Element-wise less comparison (alias for lt).
436    ///
437    /// # Examples
438    ///
439    /// ```
440    /// # use jax_rs::{Array, Shape};
441    /// let a = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
442    /// let b = Array::from_vec(vec![2.0, 2.0, 2.0], Shape::new(vec![3]));
443    /// let c = a.less(&b);
444    /// assert_eq!(c.to_vec(), vec![1.0, 0.0, 0.0]);
445    /// ```
446    pub fn less(&self, other: &Array) -> Array {
447        self.lt(other)
448    }
449
450    /// Element-wise greater-or-equal comparison (alias for ge).
451    ///
452    /// # Examples
453    ///
454    /// ```
455    /// # use jax_rs::{Array, Shape};
456    /// let a = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
457    /// let b = Array::from_vec(vec![2.0, 2.0, 2.0], Shape::new(vec![3]));
458    /// let c = a.greater_equal(&b);
459    /// assert_eq!(c.to_vec(), vec![0.0, 1.0, 1.0]);
460    /// ```
461    pub fn greater_equal(&self, other: &Array) -> Array {
462        self.ge(other)
463    }
464
465    /// Element-wise less-or-equal comparison (alias for le).
466    ///
467    /// # Examples
468    ///
469    /// ```
470    /// # use jax_rs::{Array, Shape};
471    /// let a = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
472    /// let b = Array::from_vec(vec![2.0, 2.0, 2.0], Shape::new(vec![3]));
473    /// let c = a.less_equal(&b);
474    /// assert_eq!(c.to_vec(), vec![1.0, 1.0, 0.0]);
475    /// ```
476    pub fn less_equal(&self, other: &Array) -> Array {
477        self.le(other)
478    }
479
480    /// Test element-wise for real numbers (not infinity or NaN).
481    /// For Float32, returns true for all finite values.
482    ///
483    /// # Examples
484    ///
485    /// ```
486    /// # use jax_rs::{Array, Shape};
487    /// let a = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
488    /// let r = a.isreal();
489    /// assert_eq!(r.to_vec(), vec![1.0, 1.0, 1.0]);
490    /// ```
491    pub fn isreal(&self) -> Array {
492        assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
493
494        let data = self.to_vec();
495        let result_data: Vec<f32> = data
496            .iter()
497            .map(|&x| if x.is_finite() { 1.0 } else { 0.0 })
498            .collect();
499
500        Array::from_vec(result_data, self.shape().clone())
501    }
502
503    /// Test element-wise for complex numbers.
504    /// For Float32 arrays, always returns false (0.0).
505    ///
506    /// # Examples
507    ///
508    /// ```
509    /// # use jax_rs::{Array, Shape};
510    /// let a = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
511    /// let c = a.iscomplex();
512    /// assert_eq!(c.to_vec(), vec![0.0, 0.0, 0.0]);
513    /// ```
514    pub fn iscomplex(&self) -> Array {
515        assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
516        Array::zeros(self.shape().clone(), DType::Float32)
517    }
518
519    /// Test element-wise if values are in an open interval.
520    /// Returns 1.0 where lower < x < upper, 0.0 otherwise.
521    ///
522    /// # Examples
523    ///
524    /// ```
525    /// # use jax_rs::{Array, Shape};
526    /// let a = Array::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0], Shape::new(vec![5]));
527    /// let b = a.isin_range(1.5, 4.5);
528    /// assert_eq!(b.to_vec(), vec![0.0, 1.0, 1.0, 1.0, 0.0]);
529    /// ```
530    pub fn isin_range(&self, lower: f32, upper: f32) -> Array {
531        assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
532
533        let data = self.to_vec();
534        let result_data: Vec<f32> = data
535            .iter()
536            .map(|&x| if x > lower && x < upper { 1.0 } else { 0.0 })
537            .collect();
538
539        Array::from_vec(result_data, self.shape().clone())
540    }
541
542    /// Test element-wise if values are subnormal (denormalized).
543    /// Returns 1.0 where value is subnormal, 0.0 otherwise.
544    ///
545    /// # Examples
546    ///
547    /// ```
548    /// # use jax_rs::{Array, Shape};
549    /// let a = Array::from_vec(vec![1.0, 0.0, 1e-40], Shape::new(vec![3]));
550    /// let b = a.issubnormal();
551    /// // Only 1e-40 is subnormal
552    /// assert_eq!(b.to_vec()[0], 0.0);
553    /// assert_eq!(b.to_vec()[1], 0.0);
554    /// assert_eq!(b.to_vec()[2], 1.0);
555    /// ```
556    pub fn issubnormal(&self) -> Array {
557        assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
558
559        let data = self.to_vec();
560        let result_data: Vec<f32> = data
561            .iter()
562            .map(|&x| if x.is_subnormal() { 1.0 } else { 0.0 })
563            .collect();
564
565        Array::from_vec(result_data, self.shape().clone())
566    }
567}
568
569#[cfg(test)]
570mod tests {
571    use super::*;
572
573    #[test]
574    fn test_lt() {
575        let a = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
576        let b = Array::from_vec(vec![2.0, 2.0, 2.0], Shape::new(vec![3]));
577        let c = a.lt(&b);
578        assert_eq!(c.to_vec(), vec![1.0, 0.0, 0.0]);
579    }
580
581    #[test]
582    fn test_le() {
583        let a = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
584        let b = Array::from_vec(vec![2.0, 2.0, 2.0], Shape::new(vec![3]));
585        let c = a.le(&b);
586        assert_eq!(c.to_vec(), vec![1.0, 1.0, 0.0]);
587    }
588
589    #[test]
590    fn test_gt() {
591        let a = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
592        let b = Array::from_vec(vec![2.0, 2.0, 2.0], Shape::new(vec![3]));
593        let c = a.gt(&b);
594        assert_eq!(c.to_vec(), vec![0.0, 0.0, 1.0]);
595    }
596
597    #[test]
598    fn test_ge() {
599        let a = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
600        let b = Array::from_vec(vec![2.0, 2.0, 2.0], Shape::new(vec![3]));
601        let c = a.ge(&b);
602        assert_eq!(c.to_vec(), vec![0.0, 1.0, 1.0]);
603    }
604
605    #[test]
606    fn test_eq() {
607        let a = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
608        let b = Array::from_vec(vec![1.0, 2.0, 4.0], Shape::new(vec![3]));
609        let c = a.eq(&b);
610        assert_eq!(c.to_vec(), vec![1.0, 1.0, 0.0]);
611    }
612
613    #[test]
614    fn test_ne() {
615        let a = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
616        let b = Array::from_vec(vec![1.0, 2.0, 4.0], Shape::new(vec![3]));
617        let c = a.ne(&b);
618        assert_eq!(c.to_vec(), vec![0.0, 0.0, 1.0]);
619    }
620
621    #[test]
622    fn test_comparison_broadcast() {
623        let a = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
624        let b = Array::from_vec(vec![2.0], Shape::new(vec![1]));
625        let c = a.lt(&b);
626        assert_eq!(c.to_vec(), vec![1.0, 0.0, 0.0]);
627    }
628
629    #[test]
630    fn test_allclose() {
631        let a = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
632        let b = Array::from_vec(vec![1.0001, 2.0001, 3.0001], Shape::new(vec![3]));
633        assert!(a.allclose(&b, 1e-3, 1e-3));
634        assert!(!a.allclose(&b, 1e-5, 1e-5));
635
636        // Test with broadcasting
637        let c = Array::from_vec(vec![1.0001], Shape::new(vec![1]));
638        let d = Array::from_vec(vec![1.0, 1.0, 1.0], Shape::new(vec![3]));
639        assert!(c.allclose(&d, 1e-3, 1e-3));
640    }
641
642    #[test]
643    fn test_isclose() {
644        let a = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
645        let b = Array::from_vec(vec![1.0001, 2.1, 3.0001], Shape::new(vec![3]));
646        let c = a.isclose(&b, 1e-3, 1e-3);
647        assert_eq!(c.to_vec(), vec![1.0, 0.0, 1.0]);
648    }
649
650    #[test]
651    fn test_array_equal() {
652        let a = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
653        let b = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
654        let c = Array::from_vec(vec![1.0, 2.0, 3.1], Shape::new(vec![3]));
655        assert!(a.array_equal(&b));
656        assert!(!a.array_equal(&c));
657
658        // Different shapes
659        let d = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![1, 3]));
660        assert!(!a.array_equal(&d));
661    }
662
663    #[test]
664    fn test_array_equiv() {
665        let a = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
666        let b = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![1, 3]));
667        assert!(a.array_equiv(&b));
668
669        let c = Array::from_vec(vec![1.0, 2.0, 3.1], Shape::new(vec![1, 3]));
670        assert!(!a.array_equiv(&c));
671    }
672}