Skip to main content

ferray_ma/
mask_ops.rs

1// ferray-ma: Mask manipulation utilities (REQ-15, REQ-16, REQ-17)
2//
3// harden_mask, soften_mask, getmask, getdata, is_masked, count_masked
4
5use ferray_core::Array;
6use ferray_core::dimension::Dimension;
7use ferray_core::dtype::Element;
8use ferray_core::error::{FerrayError, FerrayResult};
9
10use crate::MaskedArray;
11
12impl<T: Element, D: Dimension> MaskedArray<T, D> {
13    /// Harden the mask: prevent subsequent assignments from clearing mask bits.
14    ///
15    /// After this call, any attempt to set a mask bit to `false` via
16    /// `set_mask_flat` or `set_mask` will be silently ignored.
17    ///
18    /// # Errors
19    /// This function does not currently error but returns `Result` for API
20    /// consistency.
21    pub const fn harden_mask(&mut self) -> FerrayResult<()> {
22        self.hard_mask = true;
23        Ok(())
24    }
25
26    /// Soften the mask: allow subsequent assignments to clear mask bits.
27    ///
28    /// # Errors
29    /// This function does not currently error but returns `Result` for API
30    /// consistency.
31    pub const fn soften_mask(&mut self) -> FerrayResult<()> {
32        self.hard_mask = false;
33        Ok(())
34    }
35}
36
37/// Return the mask array of a masked array.
38///
39/// This is equivalent to `ma.mask()` but provided as a free function
40/// for API parity with `NumPy`'s `np.ma.getmask`.
41///
42/// # Errors
43/// This function does not currently error but returns `Result` for API
44/// consistency.
45pub fn getmask<T: Element, D: Dimension>(ma: &MaskedArray<T, D>) -> FerrayResult<Array<bool, D>> {
46    Ok(ma.mask().clone())
47}
48
49/// Return the underlying data array of a masked array.
50///
51/// This is equivalent to `ma.data()` but provided as a free function
52/// for API parity with `NumPy`'s `np.ma.getdata`.
53///
54/// # Errors
55/// This function does not currently error but returns `Result` for API
56/// consistency.
57pub fn getdata<T: Element + Copy, D: Dimension>(
58    ma: &MaskedArray<T, D>,
59) -> FerrayResult<Array<T, D>> {
60    Ok(ma.data().clone())
61}
62
63/// Return `true` if any element in the masked array is masked.
64///
65/// # Errors
66/// This function does not currently error but returns `Result` for API
67/// consistency.
68pub fn is_masked<T: Element, D: Dimension>(ma: &MaskedArray<T, D>) -> FerrayResult<bool> {
69    Ok(ma.mask().iter().any(|m| *m))
70}
71
72/// Count the number of masked elements, optionally along an axis.
73///
74/// If `axis` is `None`, returns the total count of masked elements as a
75/// single-element vector.
76///
77/// For axis-wise counting use [`count_masked_axis`].
78///
79/// # Errors
80/// This function does not currently error but returns `Result` for API
81/// consistency.
82pub fn count_masked<T: Element, D: Dimension>(ma: &MaskedArray<T, D>) -> FerrayResult<usize> {
83    let count = ma.mask().iter().filter(|m| **m).count();
84    Ok(count)
85}
86
87/// Count masked elements along a specific axis (#268).
88///
89/// Reduces the array of masked booleans along `axis`, returning a
90/// `usize` count for each remaining slice. The output shape drops
91/// the reduced axis. Mirrors `numpy.ma.count_masked(a, axis=k)`.
92///
93/// The previous `count_masked(ma, axis: Option<usize>)` accepted an
94/// `axis` argument but ignored it — a silent footgun where callers
95/// got the total count back when they expected per-slice counts.
96///
97/// # Errors
98/// Returns `FerrayError::AxisOutOfBounds` if `axis >= ma.ndim()`.
99pub fn count_masked_axis<T: Element, D: Dimension>(
100    ma: &MaskedArray<T, D>,
101    axis: usize,
102) -> FerrayResult<ferray_core::Array<u64, ferray_core::dimension::IxDyn>> {
103    use ferray_core::dimension::IxDyn;
104
105    let ndim = ma.ndim();
106    if axis >= ndim {
107        return Err(FerrayError::axis_out_of_bounds(axis, ndim));
108    }
109    let shape = ma.shape();
110    let axis_len = shape[axis];
111
112    // Output shape: drop the reduced axis.
113    let out_shape: Vec<usize> = shape
114        .iter()
115        .enumerate()
116        .filter_map(|(i, &s)| if i == axis { None } else { Some(s) })
117        .collect();
118    let out_size: usize = if out_shape.is_empty() {
119        1
120    } else {
121        out_shape.iter().product()
122    };
123
124    // Materialize the mask in row-major flat order so we can index by
125    // computed flat indices regardless of the source memory layout.
126    let mask_data: Vec<bool> = ma.mask().iter().copied().collect();
127    let mut strides = vec![1usize; ndim];
128    for i in (0..ndim.saturating_sub(1)).rev() {
129        strides[i] = strides[i + 1] * shape[i + 1];
130    }
131
132    let mut out_data: Vec<u64> = Vec::with_capacity(out_size);
133    let mut out_multi = vec![0usize; out_shape.len()];
134    for _ in 0..out_size {
135        // Reconstruct the source-side multi-index by inserting the
136        // reduced-axis position back into the right spot.
137        let mut count: u64 = 0;
138        for k in 0..axis_len {
139            let mut flat = 0usize;
140            let mut out_idx = 0usize;
141            for (i, &stride) in strides.iter().enumerate() {
142                if i == axis {
143                    flat += stride * k;
144                } else {
145                    flat += stride * out_multi[out_idx];
146                    out_idx += 1;
147                }
148            }
149            if mask_data[flat] {
150                count += 1;
151            }
152        }
153        out_data.push(count);
154
155        // Increment the output multi-index in row-major order.
156        for i in (0..out_shape.len()).rev() {
157            out_multi[i] += 1;
158            if out_multi[i] < out_shape[i] {
159                break;
160            }
161            out_multi[i] = 0;
162        }
163    }
164
165    let out_dim = if out_shape.is_empty() {
166        IxDyn::new(&[])
167    } else {
168        IxDyn::new(&out_shape)
169    };
170    ferray_core::Array::from_vec(out_dim, out_data)
171}