ferray_ma/mask_ops.rs
1// ferray-ma: Mask manipulation utilities (REQ-15, REQ-16, REQ-17)
2//
3// harden_mask, soften_mask, getmask, getdata, is_masked, count_masked
4
5use ferray_core::Array;
6use ferray_core::dimension::Dimension;
7use ferray_core::dtype::Element;
8use ferray_core::error::{FerrayError, FerrayResult};
9
10use crate::MaskedArray;
11
12impl<T: Element, D: Dimension> MaskedArray<T, D> {
13 /// Harden the mask: prevent subsequent assignments from clearing mask bits.
14 ///
15 /// After this call, any attempt to set a mask bit to `false` via
16 /// `set_mask_flat` or `set_mask` will be silently ignored.
17 ///
18 /// # Errors
19 /// This function does not currently error but returns `Result` for API
20 /// consistency.
21 pub const fn harden_mask(&mut self) -> FerrayResult<()> {
22 self.hard_mask = true;
23 Ok(())
24 }
25
26 /// Soften the mask: allow subsequent assignments to clear mask bits.
27 ///
28 /// # Errors
29 /// This function does not currently error but returns `Result` for API
30 /// consistency.
31 pub const fn soften_mask(&mut self) -> FerrayResult<()> {
32 self.hard_mask = false;
33 Ok(())
34 }
35}
36
37/// Return the mask array of a masked array.
38///
39/// This is equivalent to `ma.mask()` but provided as a free function
40/// for API parity with `NumPy`'s `np.ma.getmask`.
41///
42/// # Errors
43/// This function does not currently error but returns `Result` for API
44/// consistency.
45pub fn getmask<T: Element, D: Dimension>(ma: &MaskedArray<T, D>) -> FerrayResult<Array<bool, D>> {
46 Ok(ma.mask().clone())
47}
48
49/// Return the underlying data array of a masked array.
50///
51/// This is equivalent to `ma.data()` but provided as a free function
52/// for API parity with `NumPy`'s `np.ma.getdata`.
53///
54/// # Errors
55/// This function does not currently error but returns `Result` for API
56/// consistency.
57pub fn getdata<T: Element + Copy, D: Dimension>(
58 ma: &MaskedArray<T, D>,
59) -> FerrayResult<Array<T, D>> {
60 Ok(ma.data().clone())
61}
62
63/// Return `true` if any element in the masked array is masked.
64///
65/// # Errors
66/// This function does not currently error but returns `Result` for API
67/// consistency.
68pub fn is_masked<T: Element, D: Dimension>(ma: &MaskedArray<T, D>) -> FerrayResult<bool> {
69 Ok(ma.mask().iter().any(|m| *m))
70}
71
72/// Count the number of masked elements, optionally along an axis.
73///
74/// If `axis` is `None`, returns the total count of masked elements as a
75/// single-element vector.
76///
77/// For axis-wise counting use [`count_masked_axis`].
78///
79/// # Errors
80/// This function does not currently error but returns `Result` for API
81/// consistency.
82pub fn count_masked<T: Element, D: Dimension>(ma: &MaskedArray<T, D>) -> FerrayResult<usize> {
83 let count = ma.mask().iter().filter(|m| **m).count();
84 Ok(count)
85}
86
87/// Count masked elements along a specific axis (#268).
88///
89/// Reduces the array of masked booleans along `axis`, returning a
90/// `usize` count for each remaining slice. The output shape drops
91/// the reduced axis. Mirrors `numpy.ma.count_masked(a, axis=k)`.
92///
93/// The previous `count_masked(ma, axis: Option<usize>)` accepted an
94/// `axis` argument but ignored it — a silent footgun where callers
95/// got the total count back when they expected per-slice counts.
96///
97/// # Errors
98/// Returns `FerrayError::AxisOutOfBounds` if `axis >= ma.ndim()`.
99pub fn count_masked_axis<T: Element, D: Dimension>(
100 ma: &MaskedArray<T, D>,
101 axis: usize,
102) -> FerrayResult<ferray_core::Array<u64, ferray_core::dimension::IxDyn>> {
103 use ferray_core::dimension::IxDyn;
104
105 let ndim = ma.ndim();
106 if axis >= ndim {
107 return Err(FerrayError::axis_out_of_bounds(axis, ndim));
108 }
109 let shape = ma.shape();
110 let axis_len = shape[axis];
111
112 // Output shape: drop the reduced axis.
113 let out_shape: Vec<usize> = shape
114 .iter()
115 .enumerate()
116 .filter_map(|(i, &s)| if i == axis { None } else { Some(s) })
117 .collect();
118 let out_size: usize = if out_shape.is_empty() {
119 1
120 } else {
121 out_shape.iter().product()
122 };
123
124 // Materialize the mask in row-major flat order so we can index by
125 // computed flat indices regardless of the source memory layout.
126 let mask_data: Vec<bool> = ma.mask().iter().copied().collect();
127 let mut strides = vec![1usize; ndim];
128 for i in (0..ndim.saturating_sub(1)).rev() {
129 strides[i] = strides[i + 1] * shape[i + 1];
130 }
131
132 let mut out_data: Vec<u64> = Vec::with_capacity(out_size);
133 let mut out_multi = vec![0usize; out_shape.len()];
134 for _ in 0..out_size {
135 // Reconstruct the source-side multi-index by inserting the
136 // reduced-axis position back into the right spot.
137 let mut count: u64 = 0;
138 for k in 0..axis_len {
139 let mut flat = 0usize;
140 let mut out_idx = 0usize;
141 for (i, &stride) in strides.iter().enumerate() {
142 if i == axis {
143 flat += stride * k;
144 } else {
145 flat += stride * out_multi[out_idx];
146 out_idx += 1;
147 }
148 }
149 if mask_data[flat] {
150 count += 1;
151 }
152 }
153 out_data.push(count);
154
155 // Increment the output multi-index in row-major order.
156 for i in (0..out_shape.len()).rev() {
157 out_multi[i] += 1;
158 if out_multi[i] < out_shape[i] {
159 break;
160 }
161 out_multi[i] = 0;
162 }
163 }
164
165 let out_dim = if out_shape.is_empty() {
166 IxDyn::new(&[])
167 } else {
168 IxDyn::new(&out_shape)
169 };
170 ferray_core::Array::from_vec(out_dim, out_data)
171}