Skip to main content

ferray_ufunc/ops/
logical.rs

1// ferray-ufunc: Logical functions
2//
3// logical_and, logical_or, logical_xor, logical_not, all, any, all_axis,
4// any_axis.
5
6use ferray_core::Array;
7use ferray_core::dimension::{Dimension, IxDyn};
8use ferray_core::dtype::Element;
9use ferray_core::error::{FerrayError, FerrayResult};
10
11use crate::helpers::{binary_map_op, unary_map_op};
12
13/// Trait for types that can be interpreted as boolean for logical ops.
14pub trait Logical {
15    /// Return true if the value is "truthy" (nonzero).
16    fn is_truthy(&self) -> bool;
17}
18
19impl Logical for bool {
20    #[inline]
21    fn is_truthy(&self) -> bool {
22        *self
23    }
24}
25
26macro_rules! impl_logical_numeric {
27    ($($ty:ty),*) => {
28        $(
29            impl Logical for $ty {
30                #[inline]
31                fn is_truthy(&self) -> bool {
32                    *self != 0 as $ty
33                }
34            }
35        )*
36    };
37}
38
39impl_logical_numeric!(i8, i16, i32, i64, i128, u8, u16, u32, u64, u128);
40
41impl Logical for f32 {
42    #[inline]
43    fn is_truthy(&self) -> bool {
44        *self != 0.0
45    }
46}
47
48impl Logical for f64 {
49    #[inline]
50    fn is_truthy(&self) -> bool {
51        *self != 0.0
52    }
53}
54
55impl Logical for num_complex::Complex<f32> {
56    #[inline]
57    fn is_truthy(&self) -> bool {
58        self.re != 0.0 || self.im != 0.0
59    }
60}
61
62impl Logical for num_complex::Complex<f64> {
63    #[inline]
64    fn is_truthy(&self) -> bool {
65        self.re != 0.0 || self.im != 0.0
66    }
67}
68
69/// Elementwise logical AND.
70pub fn logical_and<T, D>(a: &Array<T, D>, b: &Array<T, D>) -> FerrayResult<Array<bool, D>>
71where
72    T: Element + Logical + Copy,
73    D: Dimension,
74{
75    binary_map_op(a, b, |x, y| x.is_truthy() && y.is_truthy())
76}
77
78/// Elementwise logical OR.
79pub fn logical_or<T, D>(a: &Array<T, D>, b: &Array<T, D>) -> FerrayResult<Array<bool, D>>
80where
81    T: Element + Logical + Copy,
82    D: Dimension,
83{
84    binary_map_op(a, b, |x, y| x.is_truthy() || y.is_truthy())
85}
86
87/// Elementwise logical XOR.
88pub fn logical_xor<T, D>(a: &Array<T, D>, b: &Array<T, D>) -> FerrayResult<Array<bool, D>>
89where
90    T: Element + Logical + Copy,
91    D: Dimension,
92{
93    binary_map_op(a, b, |x, y| x.is_truthy() ^ y.is_truthy())
94}
95
96/// Elementwise logical NOT.
97pub fn logical_not<T, D>(input: &Array<T, D>) -> FerrayResult<Array<bool, D>>
98where
99    T: Element + Logical + Copy,
100    D: Dimension,
101{
102    unary_map_op(input, |x| !x.is_truthy())
103}
104
105/// Test whether all elements are truthy.
106pub fn all<T, D>(input: &Array<T, D>) -> bool
107where
108    T: Element + Logical,
109    D: Dimension,
110{
111    input.iter().all(Logical::is_truthy)
112}
113
114/// Test whether any element is truthy.
115pub fn any<T, D>(input: &Array<T, D>) -> bool
116where
117    T: Element + Logical,
118    D: Dimension,
119{
120    input.iter().any(Logical::is_truthy)
121}
122
123/// Shared axis-reduction kernel for [`all_axis`] and [`any_axis`]. Collapses
124/// `axis` by folding truthiness with `op`, short-circuiting as soon as the
125/// accumulator reaches `stop_at` (the value the reduction can no longer
126/// change away from).
127fn reduce_truthy_axis<T, D, F>(
128    input: &Array<T, D>,
129    axis: usize,
130    identity: bool,
131    stop_at: bool,
132    op: F,
133) -> FerrayResult<Array<bool, IxDyn>>
134where
135    T: Element + Logical,
136    D: Dimension,
137    F: Fn(bool, &T) -> bool,
138{
139    let ndim = input.ndim();
140    if axis >= ndim {
141        return Err(FerrayError::axis_out_of_bounds(axis, ndim));
142    }
143
144    let shape: Vec<usize> = input.shape().to_vec();
145    let axis_len = shape[axis];
146    let outer_size: usize = shape[..axis].iter().product();
147    let inner_size: usize = shape[axis + 1..].iter().product();
148
149    // Materialize in row-major (iteration) order so that
150    // (outer, k, inner) indexing resolves into the flat buffer directly.
151    // Uses Clone rather than Copy so the kernel works for any Element type.
152    let data: Vec<T> = input.iter().cloned().collect();
153
154    let mut out_shape: Vec<usize> = shape
155        .iter()
156        .enumerate()
157        .filter_map(|(i, &s)| if i == axis { None } else { Some(s) })
158        .collect();
159    let out_size: usize = out_shape.iter().product::<usize>().max(1);
160
161    let mut result = vec![identity; out_size];
162
163    for outer in 0..outer_size {
164        for inner in 0..inner_size {
165            let out_idx = outer * inner_size + inner;
166            let mut acc = identity;
167            for k in 0..axis_len {
168                let idx = outer * axis_len * inner_size + k * inner_size + inner;
169                acc = op(acc, &data[idx]);
170                if acc == stop_at {
171                    break;
172                }
173            }
174            result[out_idx] = acc;
175        }
176    }
177
178    // NumPy collapses a single-axis input to a 0-D scalar; we expose it as a
179    // length-1 1-D array because ferray has no Ix0 in the IxDyn constructor
180    // path used here and users can .reshape(&[]) if they truly need scalar
181    // rank.
182    if out_shape.is_empty() {
183        out_shape.push(1);
184    }
185    Array::from_vec(IxDyn::from(&out_shape[..]), result)
186}
187
188/// Test whether all elements along `axis` are truthy, returning an array
189/// with `axis` removed. Equivalent to `np.all(input, axis=axis)`.
190///
191/// Short-circuits per-slice: as soon as a `false` is found along the axis
192/// for a given output position, the remainder of that slice is skipped.
193///
194/// # Errors
195/// Returns `FerrayError::AxisOutOfBounds` if `axis >= input.ndim()`.
196pub fn all_axis<T, D>(input: &Array<T, D>, axis: usize) -> FerrayResult<Array<bool, IxDyn>>
197where
198    T: Element + Logical,
199    D: Dimension,
200{
201    reduce_truthy_axis(input, axis, true, false, |acc, x| acc && x.is_truthy())
202}
203
204/// Test whether any element along `axis` is truthy, returning an array with
205/// `axis` removed. Equivalent to `np.any(input, axis=axis)`.
206///
207/// Short-circuits per-slice: as soon as a `true` is found along the axis
208/// for a given output position, the remainder of that slice is skipped.
209///
210/// # Errors
211/// Returns `FerrayError::AxisOutOfBounds` if `axis >= input.ndim()`.
212pub fn any_axis<T, D>(input: &Array<T, D>, axis: usize) -> FerrayResult<Array<bool, IxDyn>>
213where
214    T: Element + Logical,
215    D: Dimension,
216{
217    reduce_truthy_axis(input, axis, false, true, |acc, x| acc || x.is_truthy())
218}
219
220#[cfg(test)]
221mod tests {
222    use super::*;
223    use ferray_core::dimension::Ix1;
224
225    fn arr1_bool(data: Vec<bool>) -> Array<bool, Ix1> {
226        let n = data.len();
227        Array::from_vec(Ix1::new([n]), data).unwrap()
228    }
229
230    fn arr1_i32(data: Vec<i32>) -> Array<i32, Ix1> {
231        let n = data.len();
232        Array::from_vec(Ix1::new([n]), data).unwrap()
233    }
234
235    #[test]
236    fn test_logical_and() {
237        let a = arr1_bool(vec![true, true, false, false]);
238        let b = arr1_bool(vec![true, false, true, false]);
239        let r = logical_and(&a, &b).unwrap();
240        assert_eq!(r.as_slice().unwrap(), &[true, false, false, false]);
241    }
242
243    #[test]
244    fn test_logical_or() {
245        let a = arr1_bool(vec![true, true, false, false]);
246        let b = arr1_bool(vec![true, false, true, false]);
247        let r = logical_or(&a, &b).unwrap();
248        assert_eq!(r.as_slice().unwrap(), &[true, true, true, false]);
249    }
250
251    #[test]
252    fn test_logical_xor() {
253        let a = arr1_bool(vec![true, true, false, false]);
254        let b = arr1_bool(vec![true, false, true, false]);
255        let r = logical_xor(&a, &b).unwrap();
256        assert_eq!(r.as_slice().unwrap(), &[false, true, true, false]);
257    }
258
259    #[test]
260    fn test_logical_not() {
261        let a = arr1_bool(vec![true, false, true]);
262        let r = logical_not(&a).unwrap();
263        assert_eq!(r.as_slice().unwrap(), &[false, true, false]);
264    }
265
266    #[test]
267    fn test_logical_and_numeric() {
268        let a = arr1_i32(vec![1, 1, 0, 0]);
269        let b = arr1_i32(vec![1, 0, 1, 0]);
270        let r = logical_and(&a, &b).unwrap();
271        assert_eq!(r.as_slice().unwrap(), &[true, false, false, false]);
272    }
273
274    #[test]
275    fn test_all() {
276        let a = arr1_bool(vec![true, true, true]);
277        assert!(all(&a));
278        let b = arr1_bool(vec![true, false, true]);
279        assert!(!all(&b));
280    }
281
282    #[test]
283    fn test_any() {
284        let a = arr1_bool(vec![false, false, true]);
285        assert!(any(&a));
286        let b = arr1_bool(vec![false, false, false]);
287        assert!(!any(&b));
288    }
289
290    #[test]
291    fn test_all_numeric() {
292        let a = arr1_i32(vec![1, 2, 3]);
293        assert!(all(&a));
294        let b = arr1_i32(vec![1, 0, 3]);
295        assert!(!all(&b));
296    }
297
298    // -----------------------------------------------------------------------
299    // Broadcasting tests for logical ops (issue #379)
300    // -----------------------------------------------------------------------
301
302    #[test]
303    fn test_logical_and_broadcasts() {
304        use ferray_core::dimension::Ix2;
305        let a = Array::<bool, Ix2>::from_vec(Ix2::new([2, 1]), vec![true, false]).unwrap();
306        let b = Array::<bool, Ix2>::from_vec(Ix2::new([1, 3]), vec![true, false, true]).unwrap();
307        let r = logical_and(&a, &b).unwrap();
308        assert_eq!(r.shape(), &[2, 3]);
309        assert_eq!(
310            r.iter().copied().collect::<Vec<_>>(),
311            vec![true, false, true, false, false, false]
312        );
313    }
314
315    #[test]
316    fn test_logical_or_broadcasts() {
317        use ferray_core::dimension::Ix2;
318        let a = Array::<bool, Ix2>::from_vec(Ix2::new([2, 1]), vec![true, false]).unwrap();
319        let b = Array::<bool, Ix2>::from_vec(Ix2::new([1, 3]), vec![true, false, true]).unwrap();
320        let r = logical_or(&a, &b).unwrap();
321        assert_eq!(r.shape(), &[2, 3]);
322        assert_eq!(
323            r.iter().copied().collect::<Vec<_>>(),
324            vec![true, true, true, true, false, true]
325        );
326    }
327
328    // -----------------------------------------------------------------------
329    // Axis-aware all/any (#389)
330    // -----------------------------------------------------------------------
331
332    #[test]
333    fn all_axis_2d_rows() {
334        use ferray_core::dimension::Ix2;
335        // [[T,T,T],[T,F,T]] — row 0 all true, row 1 has a false.
336        let a = Array::<bool, Ix2>::from_vec(
337            Ix2::new([2, 3]),
338            vec![true, true, true, true, false, true],
339        )
340        .unwrap();
341        let r = all_axis(&a, 1).unwrap();
342        assert_eq!(r.shape(), &[2]);
343        assert_eq!(r.as_slice().unwrap(), &[true, false]);
344    }
345
346    #[test]
347    fn all_axis_2d_cols() {
348        use ferray_core::dimension::Ix2;
349        // [[T,T,F],[T,T,T]] — col 0 all T, col 1 all T, col 2 has F.
350        let a = Array::<bool, Ix2>::from_vec(
351            Ix2::new([2, 3]),
352            vec![true, true, false, true, true, true],
353        )
354        .unwrap();
355        let r = all_axis(&a, 0).unwrap();
356        assert_eq!(r.shape(), &[3]);
357        assert_eq!(r.as_slice().unwrap(), &[true, true, false]);
358    }
359
360    #[test]
361    fn any_axis_2d_rows() {
362        use ferray_core::dimension::Ix2;
363        let a = Array::<bool, Ix2>::from_vec(
364            Ix2::new([2, 3]),
365            vec![false, false, false, false, true, false],
366        )
367        .unwrap();
368        let r = any_axis(&a, 1).unwrap();
369        assert_eq!(r.shape(), &[2]);
370        assert_eq!(r.as_slice().unwrap(), &[false, true]);
371    }
372
373    #[test]
374    fn any_axis_2d_cols() {
375        use ferray_core::dimension::Ix2;
376        let a = Array::<bool, Ix2>::from_vec(
377            Ix2::new([2, 3]),
378            vec![false, true, false, false, false, false],
379        )
380        .unwrap();
381        let r = any_axis(&a, 0).unwrap();
382        assert_eq!(r.shape(), &[3]);
383        assert_eq!(r.as_slice().unwrap(), &[false, true, false]);
384    }
385
386    #[test]
387    fn all_axis_numeric_integer_input() {
388        use ferray_core::dimension::Ix2;
389        // Integer truthiness: 0 is false, nonzero is true.
390        let a = Array::<i32, Ix2>::from_vec(Ix2::new([2, 3]), vec![1, 2, 3, 4, 0, 6]).unwrap();
391        let r = all_axis(&a, 1).unwrap();
392        assert_eq!(r.shape(), &[2]);
393        assert_eq!(r.as_slice().unwrap(), &[true, false]);
394    }
395
396    #[test]
397    fn any_axis_numeric_float_input_with_nan() {
398        use ferray_core::dimension::Ix1;
399        // NaN is truthy (nonzero), Inf is truthy, 0 is falsy.
400        let a = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![0.0, 0.0, f64::NAN, 0.0]).unwrap();
401        let r = any_axis(&a, 0).unwrap();
402        // Reducing a 1-D array along its only axis produces a length-1 result.
403        assert_eq!(r.shape(), &[1]);
404        assert_eq!(r.as_slice().unwrap(), &[true]);
405    }
406
407    #[test]
408    fn all_axis_empty_axis_returns_identity() {
409        // np.all over an empty slice returns True (vacuously).
410        use ferray_core::dimension::Ix2;
411        let a = Array::<bool, Ix2>::from_vec(Ix2::new([2, 0]), vec![]).unwrap();
412        let r = all_axis(&a, 1).unwrap();
413        assert_eq!(r.shape(), &[2]);
414        assert_eq!(r.as_slice().unwrap(), &[true, true]);
415    }
416
417    #[test]
418    fn any_axis_empty_axis_returns_identity() {
419        // np.any over an empty slice returns False.
420        use ferray_core::dimension::Ix2;
421        let a = Array::<bool, Ix2>::from_vec(Ix2::new([2, 0]), vec![]).unwrap();
422        let r = any_axis(&a, 1).unwrap();
423        assert_eq!(r.shape(), &[2]);
424        assert_eq!(r.as_slice().unwrap(), &[false, false]);
425    }
426
427    #[test]
428    fn all_axis_3d_middle_axis() {
429        use ferray_core::dimension::Ix3;
430        // shape (2, 3, 2): reduce axis=1 → shape (2, 2).
431        // Layer 0:
432        //   row0: T, T
433        //   row1: T, F
434        //   row2: T, T
435        //   → per-col: T && T && T = T  ;  T && F && T = F
436        // Layer 1:
437        //   row0: T, T
438        //   row1: T, T
439        //   row2: T, T
440        //   → T, T
441        let data = vec![
442            true, true, true, false, true, true, // layer 0
443            true, true, true, true, true, true, // layer 1
444        ];
445        let a = Array::<bool, Ix3>::from_vec(Ix3::new([2, 3, 2]), data).unwrap();
446        let r = all_axis(&a, 1).unwrap();
447        assert_eq!(r.shape(), &[2, 2]);
448        assert_eq!(r.as_slice().unwrap(), &[true, false, true, true]);
449    }
450
451    #[test]
452    fn all_axis_out_of_bounds_errors() {
453        use ferray_core::dimension::Ix2;
454        let a = Array::<bool, Ix2>::from_vec(Ix2::new([2, 3]), vec![true; 6]).unwrap();
455        assert!(all_axis(&a, 5).is_err());
456    }
457
458    #[test]
459    fn any_axis_out_of_bounds_errors() {
460        use ferray_core::dimension::Ix2;
461        let a = Array::<bool, Ix2>::from_vec(Ix2::new([2, 3]), vec![true; 6]).unwrap();
462        assert!(any_axis(&a, 2).is_err());
463    }
464
465    #[test]
466    fn all_axis_short_circuit_correct_value() {
467        // Regression: ensure that once we find a `false`, the output is
468        // actually set to `false` (not the `true` identity seed).
469        use ferray_core::dimension::Ix2;
470        let a =
471            Array::<bool, Ix2>::from_vec(Ix2::new([1, 4]), vec![false, true, true, true]).unwrap();
472        let r = all_axis(&a, 1).unwrap();
473        assert_eq!(r.as_slice().unwrap(), &[false]);
474    }
475}