Skip to main content

ferray_ma/
constructors.rs

1// ferray-ma: Masking constructors (REQ-7, REQ-8, REQ-9)
2//
3// masked_where, masked_invalid, masked_equal, masked_greater, masked_less,
4// masked_not_equal, masked_greater_equal, masked_less_equal,
5// masked_inside, masked_outside
6
7use ferray_core::Array;
8use ferray_core::dimension::Dimension;
9use ferray_core::dtype::Element;
10use ferray_core::error::FerrayResult;
11use num_traits::Float;
12
13use crate::MaskedArray;
14
15/// Create a `MaskedArray` by masking elements where the condition array is `true`.
16///
17/// # Errors
18/// Returns `FerrayError::ShapeMismatch` if `condition` and `data` have different shapes.
19pub fn masked_where<T: Element + Copy, D: Dimension>(
20    condition: &Array<bool, D>,
21    data: &Array<T, D>,
22) -> FerrayResult<MaskedArray<T, D>> {
23    MaskedArray::new(data.clone(), condition.clone())
24}
25
26/// Create a `MaskedArray` by masking NaN and Inf values.
27///
28/// The data is cloned as-is; masked positions retain their original
29/// NaN/Inf values in the data array. If you want the data positions
30/// replaced with a sentinel value too, use [`fix_invalid`] (#510).
31///
32/// # Errors
33/// Returns an error only for internal failures.
34pub fn masked_invalid<T: Element + Float, D: Dimension>(
35    data: &Array<T, D>,
36) -> FerrayResult<MaskedArray<T, D>> {
37    let mask_data: Vec<bool> = data.iter().map(|v| v.is_nan() || v.is_infinite()).collect();
38    let mask = Array::from_vec(data.dim().clone(), mask_data)?;
39    MaskedArray::new(data.clone(), mask)
40}
41
42/// Create a `MaskedArray` by masking NaN and Inf values AND replacing
43/// them with `fill_value` in the underlying data.
44///
45/// Equivalent to `numpy.ma.fix_invalid(data, fill_value=fill_value)`
46/// (#510). This is a strict superset of [`masked_invalid`]: the
47/// result's mask is identical (positions where `x.is_nan() ||
48/// x.is_infinite()`) but the data array has those positions replaced
49/// with `fill_value` instead of the original NaN/Inf. The result's
50/// `fill_value` is also set to `fill_value` so subsequent operations
51/// that use the fill value behave consistently.
52///
53/// Cleaning a data array in a single pass is the primary use case —
54/// after `fix_invalid`, the underlying `data()` is free of NaN/Inf
55/// and can be passed to operations that would otherwise produce
56/// NaN propagation.
57///
58/// # Errors
59/// Returns an error only for internal failures (allocation).
60pub fn fix_invalid<T: Element + Float, D: Dimension>(
61    data: &Array<T, D>,
62    fill_value: T,
63) -> FerrayResult<MaskedArray<T, D>> {
64    let mut new_data: Vec<T> = Vec::with_capacity(data.size());
65    let mut new_mask: Vec<bool> = Vec::with_capacity(data.size());
66    for &v in data.iter() {
67        if v.is_nan() || v.is_infinite() {
68            new_data.push(fill_value);
69            new_mask.push(true);
70        } else {
71            new_data.push(v);
72            new_mask.push(false);
73        }
74    }
75    let data_arr = Array::from_vec(data.dim().clone(), new_data)?;
76    let mask_arr = Array::from_vec(data.dim().clone(), new_mask)?;
77    let mut out = MaskedArray::new(data_arr, mask_arr)?;
78    out.set_fill_value(fill_value);
79    Ok(out)
80}
81
82/// Create a `MaskedArray` by masking elements equal to `value`.
83///
84/// # Errors
85/// Returns an error only for internal failures.
86pub fn masked_equal<T: Element + PartialEq + Copy, D: Dimension>(
87    data: &Array<T, D>,
88    value: T,
89) -> FerrayResult<MaskedArray<T, D>> {
90    let mask_data: Vec<bool> = data.iter().map(|v| *v == value).collect();
91    let mask = Array::from_vec(data.dim().clone(), mask_data)?;
92    MaskedArray::new(data.clone(), mask)
93}
94
95/// Create a `MaskedArray` by masking elements not equal to `value`.
96///
97/// # Errors
98/// Returns an error only for internal failures.
99pub fn masked_not_equal<T: Element + PartialEq + Copy, D: Dimension>(
100    data: &Array<T, D>,
101    value: T,
102) -> FerrayResult<MaskedArray<T, D>> {
103    let mask_data: Vec<bool> = data.iter().map(|v| *v != value).collect();
104    let mask = Array::from_vec(data.dim().clone(), mask_data)?;
105    MaskedArray::new(data.clone(), mask)
106}
107
108/// Create a `MaskedArray` by masking elements greater than `value`.
109///
110/// # Errors
111/// Returns an error only for internal failures.
112pub fn masked_greater<T: Element + PartialOrd + Copy, D: Dimension>(
113    data: &Array<T, D>,
114    value: T,
115) -> FerrayResult<MaskedArray<T, D>> {
116    let mask_data: Vec<bool> = data.iter().map(|v| *v > value).collect();
117    let mask = Array::from_vec(data.dim().clone(), mask_data)?;
118    MaskedArray::new(data.clone(), mask)
119}
120
121/// Create a `MaskedArray` by masking elements less than `value`.
122///
123/// # Errors
124/// Returns an error only for internal failures.
125pub fn masked_less<T: Element + PartialOrd + Copy, D: Dimension>(
126    data: &Array<T, D>,
127    value: T,
128) -> FerrayResult<MaskedArray<T, D>> {
129    let mask_data: Vec<bool> = data.iter().map(|v| *v < value).collect();
130    let mask = Array::from_vec(data.dim().clone(), mask_data)?;
131    MaskedArray::new(data.clone(), mask)
132}
133
134/// Create a `MaskedArray` by masking elements greater than or equal to `value`.
135///
136/// # Errors
137/// Returns an error only for internal failures.
138pub fn masked_greater_equal<T: Element + PartialOrd + Copy, D: Dimension>(
139    data: &Array<T, D>,
140    value: T,
141) -> FerrayResult<MaskedArray<T, D>> {
142    let mask_data: Vec<bool> = data.iter().map(|v| *v >= value).collect();
143    let mask = Array::from_vec(data.dim().clone(), mask_data)?;
144    MaskedArray::new(data.clone(), mask)
145}
146
147/// Create a `MaskedArray` by masking elements less than or equal to `value`.
148///
149/// # Errors
150/// Returns an error only for internal failures.
151pub fn masked_less_equal<T: Element + PartialOrd + Copy, D: Dimension>(
152    data: &Array<T, D>,
153    value: T,
154) -> FerrayResult<MaskedArray<T, D>> {
155    let mask_data: Vec<bool> = data.iter().map(|v| *v <= value).collect();
156    let mask = Array::from_vec(data.dim().clone(), mask_data)?;
157    MaskedArray::new(data.clone(), mask)
158}
159
160/// Create a `MaskedArray` by masking elements inside the closed interval `[v1, v2]`.
161///
162/// Matches `numpy.ma.masked_inside`: if `v1 > v2` they are swapped so
163/// the interval is always non-degenerate (#266).
164///
165/// # Errors
166/// Returns an error only for internal failures.
167pub fn masked_inside<T: Element + PartialOrd + Copy, D: Dimension>(
168    data: &Array<T, D>,
169    v1: T,
170    v2: T,
171) -> FerrayResult<MaskedArray<T, D>> {
172    let (lo, hi) = if v1 <= v2 { (v1, v2) } else { (v2, v1) };
173    let mask_data: Vec<bool> = data.iter().map(|v| *v >= lo && *v <= hi).collect();
174    let mask = Array::from_vec(data.dim().clone(), mask_data)?;
175    MaskedArray::new(data.clone(), mask)
176}
177
178/// Create a `MaskedArray` by masking elements outside the closed interval `[v1, v2]`.
179///
180/// Matches `numpy.ma.masked_outside`: if `v1 > v2` they are swapped so
181/// the interval is always non-degenerate (#266).
182///
183/// # Errors
184/// Returns an error only for internal failures.
185pub fn masked_outside<T: Element + PartialOrd + Copy, D: Dimension>(
186    data: &Array<T, D>,
187    v1: T,
188    v2: T,
189) -> FerrayResult<MaskedArray<T, D>> {
190    let (lo, hi) = if v1 <= v2 { (v1, v2) } else { (v2, v1) };
191    let mask_data: Vec<bool> = data.iter().map(|v| *v < lo || *v > hi).collect();
192    let mask = Array::from_vec(data.dim().clone(), mask_data)?;
193    MaskedArray::new(data.clone(), mask)
194}
195
196#[cfg(test)]
197mod tests {
198    use super::*;
199    use ferray_core::dimension::Ix1;
200
201    // ---- fix_invalid (#510) ----
202
203    #[test]
204    fn fix_invalid_masks_and_replaces_nan_and_inf() {
205        let data = Array::<f64, Ix1>::from_vec(
206            Ix1::new([6]),
207            vec![1.0, f64::NAN, 3.0, f64::INFINITY, f64::NEG_INFINITY, 6.0],
208        )
209        .unwrap();
210        let ma = fix_invalid(&data, -99.0).unwrap();
211
212        // Mask has the invalid positions set.
213        let m: Vec<bool> = ma.mask().iter().copied().collect();
214        assert_eq!(m, vec![false, true, false, true, true, false]);
215
216        // Data has the invalid positions replaced with the fill value.
217        let d: Vec<f64> = ma.data().iter().copied().collect();
218        assert_eq!(d, vec![1.0, -99.0, 3.0, -99.0, -99.0, 6.0]);
219
220        // Result's fill_value is set to the passed value, not the default.
221        assert_eq!(ma.fill_value(), -99.0);
222    }
223
224    #[test]
225    fn fix_invalid_preserves_valid_values() {
226        // No NaN/Inf → mask is all-false and data matches input exactly.
227        let data = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
228        let ma = fix_invalid(&data, 0.0).unwrap();
229        assert_eq!(
230            ma.mask().iter().copied().collect::<Vec<_>>(),
231            vec![false, false, false, false]
232        );
233        assert_eq!(
234            ma.data().iter().copied().collect::<Vec<_>>(),
235            vec![1.0, 2.0, 3.0, 4.0]
236        );
237    }
238
239    #[test]
240    fn fix_invalid_all_nan_input() {
241        let data =
242            Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![f64::NAN, f64::NAN, f64::NAN]).unwrap();
243        let ma = fix_invalid(&data, 0.0).unwrap();
244        assert_eq!(
245            ma.mask().iter().copied().collect::<Vec<_>>(),
246            vec![true, true, true]
247        );
248        assert_eq!(
249            ma.data().iter().copied().collect::<Vec<_>>(),
250            vec![0.0, 0.0, 0.0]
251        );
252        // Crucially, NaN isn't in the data anymore — downstream ops
253        // that compare against NaN won't propagate.
254        assert!(ma.data().iter().all(|v| !v.is_nan()));
255    }
256
257    #[test]
258    fn fix_invalid_vs_masked_invalid_data_difference() {
259        // masked_invalid leaves NaN in the data; fix_invalid substitutes.
260        let data = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, f64::NAN, 3.0]).unwrap();
261        let via_masked = masked_invalid(&data).unwrap();
262        let via_fixed = fix_invalid(&data, -1.0).unwrap();
263
264        // Masks are identical.
265        assert_eq!(
266            via_masked.mask().iter().copied().collect::<Vec<_>>(),
267            via_fixed.mask().iter().copied().collect::<Vec<_>>()
268        );
269
270        // But the data differs: masked_invalid keeps NaN, fix_invalid
271        // substitutes -1.0.
272        assert!(via_masked.data().iter().nth(1).unwrap().is_nan());
273        assert_eq!(*via_fixed.data().iter().nth(1).unwrap(), -1.0);
274    }
275
276    #[test]
277    fn fix_invalid_2d_shape_preserved() {
278        use ferray_core::dimension::Ix2;
279        let data = Array::<f64, Ix2>::from_vec(
280            Ix2::new([2, 3]),
281            vec![1.0, f64::NAN, 3.0, 4.0, 5.0, f64::INFINITY],
282        )
283        .unwrap();
284        let ma = fix_invalid(&data, -1.0).unwrap();
285        assert_eq!(ma.shape(), &[2, 3]);
286        assert_eq!(
287            ma.mask().iter().copied().collect::<Vec<_>>(),
288            vec![false, true, false, false, false, true]
289        );
290    }
291
292    // ----- masked_inside / masked_outside swap (#266) ---------------------
293
294    #[test]
295    fn masked_inside_canonical_order_masks_interior() {
296        let data = Array::<i32, Ix1>::from_vec(Ix1::new([5]), vec![1, 2, 3, 4, 5]).unwrap();
297        let ma = masked_inside(&data, 2, 4).unwrap();
298        assert_eq!(
299            ma.mask().iter().copied().collect::<Vec<_>>(),
300            vec![false, true, true, true, false]
301        );
302    }
303
304    #[test]
305    fn masked_inside_swaps_when_v1_greater_than_v2() {
306        // #266: numpy auto-swaps; masked_inside(data, 4, 2) must be
307        // equivalent to masked_inside(data, 2, 4).
308        let data = Array::<i32, Ix1>::from_vec(Ix1::new([5]), vec![1, 2, 3, 4, 5]).unwrap();
309        let swapped = masked_inside(&data, 4, 2).unwrap();
310        let canonical = masked_inside(&data, 2, 4).unwrap();
311        assert_eq!(
312            swapped.mask().iter().copied().collect::<Vec<_>>(),
313            canonical.mask().iter().copied().collect::<Vec<_>>()
314        );
315    }
316
317    #[test]
318    fn masked_outside_canonical_order_masks_exterior() {
319        let data = Array::<i32, Ix1>::from_vec(Ix1::new([5]), vec![1, 2, 3, 4, 5]).unwrap();
320        let ma = masked_outside(&data, 2, 4).unwrap();
321        assert_eq!(
322            ma.mask().iter().copied().collect::<Vec<_>>(),
323            vec![true, false, false, false, true]
324        );
325    }
326
327    #[test]
328    fn masked_outside_swaps_when_v1_greater_than_v2() {
329        // #266: same swap behavior on the exterior path.
330        let data = Array::<i32, Ix1>::from_vec(Ix1::new([5]), vec![1, 2, 3, 4, 5]).unwrap();
331        let swapped = masked_outside(&data, 4, 2).unwrap();
332        let canonical = masked_outside(&data, 2, 4).unwrap();
333        assert_eq!(
334            swapped.mask().iter().copied().collect::<Vec<_>>(),
335            canonical.mask().iter().copied().collect::<Vec<_>>()
336        );
337    }
338
339    #[test]
340    fn masked_inside_equal_endpoints_masks_only_that_value() {
341        // [v, v] is a degenerate interval — masking should still pick
342        // up exact-equal entries.
343        let data = Array::<i32, Ix1>::from_vec(Ix1::new([5]), vec![1, 2, 3, 4, 5]).unwrap();
344        let ma = masked_inside(&data, 3, 3).unwrap();
345        assert_eq!(
346            ma.mask().iter().copied().collect::<Vec<_>>(),
347            vec![false, false, true, false, false]
348        );
349    }
350}