1use ferray_core::Array;
7use ferray_core::dimension::{Dimension, IxDyn};
8use ferray_core::dtype::Element;
9use ferray_core::error::{FerrayError, FerrayResult};
10
11use crate::helpers::{binary_map_op, unary_map_op};
12
13pub trait Logical {
15 fn is_truthy(&self) -> bool;
17}
18
19impl Logical for bool {
20 #[inline]
21 fn is_truthy(&self) -> bool {
22 *self
23 }
24}
25
26macro_rules! impl_logical_numeric {
27 ($($ty:ty),*) => {
28 $(
29 impl Logical for $ty {
30 #[inline]
31 fn is_truthy(&self) -> bool {
32 *self != 0 as $ty
33 }
34 }
35 )*
36 };
37}
38
39impl_logical_numeric!(i8, i16, i32, i64, i128, u8, u16, u32, u64, u128);
40
41impl Logical for f32 {
42 #[inline]
43 fn is_truthy(&self) -> bool {
44 *self != 0.0
45 }
46}
47
48impl Logical for f64 {
49 #[inline]
50 fn is_truthy(&self) -> bool {
51 *self != 0.0
52 }
53}
54
55impl Logical for num_complex::Complex<f32> {
56 #[inline]
57 fn is_truthy(&self) -> bool {
58 self.re != 0.0 || self.im != 0.0
59 }
60}
61
62impl Logical for num_complex::Complex<f64> {
63 #[inline]
64 fn is_truthy(&self) -> bool {
65 self.re != 0.0 || self.im != 0.0
66 }
67}
68
69pub fn logical_and<T, D>(a: &Array<T, D>, b: &Array<T, D>) -> FerrayResult<Array<bool, D>>
71where
72 T: Element + Logical + Copy,
73 D: Dimension,
74{
75 binary_map_op(a, b, |x, y| x.is_truthy() && y.is_truthy())
76}
77
78pub fn logical_or<T, D>(a: &Array<T, D>, b: &Array<T, D>) -> FerrayResult<Array<bool, D>>
80where
81 T: Element + Logical + Copy,
82 D: Dimension,
83{
84 binary_map_op(a, b, |x, y| x.is_truthy() || y.is_truthy())
85}
86
87pub fn logical_xor<T, D>(a: &Array<T, D>, b: &Array<T, D>) -> FerrayResult<Array<bool, D>>
89where
90 T: Element + Logical + Copy,
91 D: Dimension,
92{
93 binary_map_op(a, b, |x, y| x.is_truthy() ^ y.is_truthy())
94}
95
96pub fn logical_not<T, D>(input: &Array<T, D>) -> FerrayResult<Array<bool, D>>
98where
99 T: Element + Logical + Copy,
100 D: Dimension,
101{
102 unary_map_op(input, |x| !x.is_truthy())
103}
104
105pub fn all<T, D>(input: &Array<T, D>) -> bool
107where
108 T: Element + Logical,
109 D: Dimension,
110{
111 input.iter().all(Logical::is_truthy)
112}
113
114pub fn any<T, D>(input: &Array<T, D>) -> bool
116where
117 T: Element + Logical,
118 D: Dimension,
119{
120 input.iter().any(Logical::is_truthy)
121}
122
123fn reduce_truthy_axis<T, D, F>(
128 input: &Array<T, D>,
129 axis: usize,
130 identity: bool,
131 stop_at: bool,
132 op: F,
133) -> FerrayResult<Array<bool, IxDyn>>
134where
135 T: Element + Logical,
136 D: Dimension,
137 F: Fn(bool, &T) -> bool,
138{
139 let ndim = input.ndim();
140 if axis >= ndim {
141 return Err(FerrayError::axis_out_of_bounds(axis, ndim));
142 }
143
144 let shape: Vec<usize> = input.shape().to_vec();
145 let axis_len = shape[axis];
146 let outer_size: usize = shape[..axis].iter().product();
147 let inner_size: usize = shape[axis + 1..].iter().product();
148
149 let data: Vec<T> = input.iter().cloned().collect();
153
154 let mut out_shape: Vec<usize> = shape
155 .iter()
156 .enumerate()
157 .filter_map(|(i, &s)| if i == axis { None } else { Some(s) })
158 .collect();
159 let out_size: usize = out_shape.iter().product::<usize>().max(1);
160
161 let mut result = vec![identity; out_size];
162
163 for outer in 0..outer_size {
164 for inner in 0..inner_size {
165 let out_idx = outer * inner_size + inner;
166 let mut acc = identity;
167 for k in 0..axis_len {
168 let idx = outer * axis_len * inner_size + k * inner_size + inner;
169 acc = op(acc, &data[idx]);
170 if acc == stop_at {
171 break;
172 }
173 }
174 result[out_idx] = acc;
175 }
176 }
177
178 if out_shape.is_empty() {
183 out_shape.push(1);
184 }
185 Array::from_vec(IxDyn::from(&out_shape[..]), result)
186}
187
188pub fn all_axis<T, D>(input: &Array<T, D>, axis: usize) -> FerrayResult<Array<bool, IxDyn>>
197where
198 T: Element + Logical,
199 D: Dimension,
200{
201 reduce_truthy_axis(input, axis, true, false, |acc, x| acc && x.is_truthy())
202}
203
204pub fn any_axis<T, D>(input: &Array<T, D>, axis: usize) -> FerrayResult<Array<bool, IxDyn>>
213where
214 T: Element + Logical,
215 D: Dimension,
216{
217 reduce_truthy_axis(input, axis, false, true, |acc, x| acc || x.is_truthy())
218}
219
220#[cfg(test)]
221mod tests {
222 use super::*;
223 use ferray_core::dimension::Ix1;
224
225 fn arr1_bool(data: Vec<bool>) -> Array<bool, Ix1> {
226 let n = data.len();
227 Array::from_vec(Ix1::new([n]), data).unwrap()
228 }
229
230 fn arr1_i32(data: Vec<i32>) -> Array<i32, Ix1> {
231 let n = data.len();
232 Array::from_vec(Ix1::new([n]), data).unwrap()
233 }
234
235 #[test]
236 fn test_logical_and() {
237 let a = arr1_bool(vec![true, true, false, false]);
238 let b = arr1_bool(vec![true, false, true, false]);
239 let r = logical_and(&a, &b).unwrap();
240 assert_eq!(r.as_slice().unwrap(), &[true, false, false, false]);
241 }
242
243 #[test]
244 fn test_logical_or() {
245 let a = arr1_bool(vec![true, true, false, false]);
246 let b = arr1_bool(vec![true, false, true, false]);
247 let r = logical_or(&a, &b).unwrap();
248 assert_eq!(r.as_slice().unwrap(), &[true, true, true, false]);
249 }
250
251 #[test]
252 fn test_logical_xor() {
253 let a = arr1_bool(vec![true, true, false, false]);
254 let b = arr1_bool(vec![true, false, true, false]);
255 let r = logical_xor(&a, &b).unwrap();
256 assert_eq!(r.as_slice().unwrap(), &[false, true, true, false]);
257 }
258
259 #[test]
260 fn test_logical_not() {
261 let a = arr1_bool(vec![true, false, true]);
262 let r = logical_not(&a).unwrap();
263 assert_eq!(r.as_slice().unwrap(), &[false, true, false]);
264 }
265
266 #[test]
267 fn test_logical_and_numeric() {
268 let a = arr1_i32(vec![1, 1, 0, 0]);
269 let b = arr1_i32(vec![1, 0, 1, 0]);
270 let r = logical_and(&a, &b).unwrap();
271 assert_eq!(r.as_slice().unwrap(), &[true, false, false, false]);
272 }
273
274 #[test]
275 fn test_all() {
276 let a = arr1_bool(vec![true, true, true]);
277 assert!(all(&a));
278 let b = arr1_bool(vec![true, false, true]);
279 assert!(!all(&b));
280 }
281
282 #[test]
283 fn test_any() {
284 let a = arr1_bool(vec![false, false, true]);
285 assert!(any(&a));
286 let b = arr1_bool(vec![false, false, false]);
287 assert!(!any(&b));
288 }
289
290 #[test]
291 fn test_all_numeric() {
292 let a = arr1_i32(vec![1, 2, 3]);
293 assert!(all(&a));
294 let b = arr1_i32(vec![1, 0, 3]);
295 assert!(!all(&b));
296 }
297
298 #[test]
303 fn test_logical_and_broadcasts() {
304 use ferray_core::dimension::Ix2;
305 let a = Array::<bool, Ix2>::from_vec(Ix2::new([2, 1]), vec![true, false]).unwrap();
306 let b = Array::<bool, Ix2>::from_vec(Ix2::new([1, 3]), vec![true, false, true]).unwrap();
307 let r = logical_and(&a, &b).unwrap();
308 assert_eq!(r.shape(), &[2, 3]);
309 assert_eq!(
310 r.iter().copied().collect::<Vec<_>>(),
311 vec![true, false, true, false, false, false]
312 );
313 }
314
315 #[test]
316 fn test_logical_or_broadcasts() {
317 use ferray_core::dimension::Ix2;
318 let a = Array::<bool, Ix2>::from_vec(Ix2::new([2, 1]), vec![true, false]).unwrap();
319 let b = Array::<bool, Ix2>::from_vec(Ix2::new([1, 3]), vec![true, false, true]).unwrap();
320 let r = logical_or(&a, &b).unwrap();
321 assert_eq!(r.shape(), &[2, 3]);
322 assert_eq!(
323 r.iter().copied().collect::<Vec<_>>(),
324 vec![true, true, true, true, false, true]
325 );
326 }
327
328 #[test]
333 fn all_axis_2d_rows() {
334 use ferray_core::dimension::Ix2;
335 let a = Array::<bool, Ix2>::from_vec(
337 Ix2::new([2, 3]),
338 vec![true, true, true, true, false, true],
339 )
340 .unwrap();
341 let r = all_axis(&a, 1).unwrap();
342 assert_eq!(r.shape(), &[2]);
343 assert_eq!(r.as_slice().unwrap(), &[true, false]);
344 }
345
346 #[test]
347 fn all_axis_2d_cols() {
348 use ferray_core::dimension::Ix2;
349 let a = Array::<bool, Ix2>::from_vec(
351 Ix2::new([2, 3]),
352 vec![true, true, false, true, true, true],
353 )
354 .unwrap();
355 let r = all_axis(&a, 0).unwrap();
356 assert_eq!(r.shape(), &[3]);
357 assert_eq!(r.as_slice().unwrap(), &[true, true, false]);
358 }
359
360 #[test]
361 fn any_axis_2d_rows() {
362 use ferray_core::dimension::Ix2;
363 let a = Array::<bool, Ix2>::from_vec(
364 Ix2::new([2, 3]),
365 vec![false, false, false, false, true, false],
366 )
367 .unwrap();
368 let r = any_axis(&a, 1).unwrap();
369 assert_eq!(r.shape(), &[2]);
370 assert_eq!(r.as_slice().unwrap(), &[false, true]);
371 }
372
373 #[test]
374 fn any_axis_2d_cols() {
375 use ferray_core::dimension::Ix2;
376 let a = Array::<bool, Ix2>::from_vec(
377 Ix2::new([2, 3]),
378 vec![false, true, false, false, false, false],
379 )
380 .unwrap();
381 let r = any_axis(&a, 0).unwrap();
382 assert_eq!(r.shape(), &[3]);
383 assert_eq!(r.as_slice().unwrap(), &[false, true, false]);
384 }
385
386 #[test]
387 fn all_axis_numeric_integer_input() {
388 use ferray_core::dimension::Ix2;
389 let a = Array::<i32, Ix2>::from_vec(Ix2::new([2, 3]), vec![1, 2, 3, 4, 0, 6]).unwrap();
391 let r = all_axis(&a, 1).unwrap();
392 assert_eq!(r.shape(), &[2]);
393 assert_eq!(r.as_slice().unwrap(), &[true, false]);
394 }
395
396 #[test]
397 fn any_axis_numeric_float_input_with_nan() {
398 use ferray_core::dimension::Ix1;
399 let a = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![0.0, 0.0, f64::NAN, 0.0]).unwrap();
401 let r = any_axis(&a, 0).unwrap();
402 assert_eq!(r.shape(), &[1]);
404 assert_eq!(r.as_slice().unwrap(), &[true]);
405 }
406
407 #[test]
408 fn all_axis_empty_axis_returns_identity() {
409 use ferray_core::dimension::Ix2;
411 let a = Array::<bool, Ix2>::from_vec(Ix2::new([2, 0]), vec![]).unwrap();
412 let r = all_axis(&a, 1).unwrap();
413 assert_eq!(r.shape(), &[2]);
414 assert_eq!(r.as_slice().unwrap(), &[true, true]);
415 }
416
417 #[test]
418 fn any_axis_empty_axis_returns_identity() {
419 use ferray_core::dimension::Ix2;
421 let a = Array::<bool, Ix2>::from_vec(Ix2::new([2, 0]), vec![]).unwrap();
422 let r = any_axis(&a, 1).unwrap();
423 assert_eq!(r.shape(), &[2]);
424 assert_eq!(r.as_slice().unwrap(), &[false, false]);
425 }
426
427 #[test]
428 fn all_axis_3d_middle_axis() {
429 use ferray_core::dimension::Ix3;
430 let data = vec![
442 true, true, true, false, true, true, true, true, true, true, true, true, ];
445 let a = Array::<bool, Ix3>::from_vec(Ix3::new([2, 3, 2]), data).unwrap();
446 let r = all_axis(&a, 1).unwrap();
447 assert_eq!(r.shape(), &[2, 2]);
448 assert_eq!(r.as_slice().unwrap(), &[true, false, true, true]);
449 }
450
451 #[test]
452 fn all_axis_out_of_bounds_errors() {
453 use ferray_core::dimension::Ix2;
454 let a = Array::<bool, Ix2>::from_vec(Ix2::new([2, 3]), vec![true; 6]).unwrap();
455 assert!(all_axis(&a, 5).is_err());
456 }
457
458 #[test]
459 fn any_axis_out_of_bounds_errors() {
460 use ferray_core::dimension::Ix2;
461 let a = Array::<bool, Ix2>::from_vec(Ix2::new([2, 3]), vec![true; 6]).unwrap();
462 assert!(any_axis(&a, 2).is_err());
463 }
464
465 #[test]
466 fn all_axis_short_circuit_correct_value() {
467 use ferray_core::dimension::Ix2;
470 let a =
471 Array::<bool, Ix2>::from_vec(Ix2::new([1, 4]), vec![false, true, true, true]).unwrap();
472 let r = all_axis(&a, 1).unwrap();
473 assert_eq!(r.as_slice().unwrap(), &[false]);
474 }
475}