1use ferray_core::Array;
12use ferray_core::dimension::{Dimension, IxDyn};
13use ferray_core::dtype::Element;
14use ferray_core::error::FerrayResult;
15use num_traits::Float;
16
17use crate::helpers::{binary_broadcast_map_op, binary_map_op};
18
19pub fn equal<T, D>(a: &Array<T, D>, b: &Array<T, D>) -> FerrayResult<Array<bool, D>>
21where
22 T: Element + PartialEq + Copy,
23 D: Dimension,
24{
25 binary_map_op(a, b, |x, y| x == y)
26}
27
28pub fn not_equal<T, D>(a: &Array<T, D>, b: &Array<T, D>) -> FerrayResult<Array<bool, D>>
30where
31 T: Element + PartialEq + Copy,
32 D: Dimension,
33{
34 binary_map_op(a, b, |x, y| x != y)
35}
36
37pub fn less<T, D>(a: &Array<T, D>, b: &Array<T, D>) -> FerrayResult<Array<bool, D>>
39where
40 T: Element + PartialOrd + Copy,
41 D: Dimension,
42{
43 binary_map_op(a, b, |x, y| x < y)
44}
45
46pub fn less_equal<T, D>(a: &Array<T, D>, b: &Array<T, D>) -> FerrayResult<Array<bool, D>>
48where
49 T: Element + PartialOrd + Copy,
50 D: Dimension,
51{
52 binary_map_op(a, b, |x, y| x <= y)
53}
54
55pub fn greater<T, D>(a: &Array<T, D>, b: &Array<T, D>) -> FerrayResult<Array<bool, D>>
57where
58 T: Element + PartialOrd + Copy,
59 D: Dimension,
60{
61 binary_map_op(a, b, |x, y| x > y)
62}
63
64pub fn greater_equal<T, D>(a: &Array<T, D>, b: &Array<T, D>) -> FerrayResult<Array<bool, D>>
66where
67 T: Element + PartialOrd + Copy,
68 D: Dimension,
69{
70 binary_map_op(a, b, |x, y| x >= y)
71}
72
73pub fn equal_broadcast<T, D1, D2>(
86 a: &Array<T, D1>,
87 b: &Array<T, D2>,
88) -> FerrayResult<Array<bool, IxDyn>>
89where
90 T: Element + PartialEq + Copy,
91 D1: Dimension,
92 D2: Dimension,
93{
94 binary_broadcast_map_op(a, b, |x, y| x == y)
95}
96
97pub fn not_equal_broadcast<T, D1, D2>(
99 a: &Array<T, D1>,
100 b: &Array<T, D2>,
101) -> FerrayResult<Array<bool, IxDyn>>
102where
103 T: Element + PartialEq + Copy,
104 D1: Dimension,
105 D2: Dimension,
106{
107 binary_broadcast_map_op(a, b, |x, y| x != y)
108}
109
110pub fn less_broadcast<T, D1, D2>(
112 a: &Array<T, D1>,
113 b: &Array<T, D2>,
114) -> FerrayResult<Array<bool, IxDyn>>
115where
116 T: Element + PartialOrd + Copy,
117 D1: Dimension,
118 D2: Dimension,
119{
120 binary_broadcast_map_op(a, b, |x, y| x < y)
121}
122
123pub fn less_equal_broadcast<T, D1, D2>(
125 a: &Array<T, D1>,
126 b: &Array<T, D2>,
127) -> FerrayResult<Array<bool, IxDyn>>
128where
129 T: Element + PartialOrd + Copy,
130 D1: Dimension,
131 D2: Dimension,
132{
133 binary_broadcast_map_op(a, b, |x, y| x <= y)
134}
135
136pub fn greater_broadcast<T, D1, D2>(
138 a: &Array<T, D1>,
139 b: &Array<T, D2>,
140) -> FerrayResult<Array<bool, IxDyn>>
141where
142 T: Element + PartialOrd + Copy,
143 D1: Dimension,
144 D2: Dimension,
145{
146 binary_broadcast_map_op(a, b, |x, y| x > y)
147}
148
149pub fn greater_equal_broadcast<T, D1, D2>(
151 a: &Array<T, D1>,
152 b: &Array<T, D2>,
153) -> FerrayResult<Array<bool, IxDyn>>
154where
155 T: Element + PartialOrd + Copy,
156 D1: Dimension,
157 D2: Dimension,
158{
159 binary_broadcast_map_op(a, b, |x, y| x >= y)
160}
161
162pub fn isclose_broadcast<T, D1, D2>(
167 a: &Array<T, D1>,
168 b: &Array<T, D2>,
169 rtol: T,
170 atol: T,
171 equal_nan: bool,
172) -> FerrayResult<Array<bool, IxDyn>>
173where
174 T: Element + Float,
175 D1: Dimension,
176 D2: Dimension,
177{
178 binary_broadcast_map_op(a, b, |x, y| {
179 if equal_nan && x.is_nan() && y.is_nan() {
180 return true;
181 }
182 if x.is_nan() || y.is_nan() {
183 return false;
184 }
185 (x - y).abs() <= atol + rtol * y.abs()
186 })
187}
188
189pub fn array_equal<T, D>(a: &Array<T, D>, b: &Array<T, D>) -> bool
191where
192 T: Element + PartialEq,
193 D: Dimension,
194{
195 if a.shape() != b.shape() {
196 return false;
197 }
198 a.iter().zip(b.iter()).all(|(x, y)| x == y)
199}
200
201pub fn array_equiv<T, D>(a: &Array<T, D>, b: &Array<T, D>) -> bool
206where
207 T: Element + PartialEq,
208 D: Dimension,
209{
210 array_equal(a, b)
212}
213
214pub fn allclose<T, D>(a: &Array<T, D>, b: &Array<T, D>, rtol: T, atol: T) -> FerrayResult<bool>
218where
219 T: Element + Float,
220 D: Dimension,
221{
222 let close = isclose(a, b, rtol, atol, false)?;
223 Ok(close.iter().all(|&x| x))
224}
225
226pub fn isclose<T, D>(
232 a: &Array<T, D>,
233 b: &Array<T, D>,
234 rtol: T,
235 atol: T,
236 equal_nan: bool,
237) -> FerrayResult<Array<bool, D>>
238where
239 T: Element + Float,
240 D: Dimension,
241{
242 binary_map_op(a, b, |x, y| {
243 if equal_nan && x.is_nan() && y.is_nan() {
244 return true;
245 }
246 if x.is_nan() || y.is_nan() {
247 return false;
248 }
249 (x - y).abs() <= atol + rtol * y.abs()
250 })
251}
252
253#[cfg(test)]
254mod tests {
255 use super::*;
256 use ferray_core::dimension::Ix1;
257
258 use crate::test_util::arr1;
259
260 fn arr1_i32(data: Vec<i32>) -> Array<i32, Ix1> {
261 let n = data.len();
262 Array::from_vec(Ix1::new([n]), data).unwrap()
263 }
264
265 #[test]
266 fn test_equal() {
267 let a = arr1_i32(vec![1, 2, 3]);
268 let b = arr1_i32(vec![1, 5, 3]);
269 let r = equal(&a, &b).unwrap();
270 assert_eq!(r.as_slice().unwrap(), &[true, false, true]);
271 }
272
273 #[test]
274 fn test_not_equal() {
275 let a = arr1_i32(vec![1, 2, 3]);
276 let b = arr1_i32(vec![1, 5, 3]);
277 let r = not_equal(&a, &b).unwrap();
278 assert_eq!(r.as_slice().unwrap(), &[false, true, false]);
279 }
280
281 #[test]
282 fn test_less() {
283 let a = arr1(vec![1.0, 5.0, 3.0]);
284 let b = arr1(vec![2.0, 3.0, 3.0]);
285 let r = less(&a, &b).unwrap();
286 assert_eq!(r.as_slice().unwrap(), &[true, false, false]);
287 }
288
289 #[test]
290 fn test_less_equal() {
291 let a = arr1(vec![1.0, 5.0, 3.0]);
292 let b = arr1(vec![2.0, 3.0, 3.0]);
293 let r = less_equal(&a, &b).unwrap();
294 assert_eq!(r.as_slice().unwrap(), &[true, false, true]);
295 }
296
297 #[test]
298 fn test_greater() {
299 let a = arr1(vec![1.0, 5.0, 3.0]);
300 let b = arr1(vec![2.0, 3.0, 3.0]);
301 let r = greater(&a, &b).unwrap();
302 assert_eq!(r.as_slice().unwrap(), &[false, true, false]);
303 }
304
305 #[test]
306 fn test_greater_equal() {
307 let a = arr1(vec![1.0, 5.0, 3.0]);
308 let b = arr1(vec![2.0, 3.0, 3.0]);
309 let r = greater_equal(&a, &b).unwrap();
310 assert_eq!(r.as_slice().unwrap(), &[false, true, true]);
311 }
312
313 #[test]
314 fn test_array_equal() {
315 let a = arr1(vec![1.0, 2.0, 3.0]);
316 let b = arr1(vec![1.0, 2.0, 3.0]);
317 let c = arr1(vec![1.0, 2.0, 4.0]);
318 assert!(array_equal(&a, &b));
319 assert!(!array_equal(&a, &c));
320 }
321
322 #[test]
323 fn test_array_equal_different_shapes() {
324 let a = arr1(vec![1.0, 2.0]);
325 let b = arr1(vec![1.0, 2.0, 3.0]);
326 assert!(!array_equal(&a, &b));
327 }
328
329 #[test]
330 fn test_allclose() {
331 let a = arr1(vec![1.0, 2.0, 3.0]);
332 let b = arr1(vec![1.0 + 1e-9, 2.0 + 1e-9, 3.0 + 1e-9]);
333 assert!(allclose(&a, &b, 1e-5, 1e-8).unwrap());
334 }
335
336 #[test]
337 fn test_allclose_not_close() {
338 let a = arr1(vec![1.0, 2.0, 3.0]);
339 let b = arr1(vec![1.0, 2.0, 4.0]);
340 assert!(!allclose(&a, &b, 1e-5, 1e-8).unwrap());
341 }
342
343 #[test]
344 fn test_isclose() {
345 let a = arr1(vec![1.0, 2.0, 3.0]);
346 let b = arr1(vec![1.0, 2.1, 3.0]);
347 let r = isclose(&a, &b, 1e-5, 1e-8, false).unwrap();
348 assert_eq!(r.as_slice().unwrap(), &[true, false, true]);
349 }
350
351 #[test]
352 fn test_isclose_equal_nan() {
353 let a = arr1(vec![f64::NAN, 1.0]);
354 let b = arr1(vec![f64::NAN, 1.0]);
355 let r = isclose(&a, &b, 1e-5, 1e-8, true).unwrap();
356 assert_eq!(r.as_slice().unwrap(), &[true, true]);
357 }
358
359 #[test]
364 fn test_equal_broadcasts() {
365 use ferray_core::dimension::Ix2;
366 let a = Array::<i32, Ix2>::from_vec(Ix2::new([2, 3]), vec![1, 2, 3, 4, 2, 6]).unwrap();
367 let b = Array::<i32, Ix2>::from_vec(Ix2::new([1, 3]), vec![1, 2, 3]).unwrap();
368 let r = equal(&a, &b).unwrap();
369 assert_eq!(r.shape(), &[2, 3]);
370 assert_eq!(
371 r.iter().copied().collect::<Vec<_>>(),
372 vec![true, true, true, false, true, false]
373 );
374 }
375
376 #[test]
377 fn test_less_broadcasts() {
378 use ferray_core::dimension::Ix2;
379 let a = Array::<f64, Ix2>::from_vec(Ix2::new([3, 1]), vec![1.0, 5.0, 10.0]).unwrap();
380 let b = Array::<f64, Ix2>::from_vec(Ix2::new([1, 3]), vec![3.0, 5.0, 7.0]).unwrap();
381 let r = less(&a, &b).unwrap();
382 assert_eq!(r.shape(), &[3, 3]);
383 assert_eq!(
384 r.iter().copied().collect::<Vec<_>>(),
385 vec![
386 true, true, true, false, false, true, false, false, false, ]
390 );
391 }
392
393 #[test]
394 fn test_isclose_broadcasts() {
395 use ferray_core::dimension::Ix2;
396 let a = Array::<f64, Ix2>::from_vec(
397 Ix2::new([2, 3]),
398 vec![1.0, 2.0, 3.0, 1.0001, 2.0001, 3.0001],
399 )
400 .unwrap();
401 let b = Array::<f64, Ix2>::from_vec(Ix2::new([1, 3]), vec![1.0, 2.0, 3.0]).unwrap();
402 let r = isclose(&a, &b, 1e-3, 1e-8, false).unwrap();
403 assert_eq!(r.shape(), &[2, 3]);
404 assert_eq!(
405 r.iter().copied().collect::<Vec<_>>(),
406 vec![true, true, true, true, true, true]
407 );
408 }
409
410 #[test]
415 fn equal_broadcast_ix2_against_ix1() {
416 use ferray_core::dimension::Ix2;
417 let a = Array::<i32, Ix2>::from_vec(Ix2::new([2, 3]), vec![1, 2, 3, 4, 2, 6]).unwrap();
418 let b = arr1_i32(vec![1, 2, 3]);
419 let r = equal_broadcast(&a, &b).unwrap();
420 assert_eq!(r.shape(), &[2, 3]);
421 assert_eq!(
422 r.iter().copied().collect::<Vec<_>>(),
423 vec![true, true, true, false, true, false]
424 );
425 }
426
427 #[test]
428 fn not_equal_broadcast_ix2_against_ix1() {
429 use ferray_core::dimension::Ix2;
430 let a = Array::<i32, Ix2>::from_vec(Ix2::new([2, 3]), vec![1, 2, 3, 4, 2, 6]).unwrap();
431 let b = arr1_i32(vec![1, 2, 3]);
432 let r = not_equal_broadcast(&a, &b).unwrap();
433 assert_eq!(
434 r.iter().copied().collect::<Vec<_>>(),
435 vec![false, false, false, true, false, true]
436 );
437 }
438
439 #[test]
440 fn less_broadcast_ix2_against_scalar_like_ix1() {
441 use ferray_core::dimension::Ix2;
444 let a = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
445 .unwrap();
446 let threshold = arr1(vec![3.0]);
447 let r = less_broadcast(&a, &threshold).unwrap();
448 assert_eq!(r.shape(), &[2, 3]);
449 assert_eq!(
450 r.iter().copied().collect::<Vec<_>>(),
451 vec![true, true, false, false, false, false]
452 );
453 }
454
455 #[test]
456 fn greater_broadcast_ix2_against_ix1() {
457 use ferray_core::dimension::Ix2;
458 let a = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![0.0, 5.0, 10.0, 1.0, 5.0, 9.0])
459 .unwrap();
460 let b = arr1(vec![1.0, 5.0, 9.0]);
461 let r = greater_broadcast(&a, &b).unwrap();
462 assert_eq!(
463 r.iter().copied().collect::<Vec<_>>(),
464 vec![false, false, true, false, false, false]
465 );
466 }
467
468 #[test]
469 fn less_equal_broadcast_ix1_against_ix2() {
470 use ferray_core::dimension::Ix2;
472 let a = arr1(vec![1.0, 5.0, 9.0]);
473 let b = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![1.0, 5.0, 9.0, 0.5, 5.0, 10.0])
474 .unwrap();
475 let r = less_equal_broadcast(&a, &b).unwrap();
476 assert_eq!(r.shape(), &[2, 3]);
477 assert_eq!(
478 r.iter().copied().collect::<Vec<_>>(),
479 vec![true, true, true, false, true, true]
480 );
481 }
482
483 #[test]
484 fn greater_equal_broadcast_ix1_against_ix2() {
485 use ferray_core::dimension::Ix2;
486 let a = arr1(vec![5.0, 5.0, 5.0]);
487 let b = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![5.0, 4.0, 6.0, 5.0, 5.0, 5.0])
488 .unwrap();
489 let r = greater_equal_broadcast(&a, &b).unwrap();
490 assert_eq!(
491 r.iter().copied().collect::<Vec<_>>(),
492 vec![true, true, false, true, true, true]
493 );
494 }
495
496 #[test]
497 fn isclose_broadcast_ix2_against_ix1() {
498 use ferray_core::dimension::Ix2;
499 let a =
500 Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![1.0, 2.0, 3.0, 1.0001, 2.5, 3.0001])
501 .unwrap();
502 let b = arr1(vec![1.0, 2.0, 3.0]);
503 let r = isclose_broadcast(&a, &b, 1e-3, 1e-8, false).unwrap();
504 assert_eq!(r.shape(), &[2, 3]);
505 assert_eq!(
506 r.iter().copied().collect::<Vec<_>>(),
507 vec![true, true, true, true, false, true]
508 );
509 }
510
511 #[test]
512 fn equal_broadcast_returns_ixdyn_dim_type() {
513 use ferray_core::dimension::{Ix2, IxDyn};
517 let a = Array::<i32, Ix2>::from_vec(Ix2::new([2, 2]), vec![1, 2, 3, 4]).unwrap();
518 let b = arr1_i32(vec![1, 2]);
519 let r: Array<bool, IxDyn> = equal_broadcast(&a, &b).unwrap();
520 assert_eq!(r.ndim(), 2);
521 }
522
523 #[test]
524 fn equal_broadcast_incompatible_shapes_errors() {
525 let a = arr1_i32(vec![1, 2, 3]);
526 let b = arr1_i32(vec![1, 2]);
527 assert!(equal_broadcast(&a, &b).is_err());
528 }
529}