Skip to main content

ferray_ma/
lib.rs

1// ferray-ma: Masked arrays with mask propagation
2//
3// This crate implements `numpy.ma`-style masked arrays for the ferray workspace.
4// A `MaskedArray<T, D>` pairs a data array with a boolean mask array where
5// `true` = masked/invalid. All operations (arithmetic, reductions, ufuncs)
6// respect the mask by skipping masked elements.
7//
8// # Modules
9// - `masked_array`: The core `MaskedArray<T, D>` type
10// - `reductions`: Masked mean, sum, min, max, var, std, count
11// - `constructors`: masked_where, masked_invalid, masked_equal, etc.
12// - `arithmetic`: Masked binary ops with mask union
13// - `ufunc_support`: Wrapper functions for ufunc operations on MaskedArrays
14// - `sorting`: Masked sort, argsort
15// - `mask_ops`: harden_mask, soften_mask, getmask, getdata, is_masked, count_masked
16// - `filled`: filled, compressed
17
18// Masked reductions divide running sums by valid-element counts and
19// truncate `f64` results to integer index types in argmin/argmax. Float
20// equality is also intrinsic to `masked_equal` and `getdata` semantics.
21#![allow(
22    clippy::cast_possible_truncation,
23    clippy::cast_possible_wrap,
24    clippy::cast_precision_loss,
25    clippy::cast_sign_loss,
26    clippy::cast_lossless,
27    clippy::float_cmp,
28    clippy::missing_errors_doc,
29    clippy::missing_panics_doc,
30    clippy::many_single_char_names,
31    clippy::similar_names,
32    clippy::items_after_statements,
33    clippy::option_if_let_else,
34    clippy::too_long_first_doc_paragraph,
35    clippy::needless_pass_by_value,
36    clippy::match_same_arms
37)]
38
39pub mod arithmetic;
40pub mod constructors;
41pub mod extras;
42pub mod filled;
43pub mod interop;
44/// Binary I/O (save/load) for `MaskedArray` via ferray-io (#509).
45///
46/// Gated behind the `io` cargo feature so callers who don't need
47/// disk I/O don't have to pull in the zip + binary reader dependency
48/// tree through ferray-io.
49#[cfg(feature = "io")]
50pub mod io;
51pub mod manipulation;
52pub mod mask_ops;
53pub mod masked_array;
54pub mod reductions;
55pub mod sorting;
56pub mod ufunc_support;
57
58// Re-export the primary type at crate root
59pub use masked_array::MaskedArray;
60
61// Re-export masking constructors
62pub use constructors::{
63    fix_invalid, masked_equal, masked_greater, masked_greater_equal, masked_inside, masked_invalid,
64    masked_less, masked_less_equal, masked_not_equal, masked_outside, masked_where,
65};
66
67// Re-export arithmetic operations
68pub use arithmetic::{
69    masked_add, masked_add_array, masked_div, masked_div_array, masked_mul, masked_mul_array,
70    masked_sub, masked_sub_array,
71};
72
73// Re-export mask manipulation functions
74pub use mask_ops::{count_masked, count_masked_axis, getdata, getmask, is_masked};
75
76// Re-export MaskAware trait (#505) for downstream code that wants
77// to write functions polymorphic over Array and MaskedArray.
78pub use interop::{MaskAware, ma_apply_unary};
79
80// Re-export generic ufunc helpers (#513) — the escape hatch for
81// ufuncs that don't have a dedicated named wrapper. Users with an
82// arbitrary `Fn(T) -> T` / `Fn(T, T) -> T` closure can call
83// `ferray_ma::masked_unary(ma, my_fn)` directly.
84pub use ufunc_support::{masked_binary, masked_unary};
85
86// Domain-aware ufunc wrappers (#503) — auto-mask out-of-domain
87// inputs so the result mask carries a "safe to use" contract.
88pub use ufunc_support::{
89    arccos_domain, arccosh_domain, arcsin_domain, arctanh_domain, divide_domain, log_domain,
90    log2_domain, log10_domain, masked_binary_domain, masked_unary_domain, sqrt_domain,
91};
92
93// numpy.ma extras: full reductions, constructors, mask manipulation,
94// linalg-lite, set ops, fill-value protocol, comparison/logical ufuncs,
95// and class helpers. See extras.rs for the catalogue.
96pub use extras::{
97    NOMASK, common_fill_value, default_fill_value_bool, default_fill_value_f32,
98    default_fill_value_f64, default_fill_value_i64, getmaskarray, ids, is_ma, is_masked_array,
99    ma_apply_along_axis, ma_apply_over_axes, ma_concatenate, ma_equal, ma_greater,
100    ma_greater_equal, ma_in1d, ma_isin, ma_less, ma_less_equal, ma_logical_and, ma_logical_not,
101    ma_logical_or, ma_logical_xor, ma_not_equal, ma_unique, ma_vander, make_mask, make_mask_none,
102    mask_or, masked_all, masked_all_like, masked_values, maximum_fill_value, minimum_fill_value,
103};
104
105#[cfg(test)]
106mod tests {
107    use super::*;
108    use ferray_core::Array;
109    use ferray_core::dimension::Ix1;
110
111    // -----------------------------------------------------------------------
112    // AC-1: MaskedArray::new([1,2,3,4,5], [false,false,true,false,false]).mean() == 3.0
113    // -----------------------------------------------------------------------
114    #[test]
115    fn ac1_masked_mean_skips_masked() {
116        let data =
117            Array::<f64, Ix1>::from_vec(Ix1::new([5]), vec![1.0, 2.0, 3.0, 4.0, 5.0]).unwrap();
118        let mask =
119            Array::<bool, Ix1>::from_vec(Ix1::new([5]), vec![false, false, true, false, false])
120                .unwrap();
121        let ma = MaskedArray::new(data, mask).unwrap();
122        let mean = ma.mean().unwrap();
123        // (1 + 2 + 4 + 5) / 4 = 3.0
124        assert!((mean - 3.0).abs() < 1e-10);
125    }
126
127    // -----------------------------------------------------------------------
128    // AC-2: filled(0.0) replaces masked elements with 0.0
129    // -----------------------------------------------------------------------
130    #[test]
131    fn ac2_filled_replaces_masked() {
132        let data = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
133        let mask =
134            Array::<bool, Ix1>::from_vec(Ix1::new([4]), vec![false, true, false, true]).unwrap();
135        let ma = MaskedArray::new(data, mask).unwrap();
136        let filled = ma.filled(0.0).unwrap();
137        assert_eq!(filled.as_slice().unwrap(), &[1.0, 0.0, 3.0, 0.0]);
138    }
139
140    // -----------------------------------------------------------------------
141    // AC-3: compressed() returns only unmasked elements as 1D
142    // -----------------------------------------------------------------------
143    #[test]
144    fn ac3_compressed_returns_unmasked() {
145        let data =
146            Array::<f64, Ix1>::from_vec(Ix1::new([5]), vec![10.0, 20.0, 30.0, 40.0, 50.0]).unwrap();
147        let mask =
148            Array::<bool, Ix1>::from_vec(Ix1::new([5]), vec![false, true, false, true, false])
149                .unwrap();
150        let ma = MaskedArray::new(data, mask).unwrap();
151        let compressed = ma.compressed().unwrap();
152        assert_eq!(compressed.as_slice().unwrap(), &[10.0, 30.0, 50.0]);
153    }
154
155    // -----------------------------------------------------------------------
156    // AC-4: masked_invalid masks NaN and Inf
157    // -----------------------------------------------------------------------
158    #[test]
159    fn ac4_masked_invalid_nan_inf() {
160        let data =
161            Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![1.0, f64::NAN, 3.0, f64::INFINITY])
162                .unwrap();
163        let ma = masked_invalid(&data).unwrap();
164        let mask_vals: Vec<bool> = ma.mask().iter().copied().collect();
165        assert_eq!(mask_vals, vec![false, true, false, true]);
166    }
167
168    // -----------------------------------------------------------------------
169    // AC-5: ma1 + ma2 produces correct mask union and correct values
170    // -----------------------------------------------------------------------
171    #[test]
172    fn ac5_add_mask_union() {
173        let d1 = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
174        let m1 =
175            Array::<bool, Ix1>::from_vec(Ix1::new([4]), vec![false, true, false, false]).unwrap();
176        let ma1 = MaskedArray::new(d1, m1).unwrap();
177
178        let d2 = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![10.0, 20.0, 30.0, 40.0]).unwrap();
179        let m2 =
180            Array::<bool, Ix1>::from_vec(Ix1::new([4]), vec![false, false, true, false]).unwrap();
181        let ma2 = MaskedArray::new(d2, m2).unwrap();
182
183        let result = masked_add(&ma1, &ma2).unwrap();
184        let mask_vals: Vec<bool> = result.mask().iter().copied().collect();
185        // Mask union: [false, true, true, false]
186        assert_eq!(mask_vals, vec![false, true, true, false]);
187        // Unmasked values: 1+10=11, 4+40=44; masked get 0.0
188        let data_vals: Vec<f64> = result.data().iter().copied().collect();
189        assert!((data_vals[0] - 11.0).abs() < 1e-10);
190        assert!((data_vals[3] - 44.0).abs() < 1e-10);
191    }
192
193    #[test]
194    fn operator_add_matches_masked_add() {
195        let d1 = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
196        let m1 = Array::<bool, Ix1>::from_vec(Ix1::new([3]), vec![false, true, false]).unwrap();
197        let ma1 = MaskedArray::new(d1, m1).unwrap();
198
199        let d2 = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![10.0, 20.0, 30.0]).unwrap();
200        let m2 = Array::<bool, Ix1>::from_vec(Ix1::new([3]), vec![false, false, true]).unwrap();
201        let ma2 = MaskedArray::new(d2, m2).unwrap();
202
203        // Use operator syntax
204        let result = (&ma1 + &ma2).unwrap();
205        let mask_vals: Vec<bool> = result.mask().iter().copied().collect();
206        assert_eq!(mask_vals, vec![false, true, true]);
207        let data_vals: Vec<f64> = result.data().iter().copied().collect();
208        assert!((data_vals[0] - 11.0).abs() < 1e-10);
209    }
210
211    // -----------------------------------------------------------------------
212    // AC-7: sin(masked_array) returns same mask, correct values
213    // -----------------------------------------------------------------------
214    #[test]
215    fn ac7_ufunc_sin_masked() {
216        use std::f64::consts::FRAC_PI_2;
217        let data =
218            Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![0.0, FRAC_PI_2, FRAC_PI_2]).unwrap();
219        let mask = Array::<bool, Ix1>::from_vec(Ix1::new([3]), vec![false, true, false]).unwrap();
220        let ma = MaskedArray::new(data, mask).unwrap();
221        let result = ufunc_support::sin(&ma).unwrap();
222        let mask_vals: Vec<bool> = result.mask().iter().copied().collect();
223        assert_eq!(mask_vals, vec![false, true, false]);
224        let data_vals: Vec<f64> = result.data().iter().copied().collect();
225        // sin(0) = 0, masked = 0.0 (skipped), sin(pi/2) = 1.0
226        assert!((data_vals[0] - 0.0).abs() < 1e-10);
227        assert!((data_vals[2] - 1.0).abs() < 1e-10);
228    }
229
230    // -----------------------------------------------------------------------
231    // AC-8: sort places masked at end; harden_mask prevents clearing
232    // -----------------------------------------------------------------------
233    #[test]
234    fn ac8_sort_masked_at_end() {
235        let data =
236            Array::<f64, Ix1>::from_vec(Ix1::new([5]), vec![5.0, 1.0, 3.0, 2.0, 4.0]).unwrap();
237        let mask =
238            Array::<bool, Ix1>::from_vec(Ix1::new([5]), vec![false, false, true, false, false])
239                .unwrap();
240        let ma = MaskedArray::new(data, mask).unwrap();
241        let sorted = ma.sort().unwrap();
242        let data_vals: Vec<f64> = sorted.data().iter().copied().collect();
243        let mask_vals: Vec<bool> = sorted.mask().iter().copied().collect();
244        // Unmasked [5, 1, 2, 4] sorted = [1, 2, 4, 5], then masked [3]
245        assert_eq!(data_vals, vec![1.0, 2.0, 4.0, 5.0, 3.0]);
246        assert_eq!(mask_vals, vec![false, false, false, false, true]);
247    }
248
249    #[test]
250    fn ac8_harden_mask_prevents_clearing() {
251        let data = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
252        let mask = Array::<bool, Ix1>::from_vec(Ix1::new([3]), vec![false, true, false]).unwrap();
253        let mut ma = MaskedArray::new(data, mask).unwrap();
254
255        ma.harden_mask().unwrap();
256        assert!(ma.is_hard_mask());
257
258        // Try to clear the mask at index 1 — should be silently ignored
259        ma.set_mask_flat(1, false).unwrap();
260        let mask_vals: Vec<bool> = ma.mask().iter().copied().collect();
261        assert_eq!(mask_vals, vec![false, true, false]);
262
263        // Setting a mask bit to true should still work
264        ma.set_mask_flat(0, true).unwrap();
265        let mask_vals2: Vec<bool> = ma.mask().iter().copied().collect();
266        assert_eq!(mask_vals2, vec![true, true, false]);
267
268        // Soften and then clearing should work
269        ma.soften_mask().unwrap();
270        assert!(!ma.is_hard_mask());
271        ma.set_mask_flat(1, false).unwrap();
272        let mask_vals3: Vec<bool> = ma.mask().iter().copied().collect();
273        assert_eq!(mask_vals3, vec![true, false, false]);
274    }
275
276    // -----------------------------------------------------------------------
277    // AC-9: is_masked returns true/false correctly
278    // -----------------------------------------------------------------------
279    #[test]
280    fn ac9_is_masked() {
281        let data1 = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
282        let mask1 = Array::<bool, Ix1>::from_vec(Ix1::new([3]), vec![false, true, false]).unwrap();
283        let ma1 = MaskedArray::new(data1, mask1).unwrap();
284        assert!(is_masked(&ma1).unwrap());
285
286        let data2 = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
287        let mask2 = Array::<bool, Ix1>::from_vec(Ix1::new([3]), vec![false, false, false]).unwrap();
288        let ma2 = MaskedArray::new(data2, mask2).unwrap();
289        assert!(!is_masked(&ma2).unwrap());
290    }
291
292    // -----------------------------------------------------------------------
293    // Additional tests
294    // -----------------------------------------------------------------------
295
296    #[test]
297    fn shape_mismatch_error() {
298        let data = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
299        let mask = Array::<bool, Ix1>::from_vec(Ix1::new([2]), vec![false, true]).unwrap();
300        assert!(MaskedArray::new(data, mask).is_err());
301    }
302
303    #[test]
304    fn from_data_no_mask() {
305        let data = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
306        let ma = MaskedArray::from_data(data).unwrap();
307        assert!(!is_masked(&ma).unwrap());
308        assert_eq!(ma.count().unwrap(), 3);
309    }
310
311    #[test]
312    fn sum_skips_masked() {
313        let data = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
314        let mask =
315            Array::<bool, Ix1>::from_vec(Ix1::new([4]), vec![false, true, false, true]).unwrap();
316        let ma = MaskedArray::new(data, mask).unwrap();
317        assert!((ma.sum().unwrap() - 4.0).abs() < 1e-10);
318    }
319
320    #[test]
321    fn min_max_skip_masked() {
322        let data =
323            Array::<f64, Ix1>::from_vec(Ix1::new([5]), vec![5.0, 1.0, 3.0, 2.0, 4.0]).unwrap();
324        let mask =
325            Array::<bool, Ix1>::from_vec(Ix1::new([5]), vec![false, true, false, false, false])
326                .unwrap();
327        let ma = MaskedArray::new(data, mask).unwrap();
328        assert!((ma.min().unwrap() - 2.0).abs() < 1e-10);
329        assert!((ma.max().unwrap() - 5.0).abs() < 1e-10);
330    }
331
332    #[test]
333    fn var_std_skip_masked() {
334        // values: [2, 4, 6] (mask out index 1 and 4)
335        let data =
336            Array::<f64, Ix1>::from_vec(Ix1::new([5]), vec![2.0, 99.0, 4.0, 6.0, 99.0]).unwrap();
337        let mask =
338            Array::<bool, Ix1>::from_vec(Ix1::new([5]), vec![false, true, false, false, true])
339                .unwrap();
340        let ma = MaskedArray::new(data, mask).unwrap();
341        let mean = ma.mean().unwrap();
342        assert!((mean - 4.0).abs() < 1e-10);
343        // var = ((2-4)^2 + (4-4)^2 + (6-4)^2) / 3 = 8/3
344        let v = ma.var().unwrap();
345        assert!((v - 8.0 / 3.0).abs() < 1e-10);
346        let s = ma.std().unwrap();
347        assert!((s - (8.0_f64 / 3.0).sqrt()).abs() < 1e-10);
348    }
349
350    #[test]
351    fn count_elements() {
352        let data = Array::<f64, Ix1>::from_vec(Ix1::new([5]), vec![1.0; 5]).unwrap();
353        let mask =
354            Array::<bool, Ix1>::from_vec(Ix1::new([5]), vec![false, true, true, false, false])
355                .unwrap();
356        let ma = MaskedArray::new(data, mask).unwrap();
357        assert_eq!(ma.count().unwrap(), 3);
358    }
359
360    #[test]
361    fn masked_equal_test() {
362        let data =
363            Array::<f64, Ix1>::from_vec(Ix1::new([5]), vec![1.0, 2.0, 3.0, 2.0, 1.0]).unwrap();
364        let ma = masked_equal(&data, 2.0).unwrap();
365        let mask_vals: Vec<bool> = ma.mask().iter().copied().collect();
366        assert_eq!(mask_vals, vec![false, true, false, true, false]);
367    }
368
369    #[test]
370    fn masked_greater_test() {
371        let data = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
372        let ma = masked_greater(&data, 2.0).unwrap();
373        let mask_vals: Vec<bool> = ma.mask().iter().copied().collect();
374        assert_eq!(mask_vals, vec![false, false, true, true]);
375    }
376
377    #[test]
378    fn masked_less_test() {
379        let data = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
380        let ma = masked_less(&data, 3.0).unwrap();
381        let mask_vals: Vec<bool> = ma.mask().iter().copied().collect();
382        assert_eq!(mask_vals, vec![true, true, false, false]);
383    }
384
385    #[test]
386    fn masked_not_equal_test() {
387        let data = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
388        let ma = masked_not_equal(&data, 2.0).unwrap();
389        let mask_vals: Vec<bool> = ma.mask().iter().copied().collect();
390        assert_eq!(mask_vals, vec![true, false, true]);
391    }
392
393    #[test]
394    fn masked_greater_equal_test() {
395        let data = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
396        let ma = masked_greater_equal(&data, 3.0).unwrap();
397        let mask_vals: Vec<bool> = ma.mask().iter().copied().collect();
398        assert_eq!(mask_vals, vec![false, false, true, true]);
399    }
400
401    #[test]
402    fn masked_less_equal_test() {
403        let data = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
404        let ma = masked_less_equal(&data, 2.0).unwrap();
405        let mask_vals: Vec<bool> = ma.mask().iter().copied().collect();
406        assert_eq!(mask_vals, vec![true, true, false, false]);
407    }
408
409    #[test]
410    fn masked_inside_test() {
411        let data =
412            Array::<f64, Ix1>::from_vec(Ix1::new([5]), vec![1.0, 2.0, 3.0, 4.0, 5.0]).unwrap();
413        let ma = masked_inside(&data, 2.0, 4.0).unwrap();
414        let mask_vals: Vec<bool> = ma.mask().iter().copied().collect();
415        assert_eq!(mask_vals, vec![false, true, true, true, false]);
416    }
417
418    #[test]
419    fn masked_outside_test() {
420        let data =
421            Array::<f64, Ix1>::from_vec(Ix1::new([5]), vec![1.0, 2.0, 3.0, 4.0, 5.0]).unwrap();
422        let ma = masked_outside(&data, 2.0, 4.0).unwrap();
423        let mask_vals: Vec<bool> = ma.mask().iter().copied().collect();
424        assert_eq!(mask_vals, vec![true, false, false, false, true]);
425    }
426
427    #[test]
428    fn masked_where_test() {
429        let data = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
430        let cond =
431            Array::<bool, Ix1>::from_vec(Ix1::new([4]), vec![true, false, true, false]).unwrap();
432        let ma = masked_where(&cond, &data).unwrap();
433        let mask_vals: Vec<bool> = ma.mask().iter().copied().collect();
434        assert_eq!(mask_vals, vec![true, false, true, false]);
435    }
436
437    #[test]
438    fn argsort_test() {
439        let data =
440            Array::<f64, Ix1>::from_vec(Ix1::new([5]), vec![5.0, 1.0, 3.0, 2.0, 4.0]).unwrap();
441        let mask =
442            Array::<bool, Ix1>::from_vec(Ix1::new([5]), vec![false, false, true, false, false])
443                .unwrap();
444        let ma = MaskedArray::new(data, mask).unwrap();
445        let indices = ma.argsort().unwrap();
446        // Unmasked: index 1 (1.0), 3 (2.0), 4 (4.0), 0 (5.0); masked: 2
447        assert_eq!(indices.shape(), &[5]);
448        assert_eq!(indices.as_slice().unwrap(), &[1u64, 3, 4, 0, 2]);
449    }
450
451    #[test]
452    fn getmask_getdata_test() {
453        let data = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
454        let mask = Array::<bool, Ix1>::from_vec(Ix1::new([3]), vec![false, true, false]).unwrap();
455        let ma = MaskedArray::new(data.clone(), mask.clone()).unwrap();
456
457        let got_mask = getmask(&ma).unwrap();
458        let got_data = getdata(&ma).unwrap();
459
460        assert_eq!(got_mask.as_slice().unwrap(), mask.as_slice().unwrap());
461        assert_eq!(got_data.as_slice().unwrap(), data.as_slice().unwrap());
462    }
463
464    #[test]
465    fn count_masked_test() {
466        let data = Array::<f64, Ix1>::from_vec(Ix1::new([5]), vec![1.0; 5]).unwrap();
467        let mask =
468            Array::<bool, Ix1>::from_vec(Ix1::new([5]), vec![true, false, true, true, false])
469                .unwrap();
470        let ma = MaskedArray::new(data, mask).unwrap();
471        assert_eq!(count_masked(&ma).unwrap(), 3);
472    }
473
474    #[test]
475    fn count_masked_axis_2d_along_rows() {
476        // #268: per-row masked counts on a 2x3 array.
477        // Mask:
478        //   [[T, F, T],
479        //    [F, F, T]]
480        use ferray_core::dimension::Ix2;
481        let data = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![1.0; 6]).unwrap();
482        let mask = Array::<bool, Ix2>::from_vec(
483            Ix2::new([2, 3]),
484            vec![true, false, true, false, false, true],
485        )
486        .unwrap();
487        let ma = MaskedArray::new(data, mask).unwrap();
488        // axis=1 reduces columns -> per-row counts: [2, 1].
489        let counts = count_masked_axis(&ma, 1).unwrap();
490        assert_eq!(counts.shape(), &[2]);
491        assert_eq!(counts.as_slice().unwrap(), &[2u64, 1]);
492    }
493
494    #[test]
495    fn count_masked_axis_2d_along_cols() {
496        // axis=0 reduces rows -> per-column counts: [1, 0, 2].
497        use ferray_core::dimension::Ix2;
498        let data = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![1.0; 6]).unwrap();
499        let mask = Array::<bool, Ix2>::from_vec(
500            Ix2::new([2, 3]),
501            vec![true, false, true, false, false, true],
502        )
503        .unwrap();
504        let ma = MaskedArray::new(data, mask).unwrap();
505        let counts = count_masked_axis(&ma, 0).unwrap();
506        assert_eq!(counts.shape(), &[3]);
507        assert_eq!(counts.as_slice().unwrap(), &[1u64, 0, 2]);
508    }
509
510    #[test]
511    fn count_masked_axis_rejects_out_of_bounds() {
512        use ferray_core::dimension::Ix2;
513        let data = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![1.0; 6]).unwrap();
514        let mask = Array::<bool, Ix2>::from_vec(Ix2::new([2, 3]), vec![false; 6]).unwrap();
515        let ma = MaskedArray::new(data, mask).unwrap();
516        assert!(count_masked_axis(&ma, 2).is_err());
517    }
518
519    #[test]
520    fn sort_axis_2d_per_row() {
521        // #271: sort along axis=1 (columns) should sort each row
522        // independently. Row 0: [3, 1, _] (mask 2) → unmasked sorted
523        // ascending [1, 3] then masked → [1, 3, _].
524        // Row 1: [_, 4, 2] (mask 0) → [2, 4, _].
525        use ferray_core::dimension::Ix2;
526        let data =
527            Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![3.0, 1.0, 99.0, 99.0, 4.0, 2.0])
528                .unwrap();
529        let mask = Array::<bool, Ix2>::from_vec(
530            Ix2::new([2, 3]),
531            vec![false, false, true, true, false, false],
532        )
533        .unwrap();
534        let ma = MaskedArray::new(data, mask).unwrap();
535        let sorted = ma.sort_axis(1).unwrap();
536        assert_eq!(sorted.shape(), &[2, 3]);
537        let d: Vec<f64> = sorted.data().iter().copied().collect();
538        let m: Vec<bool> = sorted.mask().iter().copied().collect();
539        // Row 0
540        assert!((d[0] - 1.0).abs() < 1e-12);
541        assert!((d[1] - 3.0).abs() < 1e-12);
542        assert!(m[2], "row 0 col 2 should be masked");
543        // Row 1
544        assert!((d[3] - 2.0).abs() < 1e-12);
545        assert!((d[4] - 4.0).abs() < 1e-12);
546        assert!(m[5], "row 1 col 2 should be masked");
547    }
548
549    #[test]
550    fn sort_axis_2d_per_column() {
551        // axis=0 sorts each column. Column 0: [3, 1] both unmasked →
552        // [1, 3]. Column 1: [2, _] (row 1 masked) → [2, _].
553        use ferray_core::dimension::Ix2;
554        let data =
555            Array::<f64, Ix2>::from_vec(Ix2::new([2, 2]), vec![3.0, 2.0, 1.0, 99.0]).unwrap();
556        let mask = Array::<bool, Ix2>::from_vec(Ix2::new([2, 2]), vec![false, false, false, true])
557            .unwrap();
558        let ma = MaskedArray::new(data, mask).unwrap();
559        let sorted = ma.sort_axis(0).unwrap();
560        let d: Vec<f64> = sorted.data().iter().copied().collect();
561        let m: Vec<bool> = sorted.mask().iter().copied().collect();
562        // Column 0: index [0,0]=1, [1,0]=3.
563        assert!((d[0] - 1.0).abs() < 1e-12);
564        assert!((d[2] - 3.0).abs() < 1e-12);
565        // Column 1: index [0,1]=2 (unmasked), [1,1]=99 (masked).
566        assert!((d[1] - 2.0).abs() < 1e-12);
567        assert!(m[3]);
568    }
569
570    #[test]
571    fn sort_axis_rejects_out_of_bounds() {
572        use ferray_core::dimension::Ix2;
573        let data = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![1.0; 6]).unwrap();
574        let mask = Array::<bool, Ix2>::from_vec(Ix2::new([2, 3]), vec![false; 6]).unwrap();
575        let ma = MaskedArray::new(data, mask).unwrap();
576        assert!(ma.sort_axis(2).is_err());
577    }
578
579    #[test]
580    fn data_mut_only_exposes_element_slice() {
581        // #273: data_mut returns &mut [T], not &mut Array — callers can
582        // update values but cannot reshape or resize.
583        let data = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
584        let mask = Array::<bool, Ix1>::from_vec(Ix1::new([4]), vec![false; 4]).unwrap();
585        let mut ma = MaskedArray::new(data, mask).unwrap();
586        // Mutate values in place.
587        if let Some(s) = ma.data_mut() {
588            s[2] = 99.0;
589        }
590        assert_eq!(ma.shape(), &[4]);
591        let vals: Vec<f64> = ma.data().iter().copied().collect();
592        assert_eq!(vals, vec![1.0, 2.0, 99.0, 4.0]);
593        // Mask stays the same length and value.
594        assert_eq!(ma.mask().shape(), &[4]);
595    }
596
597    #[test]
598    fn masked_add_array_test() {
599        let data = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
600        let mask = Array::<bool, Ix1>::from_vec(Ix1::new([3]), vec![false, true, false]).unwrap();
601        let ma = MaskedArray::new(data, mask).unwrap();
602        let arr = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![10.0, 20.0, 30.0]).unwrap();
603        let result = masked_add_array(&ma, &arr).unwrap();
604        let data_vals: Vec<f64> = result.data().iter().copied().collect();
605        let mask_vals: Vec<bool> = result.mask().iter().copied().collect();
606        assert_eq!(mask_vals, vec![false, true, false]);
607        assert!((data_vals[0] - 11.0).abs() < 1e-10);
608        assert!((data_vals[2] - 33.0).abs() < 1e-10);
609    }
610
611    #[test]
612    fn all_masked_mean_is_nan() {
613        let data = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
614        let mask = Array::<bool, Ix1>::from_vec(Ix1::new([3]), vec![true, true, true]).unwrap();
615        let ma = MaskedArray::new(data, mask).unwrap();
616        assert!(ma.mean().unwrap().is_nan());
617    }
618
619    #[test]
620    fn all_masked_min_errors() {
621        let data = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
622        let mask = Array::<bool, Ix1>::from_vec(Ix1::new([3]), vec![true, true, true]).unwrap();
623        let ma = MaskedArray::new(data, mask).unwrap();
624        assert!(ma.min().is_err());
625    }
626
627    #[test]
628    fn ufunc_exp_masked() {
629        let data = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![0.0, 1.0, 2.0]).unwrap();
630        let mask = Array::<bool, Ix1>::from_vec(Ix1::new([3]), vec![false, true, false]).unwrap();
631        let ma = MaskedArray::new(data, mask).unwrap();
632        let result = ufunc_support::exp(&ma).unwrap();
633        let data_vals: Vec<f64> = result.data().iter().copied().collect();
634        let mask_vals: Vec<bool> = result.mask().iter().copied().collect();
635        assert_eq!(mask_vals, vec![false, true, false]);
636        assert!((data_vals[0] - 1.0).abs() < 1e-10); // exp(0) = 1
637        assert!((data_vals[2] - 2.0_f64.exp()).abs() < 1e-10);
638    }
639
640    #[test]
641    fn ufunc_sqrt_masked() {
642        let data = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![4.0, 9.0, 16.0, 25.0]).unwrap();
643        let mask =
644            Array::<bool, Ix1>::from_vec(Ix1::new([4]), vec![false, true, false, true]).unwrap();
645        let ma = MaskedArray::new(data, mask).unwrap();
646        let result = ufunc_support::sqrt(&ma).unwrap();
647        let data_vals: Vec<f64> = result.data().iter().copied().collect();
648        assert!((data_vals[0] - 2.0).abs() < 1e-10);
649        assert!((data_vals[2] - 4.0).abs() < 1e-10);
650    }
651
652    #[test]
653    fn set_mask_hardened() {
654        let data = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
655        let mask = Array::<bool, Ix1>::from_vec(Ix1::new([3]), vec![false, true, false]).unwrap();
656        let mut ma = MaskedArray::new(data, mask).unwrap();
657        ma.harden_mask().unwrap();
658
659        // set_mask with all-false should not clear the existing true
660        let new_mask =
661            Array::<bool, Ix1>::from_vec(Ix1::new([3]), vec![false, false, false]).unwrap();
662        ma.set_mask(new_mask).unwrap();
663        let mask_vals: Vec<bool> = ma.mask().iter().copied().collect();
664        // Hard mask: union of old [false, true, false] and new [false, false, false] = [false, true, false]
665        assert_eq!(mask_vals, vec![false, true, false]);
666    }
667
668    #[test]
669    fn masked_sub_test() {
670        let d1 = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![10.0, 20.0, 30.0]).unwrap();
671        let m1 = Array::<bool, Ix1>::from_vec(Ix1::new([3]), vec![false, false, true]).unwrap();
672        let ma1 = MaskedArray::new(d1, m1).unwrap();
673
674        let d2 = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
675        let m2 = Array::<bool, Ix1>::from_vec(Ix1::new([3]), vec![false, true, false]).unwrap();
676        let ma2 = MaskedArray::new(d2, m2).unwrap();
677
678        let result = masked_sub(&ma1, &ma2).unwrap();
679        let mask_vals: Vec<bool> = result.mask().iter().copied().collect();
680        assert_eq!(mask_vals, vec![false, true, true]);
681        let data_vals: Vec<f64> = result.data().iter().copied().collect();
682        assert!((data_vals[0] - 9.0).abs() < 1e-10);
683    }
684
685    #[test]
686    fn masked_mul_test() {
687        let d1 = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![2.0, 3.0, 4.0]).unwrap();
688        let m1 = Array::<bool, Ix1>::from_vec(Ix1::new([3]), vec![false, true, false]).unwrap();
689        let ma1 = MaskedArray::new(d1, m1).unwrap();
690
691        let d2 = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![5.0, 6.0, 7.0]).unwrap();
692        let m2 = Array::<bool, Ix1>::from_vec(Ix1::new([3]), vec![false, false, false]).unwrap();
693        let ma2 = MaskedArray::new(d2, m2).unwrap();
694
695        let result = masked_mul(&ma1, &ma2).unwrap();
696        let mask_vals: Vec<bool> = result.mask().iter().copied().collect();
697        assert_eq!(mask_vals, vec![false, true, false]);
698        let data_vals: Vec<f64> = result.data().iter().copied().collect();
699        assert!((data_vals[0] - 10.0).abs() < 1e-10);
700        assert!((data_vals[2] - 28.0).abs() < 1e-10);
701    }
702
703    #[test]
704    fn masked_div_test() {
705        let d1 = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![10.0, 20.0, 30.0]).unwrap();
706        let m1 = Array::<bool, Ix1>::from_vec(Ix1::new([3]), vec![false, false, true]).unwrap();
707        let ma1 = MaskedArray::new(d1, m1).unwrap();
708
709        let d2 = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![2.0, 5.0, 6.0]).unwrap();
710        let m2 = Array::<bool, Ix1>::from_vec(Ix1::new([3]), vec![false, false, false]).unwrap();
711        let ma2 = MaskedArray::new(d2, m2).unwrap();
712
713        let result = masked_div(&ma1, &ma2).unwrap();
714        let data_vals: Vec<f64> = result.data().iter().copied().collect();
715        assert!((data_vals[0] - 5.0).abs() < 1e-10);
716        assert!((data_vals[1] - 4.0).abs() < 1e-10);
717    }
718
719    #[test]
720    fn masked_invalid_negative_inf() {
721        let data =
722            Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, f64::NEG_INFINITY, 3.0]).unwrap();
723        let ma = masked_invalid(&data).unwrap();
724        let mask_vals: Vec<bool> = ma.mask().iter().copied().collect();
725        assert_eq!(mask_vals, vec![false, true, false]);
726    }
727
728    #[test]
729    fn empty_array_operations() {
730        let data = Array::<f64, Ix1>::from_vec(Ix1::new([0]), vec![]).unwrap();
731        let mask = Array::<bool, Ix1>::from_vec(Ix1::new([0]), vec![]).unwrap();
732        let ma = MaskedArray::new(data, mask).unwrap();
733        assert_eq!(ma.count().unwrap(), 0);
734        assert!(ma.mean().unwrap().is_nan());
735        let compressed = ma.compressed().unwrap();
736        assert_eq!(compressed.size(), 0);
737    }
738
739    #[test]
740    fn ndim_shape_size() {
741        let data = Array::<f64, Ix1>::from_vec(Ix1::new([5]), vec![1.0; 5]).unwrap();
742        let mask = Array::<bool, Ix1>::from_vec(Ix1::new([5]), vec![false; 5]).unwrap();
743        let ma = MaskedArray::new(data, mask).unwrap();
744        assert_eq!(ma.ndim(), 1);
745        assert_eq!(ma.shape(), &[5]);
746        assert_eq!(ma.size(), 5);
747    }
748
749    #[test]
750    fn ufunc_binary_power() {
751        let d1 = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![2.0, 3.0, 4.0]).unwrap();
752        let m1 = Array::<bool, Ix1>::from_vec(Ix1::new([3]), vec![false, true, false]).unwrap();
753        let ma1 = MaskedArray::new(d1, m1).unwrap();
754
755        let d2 = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![3.0, 2.0, 2.0]).unwrap();
756        let m2 = Array::<bool, Ix1>::from_vec(Ix1::new([3]), vec![false, false, false]).unwrap();
757        let ma2 = MaskedArray::new(d2, m2).unwrap();
758
759        let result = ufunc_support::power(&ma1, &ma2).unwrap();
760        let data_vals: Vec<f64> = result.data().iter().copied().collect();
761        let mask_vals: Vec<bool> = result.mask().iter().copied().collect();
762        assert_eq!(mask_vals, vec![false, true, false]);
763        assert!((data_vals[0] - 8.0).abs() < 1e-10); // 2^3 = 8
764        assert!((data_vals[2] - 16.0).abs() < 1e-10); // 4^2 = 16
765    }
766
767    #[test]
768    fn filled_with_custom_value() {
769        let data = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
770        let mask =
771            Array::<bool, Ix1>::from_vec(Ix1::new([4]), vec![true, false, true, false]).unwrap();
772        let ma = MaskedArray::new(data, mask).unwrap();
773        let filled = ma.filled(-999.0).unwrap();
774        assert_eq!(filled.as_slice().unwrap(), &[-999.0, 2.0, -999.0, 4.0]);
775    }
776
777    // --- 2D masked array tests ---
778
779    #[test]
780    fn masked_2d_construction() {
781        use ferray_core::dimension::Ix2;
782        let data =
783            Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
784                .unwrap();
785        let mask = Array::<bool, Ix2>::from_vec(
786            Ix2::new([2, 3]),
787            vec![false, true, false, false, false, true],
788        )
789        .unwrap();
790        let ma = MaskedArray::new(data, mask).unwrap();
791        assert_eq!(ma.ndim(), 2);
792        assert_eq!(ma.shape(), &[2, 3]);
793        assert_eq!(ma.size(), 6);
794        assert_eq!(ma.count().unwrap(), 4);
795    }
796
797    #[test]
798    fn masked_2d_mean() {
799        use ferray_core::dimension::Ix2;
800        let data =
801            Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
802                .unwrap();
803        // Mask out 2.0 and 6.0
804        let mask = Array::<bool, Ix2>::from_vec(
805            Ix2::new([2, 3]),
806            vec![false, true, false, false, false, true],
807        )
808        .unwrap();
809        let ma = MaskedArray::new(data, mask).unwrap();
810        // mean of [1, 3, 4, 5] = 13/4 = 3.25
811        let m = ma.mean().unwrap();
812        assert!((m - 3.25).abs() < 1e-10);
813    }
814
815    #[test]
816    fn masked_2d_sum() {
817        use ferray_core::dimension::Ix2;
818        let data =
819            Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
820                .unwrap();
821        let mask = Array::<bool, Ix2>::from_vec(
822            Ix2::new([2, 3]),
823            vec![false, true, false, false, false, true],
824        )
825        .unwrap();
826        let ma = MaskedArray::new(data, mask).unwrap();
827        // sum of [1, 3, 4, 5] = 13
828        assert!((ma.sum().unwrap() - 13.0).abs() < 1e-10);
829    }
830
831    #[test]
832    fn masked_2d_add_operator() {
833        use ferray_core::dimension::Ix2;
834        let d1 = Array::<f64, Ix2>::from_vec(Ix2::new([2, 2]), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
835        let m1 = Array::<bool, Ix2>::from_vec(Ix2::new([2, 2]), vec![false, true, false, false])
836            .unwrap();
837        let ma1 = MaskedArray::new(d1, m1).unwrap();
838
839        let d2 =
840            Array::<f64, Ix2>::from_vec(Ix2::new([2, 2]), vec![10.0, 20.0, 30.0, 40.0]).unwrap();
841        let m2 = Array::<bool, Ix2>::from_vec(Ix2::new([2, 2]), vec![false, false, true, false])
842            .unwrap();
843        let ma2 = MaskedArray::new(d2, m2).unwrap();
844
845        let result = (&ma1 + &ma2).unwrap();
846        let mask_vals: Vec<bool> = result.mask().iter().copied().collect();
847        assert_eq!(mask_vals, vec![false, true, true, false]);
848        let data_vals: Vec<f64> = result.data().iter().copied().collect();
849        assert!((data_vals[0] - 11.0).abs() < 1e-10);
850        assert!((data_vals[3] - 44.0).abs() < 1e-10);
851    }
852
853    #[test]
854    fn masked_2d_compressed() {
855        use ferray_core::dimension::Ix2;
856        let data = Array::<f64, Ix2>::from_vec(Ix2::new([2, 2]), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
857        let mask =
858            Array::<bool, Ix2>::from_vec(Ix2::new([2, 2]), vec![false, true, false, true]).unwrap();
859        let ma = MaskedArray::new(data, mask).unwrap();
860        let compressed = ma.compressed().unwrap();
861        assert_eq!(compressed.as_slice().unwrap(), &[1.0, 3.0]);
862    }
863}