jax_rs/ops/comparison.rs
1//! Comparison operations on arrays.
2
3use crate::{buffer::Buffer, Array, DType, Device};
4
5#[cfg(test)]
6use crate::Shape;
7
8/// Apply a comparison function element-wise to two arrays with broadcasting.
9fn compare_op<F>(lhs: &Array, rhs: &Array, f: F) -> Array
10where
11 F: Fn(f32, f32) -> bool,
12{
13 assert_eq!(lhs.dtype(), DType::Float32, "Only Float32 supported");
14 assert_eq!(rhs.dtype(), DType::Float32, "Only Float32 supported");
15 assert_eq!(lhs.device(), Device::Cpu, "Only CPU supported for now");
16 assert_eq!(rhs.device(), Device::Cpu, "Only CPU supported for now");
17
18 // Check if shapes are broadcast-compatible
19 let result_shape = lhs
20 .shape()
21 .broadcast_with(rhs.shape())
22 .expect("Shapes are not broadcast-compatible");
23
24 let lhs_data = lhs.to_vec();
25 let rhs_data = rhs.to_vec();
26
27 let result_data: Vec<f32> = if lhs.shape() == rhs.shape() {
28 // Same shape - simple element-wise operation
29 lhs_data
30 .iter()
31 .zip(rhs_data.iter())
32 .map(|(&a, &b)| if f(a, b) { 1.0 } else { 0.0 })
33 .collect()
34 } else {
35 // Need broadcasting
36 let size = result_shape.size();
37 (0..size)
38 .map(|i| {
39 let lhs_idx = crate::ops::binary::broadcast_index(
40 i,
41 &result_shape,
42 lhs.shape(),
43 );
44 let rhs_idx = crate::ops::binary::broadcast_index(
45 i,
46 &result_shape,
47 rhs.shape(),
48 );
49 if f(lhs_data[lhs_idx], rhs_data[rhs_idx]) {
50 1.0
51 } else {
52 0.0
53 }
54 })
55 .collect()
56 };
57
58 let buffer = Buffer::from_f32(result_data, Device::Cpu);
59 Array::from_buffer(buffer, result_shape)
60}
61
62impl Array {
63 /// Element-wise less than comparison.
64 ///
65 /// Returns an array of 1.0 where condition is true, 0.0 otherwise.
66 ///
67 /// # Examples
68 ///
69 /// ```
70 /// # use jax_rs::{Array, Shape};
71 /// let a = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
72 /// let b = Array::from_vec(vec![2.0, 2.0, 2.0], Shape::new(vec![3]));
73 /// let c = a.lt(&b);
74 /// assert_eq!(c.to_vec(), vec![1.0, 0.0, 0.0]);
75 /// ```
76 pub fn lt(&self, other: &Array) -> Array {
77 compare_op(self, other, |a, b| a < b)
78 }
79
80 /// Element-wise less than or equal comparison.
81 pub fn le(&self, other: &Array) -> Array {
82 compare_op(self, other, |a, b| a <= b)
83 }
84
85 /// Element-wise greater than comparison.
86 pub fn gt(&self, other: &Array) -> Array {
87 compare_op(self, other, |a, b| a > b)
88 }
89
90 /// Element-wise greater than or equal comparison.
91 pub fn ge(&self, other: &Array) -> Array {
92 compare_op(self, other, |a, b| a >= b)
93 }
94
95 /// Element-wise equality comparison.
96 ///
97 /// Note: For floating point, this is exact equality. Use `allclose` for
98 /// approximate equality.
99 pub fn eq(&self, other: &Array) -> Array {
100 compare_op(self, other, |a, b| a == b)
101 }
102
103 /// Element-wise equality comparison with a scalar.
104 ///
105 /// Returns an array where each element is 1.0 if equal to the scalar, 0.0 otherwise.
106 pub fn eq_scalar(&self, value: f32) -> Array {
107 let data = self.to_vec();
108 let result: Vec<f32> = data
109 .iter()
110 .map(|&x| if x == value { 1.0 } else { 0.0 })
111 .collect();
112 Array::from_vec(result, self.shape().clone())
113 }
114
115 /// Element-wise inequality comparison.
116 pub fn ne(&self, other: &Array) -> Array {
117 compare_op(self, other, |a, b| a != b)
118 }
119
120 /// Logical NOT element-wise.
121 ///
122 /// Treats 0.0 as false, non-zero as true.
123 ///
124 /// # Examples
125 ///
126 /// ```
127 /// # use jax_rs::{Array, Shape};
128 /// let a = Array::from_vec(vec![0.0, 1.0, 0.0], Shape::new(vec![3]));
129 /// let b = a.logical_not();
130 /// assert_eq!(b.to_vec(), vec![1.0, 0.0, 1.0]);
131 /// ```
132 pub fn logical_not(&self) -> Array {
133 assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
134 let data = self.to_vec();
135 let result: Vec<f32> = data
136 .iter()
137 .map(|&x| if x == 0.0 { 1.0 } else { 0.0 })
138 .collect();
139 Array::from_vec(result, self.shape().clone())
140 }
141
142 /// Logical AND element-wise.
143 ///
144 /// # Examples
145 ///
146 /// ```
147 /// # use jax_rs::{Array, Shape};
148 /// let a = Array::from_vec(vec![1.0, 1.0, 0.0], Shape::new(vec![3]));
149 /// let b = Array::from_vec(vec![1.0, 0.0, 0.0], Shape::new(vec![3]));
150 /// let c = a.logical_and(&b);
151 /// assert_eq!(c.to_vec(), vec![1.0, 0.0, 0.0]);
152 /// ```
153 pub fn logical_and(&self, other: &Array) -> Array {
154 compare_op(self, other, |a, b| a != 0.0 && b != 0.0)
155 }
156
157 /// Logical OR element-wise.
158 ///
159 /// # Examples
160 ///
161 /// ```
162 /// # use jax_rs::{Array, Shape};
163 /// let a = Array::from_vec(vec![1.0, 1.0, 0.0], Shape::new(vec![3]));
164 /// let b = Array::from_vec(vec![1.0, 0.0, 0.0], Shape::new(vec![3]));
165 /// let c = a.logical_or(&b);
166 /// assert_eq!(c.to_vec(), vec![1.0, 1.0, 0.0]);
167 /// ```
168 pub fn logical_or(&self, other: &Array) -> Array {
169 compare_op(self, other, |a, b| a != 0.0 || b != 0.0)
170 }
171
172 /// Logical XOR element-wise.
173 ///
174 /// # Examples
175 ///
176 /// ```
177 /// # use jax_rs::{Array, Shape};
178 /// let a = Array::from_vec(vec![1.0, 1.0, 0.0], Shape::new(vec![3]));
179 /// let b = Array::from_vec(vec![1.0, 0.0, 0.0], Shape::new(vec![3]));
180 /// let c = a.logical_xor(&b);
181 /// assert_eq!(c.to_vec(), vec![0.0, 1.0, 0.0]);
182 /// ```
183 pub fn logical_xor(&self, other: &Array) -> Array {
184 compare_op(self, other, |a, b| (a != 0.0) != (b != 0.0))
185 }
186
187 /// Test if all elements are true (non-zero).
188 ///
189 /// # Examples
190 ///
191 /// ```
192 /// # use jax_rs::{Array, Shape};
193 /// let a = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
194 /// assert!(a.all());
195 /// let b = Array::from_vec(vec![1.0, 0.0, 3.0], Shape::new(vec![3]));
196 /// assert!(!b.all());
197 /// ```
198 pub fn all(&self) -> bool {
199 let data = self.to_vec();
200 data.iter().all(|&x| x != 0.0)
201 }
202
203 /// Test if any element is true (non-zero).
204 ///
205 /// # Examples
206 ///
207 /// ```
208 /// # use jax_rs::{Array, Shape};
209 /// let a = Array::from_vec(vec![0.0, 0.0, 1.0], Shape::new(vec![3]));
210 /// assert!(a.any());
211 /// let b = Array::from_vec(vec![0.0, 0.0, 0.0], Shape::new(vec![3]));
212 /// assert!(!b.any());
213 /// ```
214 pub fn any(&self) -> bool {
215 let data = self.to_vec();
216 data.iter().any(|&x| x != 0.0)
217 }
218
219 /// Count the number of true (non-zero) elements.
220 ///
221 /// # Examples
222 ///
223 /// ```
224 /// # use jax_rs::{Array, Shape};
225 /// let a = Array::from_vec(vec![1.0, 0.0, 1.0, 0.0, 1.0], Shape::new(vec![5]));
226 /// assert_eq!(a.count_nonzero(), 3);
227 /// ```
228 pub fn count_nonzero(&self) -> usize {
229 let data = self.to_vec();
230 data.iter().filter(|&&x| x != 0.0).count()
231 }
232
233 /// Test if two arrays are element-wise equal within a tolerance.
234 ///
235 /// Returns true if all elements satisfy: |a - b| <= atol + rtol * |b|
236 ///
237 /// # Arguments
238 ///
239 /// * `other` - Array to compare with
240 /// * `rtol` - Relative tolerance
241 /// * `atol` - Absolute tolerance
242 ///
243 /// # Examples
244 ///
245 /// ```
246 /// # use jax_rs::{Array, Shape};
247 /// let a = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
248 /// let b = Array::from_vec(vec![1.0001, 2.0001, 3.0001], Shape::new(vec![3]));
249 /// assert!(a.allclose(&b, 1e-3, 1e-3));
250 /// assert!(!a.allclose(&b, 1e-5, 1e-5));
251 /// ```
252 pub fn allclose(&self, other: &Array, rtol: f32, atol: f32) -> bool {
253 assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
254 assert_eq!(other.dtype(), DType::Float32, "Only Float32 supported");
255
256 // Check if shapes are broadcast-compatible
257 let result_shape = match self.shape().broadcast_with(other.shape()) {
258 Some(shape) => shape,
259 None => return false,
260 };
261
262 let self_data = self.to_vec();
263 let other_data = other.to_vec();
264
265 if self.shape() == other.shape() {
266 // Same shape - simple element-wise comparison
267 self_data.iter().zip(other_data.iter()).all(|(&a, &b)| {
268 let diff = (a - b).abs();
269 diff <= atol + rtol * b.abs()
270 })
271 } else {
272 // Need broadcasting
273 let size = result_shape.size();
274 (0..size).all(|i| {
275 let self_idx =
276 crate::ops::binary::broadcast_index(i, &result_shape, self.shape());
277 let other_idx =
278 crate::ops::binary::broadcast_index(i, &result_shape, other.shape());
279 let a = self_data[self_idx];
280 let b = other_data[other_idx];
281 let diff = (a - b).abs();
282 diff <= atol + rtol * b.abs()
283 })
284 }
285 }
286
287 /// Element-wise test if values are close within a tolerance.
288 ///
289 /// Returns an array of 1.0 where |a - b| <= atol + rtol * |b|, 0.0 otherwise.
290 ///
291 /// # Examples
292 ///
293 /// ```
294 /// # use jax_rs::{Array, Shape};
295 /// let a = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
296 /// let b = Array::from_vec(vec![1.0001, 2.1, 3.0001], Shape::new(vec![3]));
297 /// let c = a.isclose(&b, 1e-3, 1e-3);
298 /// assert_eq!(c.to_vec(), vec![1.0, 0.0, 1.0]);
299 /// ```
300 pub fn isclose(&self, other: &Array, rtol: f32, atol: f32) -> Array {
301 assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
302 assert_eq!(other.dtype(), DType::Float32, "Only Float32 supported");
303
304 // Check if shapes are broadcast-compatible
305 let result_shape = self
306 .shape()
307 .broadcast_with(other.shape())
308 .expect("Shapes are not broadcast-compatible");
309
310 let self_data = self.to_vec();
311 let other_data = other.to_vec();
312
313 let result_data: Vec<f32> = if self.shape() == other.shape() {
314 // Same shape - simple element-wise operation
315 self_data
316 .iter()
317 .zip(other_data.iter())
318 .map(|(&a, &b)| {
319 let diff = (a - b).abs();
320 if diff <= atol + rtol * b.abs() {
321 1.0
322 } else {
323 0.0
324 }
325 })
326 .collect()
327 } else {
328 // Need broadcasting
329 let size = result_shape.size();
330 (0..size)
331 .map(|i| {
332 let self_idx =
333 crate::ops::binary::broadcast_index(i, &result_shape, self.shape());
334 let other_idx =
335 crate::ops::binary::broadcast_index(i, &result_shape, other.shape());
336 let a = self_data[self_idx];
337 let b = other_data[other_idx];
338 let diff = (a - b).abs();
339 if diff <= atol + rtol * b.abs() {
340 1.0
341 } else {
342 0.0
343 }
344 })
345 .collect()
346 };
347
348 let buffer = Buffer::from_f32(result_data, Device::Cpu);
349 Array::from_buffer(buffer, result_shape)
350 }
351
352 /// Test if two arrays have the same shape and elements.
353 ///
354 /// This is exact equality - for approximate equality use `allclose`.
355 ///
356 /// # Examples
357 ///
358 /// ```
359 /// # use jax_rs::{Array, Shape};
360 /// let a = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
361 /// let b = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
362 /// let c = Array::from_vec(vec![1.0, 2.0, 3.1], Shape::new(vec![3]));
363 /// assert!(a.array_equal(&b));
364 /// assert!(!a.array_equal(&c));
365 /// ```
366 pub fn array_equal(&self, other: &Array) -> bool {
367 if self.shape() != other.shape() {
368 return false;
369 }
370 if self.dtype() != other.dtype() {
371 return false;
372 }
373
374 let self_data = self.to_vec();
375 let other_data = other.to_vec();
376 self_data == other_data
377 }
378
379 /// Test if arrays can be broadcast to the same shape and are equal.
380 ///
381 /// Unlike `array_equal`, this allows broadcasting.
382 ///
383 /// # Examples
384 ///
385 /// ```
386 /// # use jax_rs::{Array, Shape};
387 /// let a = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
388 /// let b = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![1, 3]));
389 /// assert!(a.array_equiv(&b));
390 /// ```
391 pub fn array_equiv(&self, other: &Array) -> bool {
392 assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
393 assert_eq!(other.dtype(), DType::Float32, "Only Float32 supported");
394
395 // Check if shapes are broadcast-compatible
396 let result_shape = match self.shape().broadcast_with(other.shape()) {
397 Some(shape) => shape,
398 None => return false,
399 };
400
401 let self_data = self.to_vec();
402 let other_data = other.to_vec();
403
404 if self.shape() == other.shape() {
405 // Same shape - simple element-wise comparison
406 self_data == other_data
407 } else {
408 // Need broadcasting
409 let size = result_shape.size();
410 (0..size).all(|i| {
411 let self_idx =
412 crate::ops::binary::broadcast_index(i, &result_shape, self.shape());
413 let other_idx =
414 crate::ops::binary::broadcast_index(i, &result_shape, other.shape());
415 self_data[self_idx] == other_data[other_idx]
416 })
417 }
418 }
419
420 /// Element-wise greater comparison (alias for gt).
421 ///
422 /// # Examples
423 ///
424 /// ```
425 /// # use jax_rs::{Array, Shape};
426 /// let a = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
427 /// let b = Array::from_vec(vec![2.0, 2.0, 2.0], Shape::new(vec![3]));
428 /// let c = a.greater(&b);
429 /// assert_eq!(c.to_vec(), vec![0.0, 0.0, 1.0]);
430 /// ```
431 pub fn greater(&self, other: &Array) -> Array {
432 self.gt(other)
433 }
434
435 /// Element-wise less comparison (alias for lt).
436 ///
437 /// # Examples
438 ///
439 /// ```
440 /// # use jax_rs::{Array, Shape};
441 /// let a = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
442 /// let b = Array::from_vec(vec![2.0, 2.0, 2.0], Shape::new(vec![3]));
443 /// let c = a.less(&b);
444 /// assert_eq!(c.to_vec(), vec![1.0, 0.0, 0.0]);
445 /// ```
446 pub fn less(&self, other: &Array) -> Array {
447 self.lt(other)
448 }
449
450 /// Element-wise greater-or-equal comparison (alias for ge).
451 ///
452 /// # Examples
453 ///
454 /// ```
455 /// # use jax_rs::{Array, Shape};
456 /// let a = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
457 /// let b = Array::from_vec(vec![2.0, 2.0, 2.0], Shape::new(vec![3]));
458 /// let c = a.greater_equal(&b);
459 /// assert_eq!(c.to_vec(), vec![0.0, 1.0, 1.0]);
460 /// ```
461 pub fn greater_equal(&self, other: &Array) -> Array {
462 self.ge(other)
463 }
464
465 /// Element-wise less-or-equal comparison (alias for le).
466 ///
467 /// # Examples
468 ///
469 /// ```
470 /// # use jax_rs::{Array, Shape};
471 /// let a = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
472 /// let b = Array::from_vec(vec![2.0, 2.0, 2.0], Shape::new(vec![3]));
473 /// let c = a.less_equal(&b);
474 /// assert_eq!(c.to_vec(), vec![1.0, 1.0, 0.0]);
475 /// ```
476 pub fn less_equal(&self, other: &Array) -> Array {
477 self.le(other)
478 }
479
480 /// Test element-wise for real numbers (not infinity or NaN).
481 /// For Float32, returns true for all finite values.
482 ///
483 /// # Examples
484 ///
485 /// ```
486 /// # use jax_rs::{Array, Shape};
487 /// let a = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
488 /// let r = a.isreal();
489 /// assert_eq!(r.to_vec(), vec![1.0, 1.0, 1.0]);
490 /// ```
491 pub fn isreal(&self) -> Array {
492 assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
493
494 let data = self.to_vec();
495 let result_data: Vec<f32> = data
496 .iter()
497 .map(|&x| if x.is_finite() { 1.0 } else { 0.0 })
498 .collect();
499
500 Array::from_vec(result_data, self.shape().clone())
501 }
502
503 /// Test element-wise for complex numbers.
504 /// For Float32 arrays, always returns false (0.0).
505 ///
506 /// # Examples
507 ///
508 /// ```
509 /// # use jax_rs::{Array, Shape};
510 /// let a = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
511 /// let c = a.iscomplex();
512 /// assert_eq!(c.to_vec(), vec![0.0, 0.0, 0.0]);
513 /// ```
514 pub fn iscomplex(&self) -> Array {
515 assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
516 Array::zeros(self.shape().clone(), DType::Float32)
517 }
518
519 /// Test element-wise if values are in an open interval.
520 /// Returns 1.0 where lower < x < upper, 0.0 otherwise.
521 ///
522 /// # Examples
523 ///
524 /// ```
525 /// # use jax_rs::{Array, Shape};
526 /// let a = Array::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0], Shape::new(vec![5]));
527 /// let b = a.isin_range(1.5, 4.5);
528 /// assert_eq!(b.to_vec(), vec![0.0, 1.0, 1.0, 1.0, 0.0]);
529 /// ```
530 pub fn isin_range(&self, lower: f32, upper: f32) -> Array {
531 assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
532
533 let data = self.to_vec();
534 let result_data: Vec<f32> = data
535 .iter()
536 .map(|&x| if x > lower && x < upper { 1.0 } else { 0.0 })
537 .collect();
538
539 Array::from_vec(result_data, self.shape().clone())
540 }
541
542 /// Test element-wise if values are subnormal (denormalized).
543 /// Returns 1.0 where value is subnormal, 0.0 otherwise.
544 ///
545 /// # Examples
546 ///
547 /// ```
548 /// # use jax_rs::{Array, Shape};
549 /// let a = Array::from_vec(vec![1.0, 0.0, 1e-40], Shape::new(vec![3]));
550 /// let b = a.issubnormal();
551 /// // Only 1e-40 is subnormal
552 /// assert_eq!(b.to_vec()[0], 0.0);
553 /// assert_eq!(b.to_vec()[1], 0.0);
554 /// assert_eq!(b.to_vec()[2], 1.0);
555 /// ```
556 pub fn issubnormal(&self) -> Array {
557 assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
558
559 let data = self.to_vec();
560 let result_data: Vec<f32> = data
561 .iter()
562 .map(|&x| if x.is_subnormal() { 1.0 } else { 0.0 })
563 .collect();
564
565 Array::from_vec(result_data, self.shape().clone())
566 }
567}
568
569#[cfg(test)]
570mod tests {
571 use super::*;
572
573 #[test]
574 fn test_lt() {
575 let a = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
576 let b = Array::from_vec(vec![2.0, 2.0, 2.0], Shape::new(vec![3]));
577 let c = a.lt(&b);
578 assert_eq!(c.to_vec(), vec![1.0, 0.0, 0.0]);
579 }
580
581 #[test]
582 fn test_le() {
583 let a = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
584 let b = Array::from_vec(vec![2.0, 2.0, 2.0], Shape::new(vec![3]));
585 let c = a.le(&b);
586 assert_eq!(c.to_vec(), vec![1.0, 1.0, 0.0]);
587 }
588
589 #[test]
590 fn test_gt() {
591 let a = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
592 let b = Array::from_vec(vec![2.0, 2.0, 2.0], Shape::new(vec![3]));
593 let c = a.gt(&b);
594 assert_eq!(c.to_vec(), vec![0.0, 0.0, 1.0]);
595 }
596
597 #[test]
598 fn test_ge() {
599 let a = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
600 let b = Array::from_vec(vec![2.0, 2.0, 2.0], Shape::new(vec![3]));
601 let c = a.ge(&b);
602 assert_eq!(c.to_vec(), vec![0.0, 1.0, 1.0]);
603 }
604
605 #[test]
606 fn test_eq() {
607 let a = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
608 let b = Array::from_vec(vec![1.0, 2.0, 4.0], Shape::new(vec![3]));
609 let c = a.eq(&b);
610 assert_eq!(c.to_vec(), vec![1.0, 1.0, 0.0]);
611 }
612
613 #[test]
614 fn test_ne() {
615 let a = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
616 let b = Array::from_vec(vec![1.0, 2.0, 4.0], Shape::new(vec![3]));
617 let c = a.ne(&b);
618 assert_eq!(c.to_vec(), vec![0.0, 0.0, 1.0]);
619 }
620
621 #[test]
622 fn test_comparison_broadcast() {
623 let a = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
624 let b = Array::from_vec(vec![2.0], Shape::new(vec![1]));
625 let c = a.lt(&b);
626 assert_eq!(c.to_vec(), vec![1.0, 0.0, 0.0]);
627 }
628
629 #[test]
630 fn test_allclose() {
631 let a = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
632 let b = Array::from_vec(vec![1.0001, 2.0001, 3.0001], Shape::new(vec![3]));
633 assert!(a.allclose(&b, 1e-3, 1e-3));
634 assert!(!a.allclose(&b, 1e-5, 1e-5));
635
636 // Test with broadcasting
637 let c = Array::from_vec(vec![1.0001], Shape::new(vec![1]));
638 let d = Array::from_vec(vec![1.0, 1.0, 1.0], Shape::new(vec![3]));
639 assert!(c.allclose(&d, 1e-3, 1e-3));
640 }
641
642 #[test]
643 fn test_isclose() {
644 let a = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
645 let b = Array::from_vec(vec![1.0001, 2.1, 3.0001], Shape::new(vec![3]));
646 let c = a.isclose(&b, 1e-3, 1e-3);
647 assert_eq!(c.to_vec(), vec![1.0, 0.0, 1.0]);
648 }
649
650 #[test]
651 fn test_array_equal() {
652 let a = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
653 let b = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
654 let c = Array::from_vec(vec![1.0, 2.0, 3.1], Shape::new(vec![3]));
655 assert!(a.array_equal(&b));
656 assert!(!a.array_equal(&c));
657
658 // Different shapes
659 let d = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![1, 3]));
660 assert!(!a.array_equal(&d));
661 }
662
663 #[test]
664 fn test_array_equiv() {
665 let a = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
666 let b = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![1, 3]));
667 assert!(a.array_equiv(&b));
668
669 let c = Array::from_vec(vec![1.0, 2.0, 3.1], Shape::new(vec![1, 3]));
670 assert!(!a.array_equiv(&c));
671 }
672}