Skip to main content

ferray_ma/
reductions.rs

1// ferray-ma: Masked reductions (REQ-4)
2//
3// mean, sum, min, max, var, std, count — all skip masked elements.
4//
5// Whole-array variants (`sum`, `mean`, ...) return a scalar `T`.
6// Per-axis variants (`sum_axis`, `mean_axis`, ...) return a `MaskedArray<T, IxDyn>`
7// where each output position holds the reduction of one lane along the
8// chosen axis. Lanes that contain only masked elements are themselves
9// masked in the output (and hold the source array's `fill_value`).
10
11use ferray_core::Array;
12use ferray_core::dimension::{Dimension, IxDyn};
13use ferray_core::dtype::Element;
14use ferray_core::error::{FerrayError, FerrayResult};
15use num_traits::Float;
16
17use crate::MaskedArray;
18
19// ---------------------------------------------------------------------------
20// Internal helpers for axis-aware reductions
21// ---------------------------------------------------------------------------
22
23/// Compute row-major strides for `shape`.
24fn compute_strides(shape: &[usize]) -> Vec<usize> {
25    let n = shape.len();
26    let mut s = vec![1usize; n];
27    for i in (0..n.saturating_sub(1)).rev() {
28        s[i] = s[i + 1] * shape[i + 1];
29    }
30    s
31}
32
33/// Increment a multi-index in row-major order. Returns false on overflow.
34fn increment_multi(multi: &mut [usize], shape: &[usize]) -> bool {
35    for d in (0..multi.len()).rev() {
36        multi[d] += 1;
37        if multi[d] < shape[d] {
38            return true;
39        }
40        multi[d] = 0;
41    }
42    false
43}
44
45/// Apply a per-lane masked reduction along `axis`.
46///
47/// `kernel` receives a `&[(T, bool)]` slice (data + mask) for each lane and
48/// returns either `Some(value)` (the reduction result) or `None` if every
49/// element in the lane was masked. Masked output positions are filled with
50/// `fill_value`.
51fn reduce_axis<T, D, F>(
52    ma: &MaskedArray<T, D>,
53    axis: usize,
54    fill_value: T,
55    kernel: F,
56) -> FerrayResult<MaskedArray<T, IxDyn>>
57where
58    T: Element + Copy,
59    D: Dimension,
60    F: Fn(&[(T, bool)]) -> Option<T>,
61{
62    let ndim = ma.ndim();
63    if axis >= ndim {
64        return Err(FerrayError::axis_out_of_bounds(axis, ndim));
65    }
66    let shape = ma.shape();
67    let axis_len = shape[axis];
68
69    // Output shape: drop the reduced axis.
70    let out_shape: Vec<usize> = shape
71        .iter()
72        .enumerate()
73        .filter_map(|(i, &s)| if i == axis { None } else { Some(s) })
74        .collect();
75    let out_size: usize = if out_shape.is_empty() {
76        1
77    } else {
78        out_shape.iter().product()
79    };
80
81    // Materialize source data + mask in row-major order so we can index by
82    // computed flat indices regardless of the source memory layout.
83    let src_data: Vec<T> = ma.data().iter().copied().collect();
84    let src_mask: Vec<bool> = ma.mask().iter().copied().collect();
85    let strides = compute_strides(shape);
86
87    let mut out_data = Vec::with_capacity(out_size);
88    let mut out_mask = Vec::with_capacity(out_size);
89    let mut out_multi = vec![0usize; out_shape.len()];
90    let mut in_multi = vec![0usize; ndim];
91    let mut lane: Vec<(T, bool)> = Vec::with_capacity(axis_len);
92
93    for _ in 0..out_size {
94        // Map output multi-index back into the input multi-index by inserting
95        // a placeholder at `axis`.
96        let mut out_dim = 0;
97        for (d, idx) in in_multi.iter_mut().enumerate() {
98            if d == axis {
99                *idx = 0;
100            } else {
101                *idx = out_multi[out_dim];
102                out_dim += 1;
103            }
104        }
105
106        lane.clear();
107        for k in 0..axis_len {
108            in_multi[axis] = k;
109            let flat = in_multi
110                .iter()
111                .zip(strides.iter())
112                .map(|(i, s)| i * s)
113                .sum::<usize>();
114            lane.push((src_data[flat], src_mask[flat]));
115        }
116
117        if let Some(value) = kernel(&lane) {
118            out_data.push(value);
119            out_mask.push(false);
120        } else {
121            out_data.push(fill_value);
122            out_mask.push(true);
123        }
124
125        if !out_shape.is_empty() {
126            increment_multi(&mut out_multi, &out_shape);
127        }
128    }
129
130    let data_arr = Array::<T, IxDyn>::from_vec(IxDyn::new(&out_shape), out_data)?;
131    let mask_arr = Array::<bool, IxDyn>::from_vec(IxDyn::new(&out_shape), out_mask)?;
132    let mut result = MaskedArray::new(data_arr, mask_arr)?;
133    result.set_fill_value(fill_value);
134    Ok(result)
135}
136
137/// Per-axis count of unmasked elements. Returns a plain `Array<u64, IxDyn>`
138/// (not masked, since count is always defined) with the reduced axis dropped.
139fn count_axis<T, D>(ma: &MaskedArray<T, D>, axis: usize) -> FerrayResult<Array<u64, IxDyn>>
140where
141    T: Element + Copy,
142    D: Dimension,
143{
144    let ndim = ma.ndim();
145    if axis >= ndim {
146        return Err(FerrayError::axis_out_of_bounds(axis, ndim));
147    }
148    let shape = ma.shape();
149    let axis_len = shape[axis];
150    let out_shape: Vec<usize> = shape
151        .iter()
152        .enumerate()
153        .filter_map(|(i, &s)| if i == axis { None } else { Some(s) })
154        .collect();
155    let out_size: usize = if out_shape.is_empty() {
156        1
157    } else {
158        out_shape.iter().product()
159    };
160
161    let src_mask: Vec<bool> = ma.mask().iter().copied().collect();
162    let strides = compute_strides(shape);
163    let mut out: Vec<u64> = Vec::with_capacity(out_size);
164    let mut out_multi = vec![0usize; out_shape.len()];
165    let mut in_multi = vec![0usize; ndim];
166
167    for _ in 0..out_size {
168        let mut out_dim = 0;
169        for (d, idx) in in_multi.iter_mut().enumerate() {
170            if d == axis {
171                *idx = 0;
172            } else {
173                *idx = out_multi[out_dim];
174                out_dim += 1;
175            }
176        }
177
178        let mut count: u64 = 0;
179        for k in 0..axis_len {
180            in_multi[axis] = k;
181            let flat = in_multi
182                .iter()
183                .zip(strides.iter())
184                .map(|(i, s)| i * s)
185                .sum::<usize>();
186            if !src_mask[flat] {
187                count += 1;
188            }
189        }
190        out.push(count);
191
192        if !out_shape.is_empty() {
193            increment_multi(&mut out_multi, &out_shape);
194        }
195    }
196
197    Array::<u64, IxDyn>::from_vec(IxDyn::new(&out_shape), out)
198}
199
200impl<T, D> MaskedArray<T, D>
201where
202    T: Element + Copy,
203    D: Dimension,
204{
205    /// Count the number of unmasked (valid) elements.
206    ///
207    /// # Errors
208    /// This function does not currently error but returns `Result` for API
209    /// consistency.
210    pub fn count(&self) -> FerrayResult<usize> {
211        let n = self
212            .data()
213            .iter()
214            .zip(self.mask().iter())
215            .filter(|(_, m)| !**m)
216            .count();
217        Ok(n)
218    }
219
220    /// Count the number of unmasked elements per lane along `axis`.
221    ///
222    /// Returns a plain `Array<u64, IxDyn>` (not masked, since count is
223    /// always defined) with the reduced axis removed.
224    ///
225    /// # Errors
226    /// Returns `FerrayError::AxisOutOfBounds` if `axis >= ndim`.
227    pub fn count_axis(&self, axis: usize) -> FerrayResult<Array<u64, IxDyn>> {
228        count_axis(self, axis)
229    }
230}
231
232impl<T, D> MaskedArray<T, D>
233where
234    T: Element + Float,
235    D: Dimension,
236{
237    /// Compute the sum of unmasked elements.
238    ///
239    /// Returns zero if all elements are masked.
240    ///
241    /// # Errors
242    /// Returns an error only for internal failures.
243    pub fn sum(&self) -> FerrayResult<T> {
244        let zero = num_traits::zero::<T>();
245        let s = self
246            .data()
247            .iter()
248            .zip(self.mask().iter())
249            .filter(|(_, m)| !**m)
250            .fold(zero, |acc, (v, _)| acc + *v);
251        Ok(s)
252    }
253
254    /// Compute the mean of unmasked elements.
255    ///
256    /// Returns `NaN` if no elements are unmasked.
257    ///
258    /// # Errors
259    /// Returns an error only for internal failures.
260    pub fn mean(&self) -> FerrayResult<T> {
261        let zero = num_traits::zero::<T>();
262        let one: T = num_traits::one();
263        let (sum, count) = self
264            .data()
265            .iter()
266            .zip(self.mask().iter())
267            .filter(|(_, m)| !**m)
268            .fold((zero, 0usize), |(s, c), (v, _)| (s + *v, c + 1));
269        if count == 0 {
270            return Ok(T::nan());
271        }
272        Ok(sum / T::from(count).unwrap_or(one))
273    }
274
275    /// Compute the minimum of unmasked elements.
276    ///
277    /// NaN values in unmasked elements are propagated (returns NaN), matching `NumPy`.
278    ///
279    /// # Errors
280    /// Returns `FerrayError::InvalidValue` if no elements are unmasked.
281    pub fn min(&self) -> FerrayResult<T> {
282        self.data()
283            .iter()
284            .zip(self.mask().iter())
285            .filter(|(_, m)| !**m)
286            .map(|(v, _)| *v)
287            .fold(None, |acc: Option<T>, v| {
288                Some(match acc {
289                    Some(a) => {
290                        // NaN-propagating: if comparison is unordered, propagate NaN
291                        if a <= v {
292                            a
293                        } else if a > v {
294                            v
295                        } else {
296                            a
297                        }
298                    }
299                    None => v,
300                })
301            })
302            .ok_or_else(|| FerrayError::invalid_value("min: all elements are masked"))
303    }
304
305    /// Compute the maximum of unmasked elements.
306    ///
307    /// NaN values in unmasked elements are propagated (returns NaN), matching `NumPy`.
308    ///
309    /// # Errors
310    /// Returns `FerrayError::InvalidValue` if no elements are unmasked.
311    pub fn max(&self) -> FerrayResult<T> {
312        self.data()
313            .iter()
314            .zip(self.mask().iter())
315            .filter(|(_, m)| !**m)
316            .map(|(v, _)| *v)
317            .fold(None, |acc: Option<T>, v| {
318                Some(match acc {
319                    Some(a) => {
320                        if a >= v {
321                            a
322                        } else if a < v {
323                            v
324                        } else {
325                            a
326                        }
327                    }
328                    None => v,
329                })
330            })
331            .ok_or_else(|| FerrayError::invalid_value("max: all elements are masked"))
332    }
333
334    /// Compute the variance of unmasked elements (population variance, ddof=0).
335    ///
336    /// Returns `NaN` if no elements are unmasked.
337    ///
338    /// # Errors
339    /// Returns an error only for internal failures.
340    pub fn var(&self) -> FerrayResult<T> {
341        let mean = self.mean()?;
342        if mean.is_nan() {
343            return Ok(T::nan());
344        }
345        let zero = num_traits::zero::<T>();
346        let one: T = num_traits::one();
347        let (sum_sq, count) = self
348            .data()
349            .iter()
350            .zip(self.mask().iter())
351            .filter(|(_, m)| !**m)
352            .fold((zero, 0usize), |(s, c), (v, _)| {
353                let d = *v - mean;
354                (s + d * d, c + 1)
355            });
356        if count == 0 {
357            return Ok(T::nan());
358        }
359        Ok(sum_sq / T::from(count).unwrap_or(one))
360    }
361
362    /// Compute the standard deviation of unmasked elements (population, ddof=0).
363    ///
364    /// Returns `NaN` if no elements are unmasked.
365    ///
366    /// # Errors
367    /// Returns an error only for internal failures.
368    pub fn std(&self) -> FerrayResult<T> {
369        Ok(self.var()?.sqrt())
370    }
371
372    // -----------------------------------------------------------------------
373    // Per-axis reductions (issue #500)
374    //
375    // Each lane along `axis` is reduced independently. Lanes containing only
376    // masked elements produce a masked output position holding `fill_value`.
377    // -----------------------------------------------------------------------
378
379    /// Sum unmasked elements along `axis`. Returns a masked array with the
380    /// reduced axis removed; lanes that are entirely masked produce a
381    /// masked output position holding `fill_value`.
382    pub fn sum_axis(&self, axis: usize) -> FerrayResult<MaskedArray<T, IxDyn>> {
383        let zero = num_traits::zero::<T>();
384        let fill = self.fill_value();
385        reduce_axis(self, axis, fill, |lane| {
386            let mut acc = zero;
387            let mut any = false;
388            for &(v, m) in lane {
389                if !m {
390                    acc = acc + v;
391                    any = true;
392                }
393            }
394            if any { Some(acc) } else { None }
395        })
396    }
397
398    /// Mean of unmasked elements along `axis`. All-masked lanes are masked
399    /// in the output.
400    pub fn mean_axis(&self, axis: usize) -> FerrayResult<MaskedArray<T, IxDyn>> {
401        let zero = num_traits::zero::<T>();
402        let fill = self.fill_value();
403        reduce_axis(self, axis, fill, |lane| {
404            let mut acc = zero;
405            let mut count = 0usize;
406            for &(v, m) in lane {
407                if !m {
408                    acc = acc + v;
409                    count += 1;
410                }
411            }
412            if count == 0 {
413                None
414            } else {
415                Some(acc / T::from(count).unwrap_or_else(|| num_traits::one()))
416            }
417        })
418    }
419
420    /// Min of unmasked elements along `axis`. NaN-propagating per `NumPy`.
421    pub fn min_axis(&self, axis: usize) -> FerrayResult<MaskedArray<T, IxDyn>> {
422        let fill = self.fill_value();
423        reduce_axis(self, axis, fill, |lane| {
424            let mut acc: Option<T> = None;
425            for &(v, m) in lane {
426                if !m {
427                    acc = Some(match acc {
428                        Some(a) => {
429                            // NaN-propagating: if comparison is unordered, return NaN
430                            if a <= v {
431                                a
432                            } else if a > v {
433                                v
434                            } else {
435                                a
436                            }
437                        }
438                        None => v,
439                    });
440                }
441            }
442            acc
443        })
444    }
445
446    /// Max of unmasked elements along `axis`. NaN-propagating per `NumPy`.
447    pub fn max_axis(&self, axis: usize) -> FerrayResult<MaskedArray<T, IxDyn>> {
448        let fill = self.fill_value();
449        reduce_axis(self, axis, fill, |lane| {
450            let mut acc: Option<T> = None;
451            for &(v, m) in lane {
452                if !m {
453                    acc = Some(match acc {
454                        Some(a) => {
455                            if a >= v {
456                                a
457                            } else if a < v {
458                                v
459                            } else {
460                                a
461                            }
462                        }
463                        None => v,
464                    });
465                }
466            }
467            acc
468        })
469    }
470
471    /// Population variance (ddof=0) of unmasked elements along `axis`.
472    pub fn var_axis(&self, axis: usize) -> FerrayResult<MaskedArray<T, IxDyn>> {
473        let zero = num_traits::zero::<T>();
474        let fill = self.fill_value();
475        reduce_axis(self, axis, fill, |lane| {
476            let mut acc = zero;
477            let mut count = 0usize;
478            for &(v, m) in lane {
479                if !m {
480                    acc = acc + v;
481                    count += 1;
482                }
483            }
484            if count == 0 {
485                return None;
486            }
487            let n = T::from(count).unwrap_or_else(|| num_traits::one());
488            let mean = acc / n;
489            let mut sum_sq = zero;
490            for &(v, m) in lane {
491                if !m {
492                    let d = v - mean;
493                    sum_sq = sum_sq + d * d;
494                }
495            }
496            Some(sum_sq / n)
497        })
498    }
499
500    /// Population standard deviation (ddof=0) of unmasked elements along `axis`.
501    pub fn std_axis(&self, axis: usize) -> FerrayResult<MaskedArray<T, IxDyn>> {
502        let result = self.var_axis(axis)?;
503        // Take sqrt of unmasked positions; masked positions stay masked.
504        let fill = self.fill_value();
505        let mask = result.mask().clone();
506        let new_data: Vec<T> = result
507            .data()
508            .iter()
509            .zip(result.mask().iter())
510            .map(|(v, m)| if *m { fill } else { v.sqrt() })
511            .collect();
512        let data_arr = Array::<T, IxDyn>::from_vec(IxDyn::new(result.shape()), new_data)?;
513        let mut out = MaskedArray::new(data_arr, mask)?;
514        out.set_fill_value(fill);
515        Ok(out)
516    }
517}
518
519#[cfg(test)]
520mod tests {
521    use super::*;
522    use ferray_core::dimension::{Ix1, Ix2};
523
524    fn ma2d(rows: usize, cols: usize, data: Vec<f64>, mask: Vec<bool>) -> MaskedArray<f64, Ix2> {
525        let d = Array::<f64, Ix2>::from_vec(Ix2::new([rows, cols]), data).unwrap();
526        let m = Array::<bool, Ix2>::from_vec(Ix2::new([rows, cols]), mask).unwrap();
527        MaskedArray::new(d, m).unwrap()
528    }
529
530    // ---- #500: per-axis reductions ----
531
532    #[test]
533    fn sum_axis_drops_axis() {
534        // 2x3 array, no masks. axis=0 sums columns, axis=1 sums rows.
535        let ma = ma2d(2, 3, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![false; 6]);
536        let s0 = ma.sum_axis(0).unwrap();
537        assert_eq!(s0.shape(), &[3]);
538        let d0: Vec<f64> = s0.data().iter().copied().collect();
539        assert_eq!(d0, vec![5.0, 7.0, 9.0]);
540
541        let s1 = ma.sum_axis(1).unwrap();
542        assert_eq!(s1.shape(), &[2]);
543        let d1: Vec<f64> = s1.data().iter().copied().collect();
544        assert_eq!(d1, vec![6.0, 15.0]);
545    }
546
547    #[test]
548    fn sum_axis_skips_masked() {
549        // 2x3 array. Mask out (0, 1) so column 1 and row 0 lose one element.
550        let ma = ma2d(
551            2,
552            3,
553            vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
554            vec![false, true, false, false, false, false],
555        );
556        // axis=0 (per-column): col0 = 1+4=5, col1 = 5 (only row 1), col2 = 3+6=9
557        let s0 = ma.sum_axis(0).unwrap();
558        let d0: Vec<f64> = s0.data().iter().copied().collect();
559        assert_eq!(d0, vec![5.0, 5.0, 9.0]);
560        let m0: Vec<bool> = s0.mask().iter().copied().collect();
561        assert_eq!(m0, vec![false, false, false]);
562    }
563
564    #[test]
565    fn sum_axis_all_masked_lane_is_masked() {
566        // 2x3 array, mask out entire column 1.
567        let ma = ma2d(
568            2,
569            3,
570            vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
571            vec![false, true, false, false, true, false],
572        );
573        let s0 = ma.sum_axis(0).unwrap();
574        let m0: Vec<bool> = s0.mask().iter().copied().collect();
575        assert_eq!(m0, vec![false, true, false]);
576    }
577
578    #[test]
579    fn mean_axis_skips_masked() {
580        let ma = ma2d(
581            2,
582            3,
583            vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
584            vec![false, true, false, false, false, false],
585        );
586        // mean axis 1 (per-row): row 0 = (1+3)/2 = 2.0, row 1 = (4+5+6)/3 = 5.0
587        let m1 = ma.mean_axis(1).unwrap();
588        let d: Vec<f64> = m1.data().iter().copied().collect();
589        assert_eq!(d, vec![2.0, 5.0]);
590    }
591
592    #[test]
593    fn min_max_axis() {
594        let ma = ma2d(2, 3, vec![3.0, 1.0, 5.0, 2.0, 4.0, 0.0], vec![false; 6]);
595        let mn = ma.min_axis(0).unwrap();
596        let mx = ma.max_axis(0).unwrap();
597        let mn_d: Vec<f64> = mn.data().iter().copied().collect();
598        let mx_d: Vec<f64> = mx.data().iter().copied().collect();
599        assert_eq!(mn_d, vec![2.0, 1.0, 0.0]);
600        assert_eq!(mx_d, vec![3.0, 4.0, 5.0]);
601    }
602
603    #[test]
604    fn count_axis_basic() {
605        // 2x3 array. Mask: (0,1) and (1,2). col0:2, col1:1, col2:1.
606        let ma = ma2d(
607            2,
608            3,
609            vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
610            vec![false, true, false, false, false, true],
611        );
612        let c0 = ma.count_axis(0).unwrap();
613        let v: Vec<u64> = c0.iter().copied().collect();
614        assert_eq!(v, vec![2u64, 1, 1]);
615    }
616
617    #[test]
618    fn axis_out_of_bounds_errors() {
619        let ma = ma2d(2, 3, vec![0.0; 6], vec![false; 6]);
620        assert!(ma.sum_axis(2).is_err());
621    }
622
623    #[test]
624    fn var_std_axis() {
625        // Two rows of [1, 2, 3, 4, 5] — variance along axis 1 should be 2.0 each.
626        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 1.0, 2.0, 3.0, 4.0, 5.0];
627        let ma = ma2d(2, 5, data, vec![false; 10]);
628        let v = ma.var_axis(1).unwrap();
629        let s = ma.std_axis(1).unwrap();
630        let v_d: Vec<f64> = v.data().iter().copied().collect();
631        let s_d: Vec<f64> = s.data().iter().copied().collect();
632        for &x in &v_d {
633            assert!((x - 2.0).abs() < 1e-12);
634        }
635        for &x in &s_d {
636            assert!((x - 2.0_f64.sqrt()).abs() < 1e-12);
637        }
638    }
639
640    // ---- #501: fill_value ----
641
642    #[test]
643    fn fill_value_default_is_zero() {
644        let d = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
645        let m = Array::<bool, Ix1>::from_vec(Ix1::new([3]), vec![false; 3]).unwrap();
646        let ma = MaskedArray::new(d, m).unwrap();
647        assert_eq!(ma.fill_value(), 0.0);
648    }
649
650    #[test]
651    fn with_fill_value_sets_field() {
652        let d = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
653        let m = Array::<bool, Ix1>::from_vec(Ix1::new([3]), vec![false; 3]).unwrap();
654        let ma = MaskedArray::new(d, m).unwrap().with_fill_value(99.0);
655        assert_eq!(ma.fill_value(), 99.0);
656    }
657
658    #[test]
659    fn filled_default_uses_stored_fill_value() {
660        let d = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
661        let m =
662            Array::<bool, Ix1>::from_vec(Ix1::new([4]), vec![false, true, false, true]).unwrap();
663        let ma = MaskedArray::new(d, m).unwrap().with_fill_value(-1.0);
664        let filled = ma.filled_default().unwrap();
665        let v: Vec<f64> = filled.iter().copied().collect();
666        assert_eq!(v, vec![1.0, -1.0, 3.0, -1.0]);
667    }
668
669    #[test]
670    fn arithmetic_uses_fill_value() {
671        // (Adding two masked arrays) — result data at masked positions should
672        // be the receiver's fill_value, not zero.
673        use crate::masked_add;
674        let d_a = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
675        let m_a = Array::<bool, Ix1>::from_vec(Ix1::new([3]), vec![false, true, false]).unwrap();
676        let d_b = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![10.0, 20.0, 30.0]).unwrap();
677        let m_b = Array::<bool, Ix1>::from_vec(Ix1::new([3]), vec![false; 3]).unwrap();
678        let a = MaskedArray::new(d_a, m_a).unwrap().with_fill_value(-999.0);
679        let b = MaskedArray::new(d_b, m_b).unwrap();
680        let r = masked_add(&a, &b).unwrap();
681        let r_d: Vec<f64> = r.data().iter().copied().collect();
682        assert_eq!(r_d, vec![11.0, -999.0, 33.0]);
683        assert_eq!(r.fill_value(), -999.0);
684    }
685
686    // ---- #504: broadcasting in masked arithmetic ----
687
688    #[test]
689    fn masked_add_broadcasts_within_same_rank() {
690        use crate::masked_add;
691        // (3, 1) + (1, 4) -> (3, 4) — both Ix2.
692        let d_a = Array::<f64, Ix2>::from_vec(Ix2::new([3, 1]), vec![1.0, 2.0, 3.0]).unwrap();
693        let m_a = Array::<bool, Ix2>::from_vec(Ix2::new([3, 1]), vec![false; 3]).unwrap();
694        let d_b =
695            Array::<f64, Ix2>::from_vec(Ix2::new([1, 4]), vec![10.0, 20.0, 30.0, 40.0]).unwrap();
696        let m_b = Array::<bool, Ix2>::from_vec(Ix2::new([1, 4]), vec![false; 4]).unwrap();
697        let a = MaskedArray::new(d_a, m_a).unwrap();
698        let b = MaskedArray::new(d_b, m_b).unwrap();
699        let r = masked_add(&a, &b).unwrap();
700        assert_eq!(r.shape(), &[3, 4]);
701        let r_d: Vec<f64> = r.data().iter().copied().collect();
702        assert_eq!(
703            r_d,
704            vec![
705                11.0, 21.0, 31.0, 41.0, // row 0 = 1 + {10,20,30,40}
706                12.0, 22.0, 32.0, 42.0, // row 1 = 2 + ...
707                13.0, 23.0, 33.0, 43.0, // row 2 = 3 + ...
708            ]
709        );
710        let r_m: Vec<bool> = r.mask().iter().copied().collect();
711        assert_eq!(r_m, vec![false; 12]);
712    }
713
714    #[test]
715    fn masked_sub_broadcasts_with_mask_union() {
716        use crate::masked_sub;
717        // Mask one element in `a`. After broadcasting (3,1) -> (3,4),
718        // the masked row becomes a fully-masked row in the result.
719        let d_a = Array::<f64, Ix2>::from_vec(Ix2::new([3, 1]), vec![10.0, 20.0, 30.0]).unwrap();
720        let m_a = Array::<bool, Ix2>::from_vec(Ix2::new([3, 1]), vec![false, true, false]).unwrap();
721        let d_b = Array::<f64, Ix2>::from_vec(Ix2::new([1, 4]), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
722        let m_b = Array::<bool, Ix2>::from_vec(Ix2::new([1, 4]), vec![false; 4]).unwrap();
723        let a = MaskedArray::new(d_a, m_a).unwrap();
724        let b = MaskedArray::new(d_b, m_b).unwrap();
725        let r = masked_sub(&a, &b).unwrap();
726        let r_m: Vec<bool> = r.mask().iter().copied().collect();
727        // Row 1 is fully masked (4 positions).
728        assert_eq!(
729            r_m,
730            vec![
731                false, false, false, false, // row 0
732                true, true, true, true, // row 1 (masked)
733                false, false, false, false, // row 2
734            ]
735        );
736    }
737
738    // ---- #276: all-masked whole-array reductions ----
739    //
740    // Pin the current semantics when every element is masked:
741    //   sum   -> 0 (neutral element of the fold)
742    //   mean  -> NaN
743    //   var   -> NaN
744    //   std   -> NaN
745    //   min   -> error (FerrayError::InvalidValue)
746    //   max   -> error (FerrayError::InvalidValue)
747
748    fn all_masked_ma1d(n: usize) -> MaskedArray<f64, Ix1> {
749        let d = Array::<f64, Ix1>::from_vec(Ix1::new([n]), vec![1.0; n]).unwrap();
750        let m = Array::<bool, Ix1>::from_vec(Ix1::new([n]), vec![true; n]).unwrap();
751        MaskedArray::new(d, m).unwrap()
752    }
753
754    #[test]
755    fn sum_all_masked_returns_zero() {
756        let ma = all_masked_ma1d(4);
757        assert_eq!(ma.sum().unwrap(), 0.0);
758    }
759
760    #[test]
761    fn mean_all_masked_returns_nan() {
762        let ma = all_masked_ma1d(4);
763        assert!(ma.mean().unwrap().is_nan());
764    }
765
766    #[test]
767    fn var_all_masked_returns_nan() {
768        let ma = all_masked_ma1d(4);
769        assert!(ma.var().unwrap().is_nan());
770    }
771
772    #[test]
773    fn std_all_masked_returns_nan() {
774        let ma = all_masked_ma1d(4);
775        assert!(ma.std().unwrap().is_nan());
776    }
777
778    #[test]
779    fn min_max_all_masked_error() {
780        // Documenting the asymmetry: sum/mean/var/std fall through with
781        // 0/NaN sentinels, but min/max have no neutral element and error.
782        let ma = all_masked_ma1d(4);
783        assert!(ma.min().is_err());
784        assert!(ma.max().is_err());
785    }
786
787    #[test]
788    fn sum_var_std_all_masked_2d_matches_1d() {
789        // Same semantics hold for multi-dimensional whole-array reductions.
790        let d = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![9.0; 6]).unwrap();
791        let m = Array::<bool, Ix2>::from_vec(Ix2::new([2, 3]), vec![true; 6]).unwrap();
792        let ma = MaskedArray::new(d, m).unwrap();
793        assert_eq!(ma.sum().unwrap(), 0.0);
794        assert!(ma.var().unwrap().is_nan());
795        assert!(ma.std().unwrap().is_nan());
796    }
797
798    #[test]
799    fn masked_add_broadcast_incompatible_errors() {
800        use crate::masked_add;
801        let d_a = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
802        let m_a = Array::<bool, Ix1>::from_vec(Ix1::new([3]), vec![false; 3]).unwrap();
803        let d_b = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
804        let m_b = Array::<bool, Ix1>::from_vec(Ix1::new([4]), vec![false; 4]).unwrap();
805        let a = MaskedArray::new(d_a, m_a).unwrap();
806        let b = MaskedArray::new(d_b, m_b).unwrap();
807        assert!(masked_add(&a, &b).is_err());
808    }
809}