Skip to main content

ferray_ma/
reductions.rs

1// ferray-ma: Masked reductions (REQ-4)
2//
3// mean, sum, min, max, var, std, count — all skip masked elements.
4//
5// Whole-array variants (`sum`, `mean`, ...) return a scalar `T`.
6// Per-axis variants (`sum_axis`, `mean_axis`, ...) return a `MaskedArray<T, IxDyn>`
7// where each output position holds the reduction of one lane along the
8// chosen axis. Lanes that contain only masked elements are themselves
9// masked in the output (and hold the source array's `fill_value`).
10
11use ferray_core::Array;
12use ferray_core::dimension::{Dimension, IxDyn};
13use ferray_core::dtype::Element;
14use ferray_core::error::{FerrayError, FerrayResult};
15use num_traits::Float;
16
17use crate::MaskedArray;
18
19// ---------------------------------------------------------------------------
20// Internal helpers for axis-aware reductions
21// ---------------------------------------------------------------------------
22
23/// Compute row-major strides for `shape`.
24fn compute_strides(shape: &[usize]) -> Vec<usize> {
25    let n = shape.len();
26    let mut s = vec![1usize; n];
27    for i in (0..n.saturating_sub(1)).rev() {
28        s[i] = s[i + 1] * shape[i + 1];
29    }
30    s
31}
32
33/// Increment a multi-index in row-major order. Returns false on overflow.
34fn increment_multi(multi: &mut [usize], shape: &[usize]) -> bool {
35    for d in (0..multi.len()).rev() {
36        multi[d] += 1;
37        if multi[d] < shape[d] {
38            return true;
39        }
40        multi[d] = 0;
41    }
42    false
43}
44
45/// Apply a per-lane masked reduction along `axis`.
46///
47/// `kernel` receives a `&[(T, bool)]` slice (data + mask) for each lane and
48/// returns either `Some(value)` (the reduction result) or `None` if every
49/// element in the lane was masked. Masked output positions are filled with
50/// `fill_value`.
51fn reduce_axis<T, D, F>(
52    ma: &MaskedArray<T, D>,
53    axis: usize,
54    fill_value: T,
55    kernel: F,
56) -> FerrayResult<MaskedArray<T, IxDyn>>
57where
58    T: Element + Copy,
59    D: Dimension,
60    F: Fn(&[(T, bool)]) -> Option<T>,
61{
62    let ndim = ma.ndim();
63    if axis >= ndim {
64        return Err(FerrayError::axis_out_of_bounds(axis, ndim));
65    }
66    let shape = ma.shape();
67    let axis_len = shape[axis];
68
69    // Output shape: drop the reduced axis.
70    let out_shape: Vec<usize> = shape
71        .iter()
72        .enumerate()
73        .filter_map(|(i, &s)| if i == axis { None } else { Some(s) })
74        .collect();
75    let out_size: usize = if out_shape.is_empty() {
76        1
77    } else {
78        out_shape.iter().product()
79    };
80
81    // Materialize source data + mask in row-major order so we can index by
82    // computed flat indices regardless of the source memory layout.
83    let src_data: Vec<T> = ma.data().iter().copied().collect();
84    let src_mask: Vec<bool> = ma.mask().iter().copied().collect();
85    let strides = compute_strides(shape);
86
87    let mut out_data = Vec::with_capacity(out_size);
88    let mut out_mask = Vec::with_capacity(out_size);
89    let mut out_multi = vec![0usize; out_shape.len()];
90    let mut in_multi = vec![0usize; ndim];
91    let mut lane: Vec<(T, bool)> = Vec::with_capacity(axis_len);
92
93    for _ in 0..out_size {
94        // Map output multi-index back into the input multi-index by inserting
95        // a placeholder at `axis`.
96        let mut out_dim = 0;
97        for (d, idx) in in_multi.iter_mut().enumerate() {
98            if d == axis {
99                *idx = 0;
100            } else {
101                *idx = out_multi[out_dim];
102                out_dim += 1;
103            }
104        }
105
106        lane.clear();
107        for k in 0..axis_len {
108            in_multi[axis] = k;
109            let flat = in_multi
110                .iter()
111                .zip(strides.iter())
112                .map(|(i, s)| i * s)
113                .sum::<usize>();
114            lane.push((src_data[flat], src_mask[flat]));
115        }
116
117        match kernel(&lane) {
118            Some(value) => {
119                out_data.push(value);
120                out_mask.push(false);
121            }
122            None => {
123                out_data.push(fill_value);
124                out_mask.push(true);
125            }
126        }
127
128        if !out_shape.is_empty() {
129            increment_multi(&mut out_multi, &out_shape);
130        }
131    }
132
133    let data_arr = Array::<T, IxDyn>::from_vec(IxDyn::new(&out_shape), out_data)?;
134    let mask_arr = Array::<bool, IxDyn>::from_vec(IxDyn::new(&out_shape), out_mask)?;
135    let mut result = MaskedArray::new(data_arr, mask_arr)?;
136    result.set_fill_value(fill_value);
137    Ok(result)
138}
139
140/// Per-axis count of unmasked elements. Returns a plain `Array<u64, IxDyn>`
141/// (not masked, since count is always defined) with the reduced axis dropped.
142fn count_axis<T, D>(ma: &MaskedArray<T, D>, axis: usize) -> FerrayResult<Array<u64, IxDyn>>
143where
144    T: Element + Copy,
145    D: Dimension,
146{
147    let ndim = ma.ndim();
148    if axis >= ndim {
149        return Err(FerrayError::axis_out_of_bounds(axis, ndim));
150    }
151    let shape = ma.shape();
152    let axis_len = shape[axis];
153    let out_shape: Vec<usize> = shape
154        .iter()
155        .enumerate()
156        .filter_map(|(i, &s)| if i == axis { None } else { Some(s) })
157        .collect();
158    let out_size: usize = if out_shape.is_empty() {
159        1
160    } else {
161        out_shape.iter().product()
162    };
163
164    let src_mask: Vec<bool> = ma.mask().iter().copied().collect();
165    let strides = compute_strides(shape);
166    let mut out: Vec<u64> = Vec::with_capacity(out_size);
167    let mut out_multi = vec![0usize; out_shape.len()];
168    let mut in_multi = vec![0usize; ndim];
169
170    for _ in 0..out_size {
171        let mut out_dim = 0;
172        for (d, idx) in in_multi.iter_mut().enumerate() {
173            if d == axis {
174                *idx = 0;
175            } else {
176                *idx = out_multi[out_dim];
177                out_dim += 1;
178            }
179        }
180
181        let mut count: u64 = 0;
182        for k in 0..axis_len {
183            in_multi[axis] = k;
184            let flat = in_multi
185                .iter()
186                .zip(strides.iter())
187                .map(|(i, s)| i * s)
188                .sum::<usize>();
189            if !src_mask[flat] {
190                count += 1;
191            }
192        }
193        out.push(count);
194
195        if !out_shape.is_empty() {
196            increment_multi(&mut out_multi, &out_shape);
197        }
198    }
199
200    Array::<u64, IxDyn>::from_vec(IxDyn::new(&out_shape), out)
201}
202
203impl<T, D> MaskedArray<T, D>
204where
205    T: Element + Copy,
206    D: Dimension,
207{
208    /// Count the number of unmasked (valid) elements.
209    ///
210    /// # Errors
211    /// This function does not currently error but returns `Result` for API
212    /// consistency.
213    pub fn count(&self) -> FerrayResult<usize> {
214        let n = self
215            .data()
216            .iter()
217            .zip(self.mask().iter())
218            .filter(|(_, m)| !**m)
219            .count();
220        Ok(n)
221    }
222
223    /// Count the number of unmasked elements per lane along `axis`.
224    ///
225    /// Returns a plain `Array<u64, IxDyn>` (not masked, since count is
226    /// always defined) with the reduced axis removed.
227    ///
228    /// # Errors
229    /// Returns `FerrayError::AxisOutOfBounds` if `axis >= ndim`.
230    pub fn count_axis(&self, axis: usize) -> FerrayResult<Array<u64, IxDyn>> {
231        count_axis(self, axis)
232    }
233}
234
235impl<T, D> MaskedArray<T, D>
236where
237    T: Element + Float,
238    D: Dimension,
239{
240    /// Compute the sum of unmasked elements.
241    ///
242    /// Returns zero if all elements are masked.
243    ///
244    /// # Errors
245    /// Returns an error only for internal failures.
246    pub fn sum(&self) -> FerrayResult<T> {
247        let zero = num_traits::zero::<T>();
248        let s = self
249            .data()
250            .iter()
251            .zip(self.mask().iter())
252            .filter(|(_, m)| !**m)
253            .fold(zero, |acc, (v, _)| acc + *v);
254        Ok(s)
255    }
256
257    /// Compute the mean of unmasked elements.
258    ///
259    /// Returns `NaN` if no elements are unmasked.
260    ///
261    /// # Errors
262    /// Returns an error only for internal failures.
263    pub fn mean(&self) -> FerrayResult<T> {
264        let zero = num_traits::zero::<T>();
265        let one: T = num_traits::one();
266        let (sum, count) = self
267            .data()
268            .iter()
269            .zip(self.mask().iter())
270            .filter(|(_, m)| !**m)
271            .fold((zero, 0usize), |(s, c), (v, _)| (s + *v, c + 1));
272        if count == 0 {
273            return Ok(T::nan());
274        }
275        Ok(sum / T::from(count).unwrap_or(one))
276    }
277
278    /// Compute the minimum of unmasked elements.
279    ///
280    /// NaN values in unmasked elements are propagated (returns NaN), matching NumPy.
281    ///
282    /// # Errors
283    /// Returns `FerrayError::InvalidValue` if no elements are unmasked.
284    pub fn min(&self) -> FerrayResult<T> {
285        self.data()
286            .iter()
287            .zip(self.mask().iter())
288            .filter(|(_, m)| !**m)
289            .map(|(v, _)| *v)
290            .fold(None, |acc: Option<T>, v| {
291                Some(match acc {
292                    Some(a) => {
293                        // NaN-propagating: if comparison is unordered, propagate NaN
294                        if a <= v {
295                            a
296                        } else if a > v {
297                            v
298                        } else {
299                            a
300                        }
301                    }
302                    None => v,
303                })
304            })
305            .ok_or_else(|| FerrayError::invalid_value("min: all elements are masked"))
306    }
307
308    /// Compute the maximum of unmasked elements.
309    ///
310    /// NaN values in unmasked elements are propagated (returns NaN), matching NumPy.
311    ///
312    /// # Errors
313    /// Returns `FerrayError::InvalidValue` if no elements are unmasked.
314    pub fn max(&self) -> FerrayResult<T> {
315        self.data()
316            .iter()
317            .zip(self.mask().iter())
318            .filter(|(_, m)| !**m)
319            .map(|(v, _)| *v)
320            .fold(None, |acc: Option<T>, v| {
321                Some(match acc {
322                    Some(a) => {
323                        if a >= v {
324                            a
325                        } else if a < v {
326                            v
327                        } else {
328                            a
329                        }
330                    }
331                    None => v,
332                })
333            })
334            .ok_or_else(|| FerrayError::invalid_value("max: all elements are masked"))
335    }
336
337    /// Compute the variance of unmasked elements (population variance, ddof=0).
338    ///
339    /// Returns `NaN` if no elements are unmasked.
340    ///
341    /// # Errors
342    /// Returns an error only for internal failures.
343    pub fn var(&self) -> FerrayResult<T> {
344        let mean = self.mean()?;
345        if mean.is_nan() {
346            return Ok(T::nan());
347        }
348        let zero = num_traits::zero::<T>();
349        let one: T = num_traits::one();
350        let (sum_sq, count) = self
351            .data()
352            .iter()
353            .zip(self.mask().iter())
354            .filter(|(_, m)| !**m)
355            .fold((zero, 0usize), |(s, c), (v, _)| {
356                let d = *v - mean;
357                (s + d * d, c + 1)
358            });
359        if count == 0 {
360            return Ok(T::nan());
361        }
362        Ok(sum_sq / T::from(count).unwrap_or(one))
363    }
364
365    /// Compute the standard deviation of unmasked elements (population, ddof=0).
366    ///
367    /// Returns `NaN` if no elements are unmasked.
368    ///
369    /// # Errors
370    /// Returns an error only for internal failures.
371    pub fn std(&self) -> FerrayResult<T> {
372        Ok(self.var()?.sqrt())
373    }
374
375    // -----------------------------------------------------------------------
376    // Per-axis reductions (issue #500)
377    //
378    // Each lane along `axis` is reduced independently. Lanes containing only
379    // masked elements produce a masked output position holding `fill_value`.
380    // -----------------------------------------------------------------------
381
382    /// Sum unmasked elements along `axis`. Returns a masked array with the
383    /// reduced axis removed; lanes that are entirely masked produce a
384    /// masked output position holding `fill_value`.
385    pub fn sum_axis(&self, axis: usize) -> FerrayResult<MaskedArray<T, IxDyn>> {
386        let zero = num_traits::zero::<T>();
387        let fill = self.fill_value();
388        reduce_axis(self, axis, fill, |lane| {
389            let mut acc = zero;
390            let mut any = false;
391            for &(v, m) in lane {
392                if !m {
393                    acc = acc + v;
394                    any = true;
395                }
396            }
397            if any { Some(acc) } else { None }
398        })
399    }
400
401    /// Mean of unmasked elements along `axis`. All-masked lanes are masked
402    /// in the output.
403    pub fn mean_axis(&self, axis: usize) -> FerrayResult<MaskedArray<T, IxDyn>> {
404        let zero = num_traits::zero::<T>();
405        let fill = self.fill_value();
406        reduce_axis(self, axis, fill, |lane| {
407            let mut acc = zero;
408            let mut count = 0usize;
409            for &(v, m) in lane {
410                if !m {
411                    acc = acc + v;
412                    count += 1;
413                }
414            }
415            if count == 0 {
416                None
417            } else {
418                Some(acc / T::from(count).unwrap_or_else(|| num_traits::one()))
419            }
420        })
421    }
422
423    /// Min of unmasked elements along `axis`. NaN-propagating per NumPy.
424    pub fn min_axis(&self, axis: usize) -> FerrayResult<MaskedArray<T, IxDyn>> {
425        let fill = self.fill_value();
426        reduce_axis(self, axis, fill, |lane| {
427            let mut acc: Option<T> = None;
428            for &(v, m) in lane {
429                if !m {
430                    acc = Some(match acc {
431                        Some(a) => {
432                            // NaN-propagating: if comparison is unordered, return NaN
433                            if a <= v {
434                                a
435                            } else if a > v {
436                                v
437                            } else {
438                                a
439                            }
440                        }
441                        None => v,
442                    });
443                }
444            }
445            acc
446        })
447    }
448
449    /// Max of unmasked elements along `axis`. NaN-propagating per NumPy.
450    pub fn max_axis(&self, axis: usize) -> FerrayResult<MaskedArray<T, IxDyn>> {
451        let fill = self.fill_value();
452        reduce_axis(self, axis, fill, |lane| {
453            let mut acc: Option<T> = None;
454            for &(v, m) in lane {
455                if !m {
456                    acc = Some(match acc {
457                        Some(a) => {
458                            if a >= v {
459                                a
460                            } else if a < v {
461                                v
462                            } else {
463                                a
464                            }
465                        }
466                        None => v,
467                    });
468                }
469            }
470            acc
471        })
472    }
473
474    /// Population variance (ddof=0) of unmasked elements along `axis`.
475    pub fn var_axis(&self, axis: usize) -> FerrayResult<MaskedArray<T, IxDyn>> {
476        let zero = num_traits::zero::<T>();
477        let fill = self.fill_value();
478        reduce_axis(self, axis, fill, |lane| {
479            let mut acc = zero;
480            let mut count = 0usize;
481            for &(v, m) in lane {
482                if !m {
483                    acc = acc + v;
484                    count += 1;
485                }
486            }
487            if count == 0 {
488                return None;
489            }
490            let n = T::from(count).unwrap_or_else(|| num_traits::one());
491            let mean = acc / n;
492            let mut sum_sq = zero;
493            for &(v, m) in lane {
494                if !m {
495                    let d = v - mean;
496                    sum_sq = sum_sq + d * d;
497                }
498            }
499            Some(sum_sq / n)
500        })
501    }
502
503    /// Population standard deviation (ddof=0) of unmasked elements along `axis`.
504    pub fn std_axis(&self, axis: usize) -> FerrayResult<MaskedArray<T, IxDyn>> {
505        let result = self.var_axis(axis)?;
506        // Take sqrt of unmasked positions; masked positions stay masked.
507        let fill = self.fill_value();
508        let mask = result.mask().clone();
509        let new_data: Vec<T> = result
510            .data()
511            .iter()
512            .zip(result.mask().iter())
513            .map(|(v, m)| if *m { fill } else { v.sqrt() })
514            .collect();
515        let data_arr = Array::<T, IxDyn>::from_vec(IxDyn::new(result.shape()), new_data)?;
516        let mut out = MaskedArray::new(data_arr, mask)?;
517        out.set_fill_value(fill);
518        Ok(out)
519    }
520}
521
522#[cfg(test)]
523mod tests {
524    use super::*;
525    use ferray_core::dimension::{Ix1, Ix2};
526
527    fn ma2d(rows: usize, cols: usize, data: Vec<f64>, mask: Vec<bool>) -> MaskedArray<f64, Ix2> {
528        let d = Array::<f64, Ix2>::from_vec(Ix2::new([rows, cols]), data).unwrap();
529        let m = Array::<bool, Ix2>::from_vec(Ix2::new([rows, cols]), mask).unwrap();
530        MaskedArray::new(d, m).unwrap()
531    }
532
533    // ---- #500: per-axis reductions ----
534
535    #[test]
536    fn sum_axis_drops_axis() {
537        // 2x3 array, no masks. axis=0 sums columns, axis=1 sums rows.
538        let ma = ma2d(2, 3, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![false; 6]);
539        let s0 = ma.sum_axis(0).unwrap();
540        assert_eq!(s0.shape(), &[3]);
541        let d0: Vec<f64> = s0.data().iter().copied().collect();
542        assert_eq!(d0, vec![5.0, 7.0, 9.0]);
543
544        let s1 = ma.sum_axis(1).unwrap();
545        assert_eq!(s1.shape(), &[2]);
546        let d1: Vec<f64> = s1.data().iter().copied().collect();
547        assert_eq!(d1, vec![6.0, 15.0]);
548    }
549
550    #[test]
551    fn sum_axis_skips_masked() {
552        // 2x3 array. Mask out (0, 1) so column 1 and row 0 lose one element.
553        let ma = ma2d(
554            2,
555            3,
556            vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
557            vec![false, true, false, false, false, false],
558        );
559        // axis=0 (per-column): col0 = 1+4=5, col1 = 5 (only row 1), col2 = 3+6=9
560        let s0 = ma.sum_axis(0).unwrap();
561        let d0: Vec<f64> = s0.data().iter().copied().collect();
562        assert_eq!(d0, vec![5.0, 5.0, 9.0]);
563        let m0: Vec<bool> = s0.mask().iter().copied().collect();
564        assert_eq!(m0, vec![false, false, false]);
565    }
566
567    #[test]
568    fn sum_axis_all_masked_lane_is_masked() {
569        // 2x3 array, mask out entire column 1.
570        let ma = ma2d(
571            2,
572            3,
573            vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
574            vec![false, true, false, false, true, false],
575        );
576        let s0 = ma.sum_axis(0).unwrap();
577        let m0: Vec<bool> = s0.mask().iter().copied().collect();
578        assert_eq!(m0, vec![false, true, false]);
579    }
580
581    #[test]
582    fn mean_axis_skips_masked() {
583        let ma = ma2d(
584            2,
585            3,
586            vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
587            vec![false, true, false, false, false, false],
588        );
589        // mean axis 1 (per-row): row 0 = (1+3)/2 = 2.0, row 1 = (4+5+6)/3 = 5.0
590        let m1 = ma.mean_axis(1).unwrap();
591        let d: Vec<f64> = m1.data().iter().copied().collect();
592        assert_eq!(d, vec![2.0, 5.0]);
593    }
594
595    #[test]
596    fn min_max_axis() {
597        let ma = ma2d(2, 3, vec![3.0, 1.0, 5.0, 2.0, 4.0, 0.0], vec![false; 6]);
598        let mn = ma.min_axis(0).unwrap();
599        let mx = ma.max_axis(0).unwrap();
600        let mn_d: Vec<f64> = mn.data().iter().copied().collect();
601        let mx_d: Vec<f64> = mx.data().iter().copied().collect();
602        assert_eq!(mn_d, vec![2.0, 1.0, 0.0]);
603        assert_eq!(mx_d, vec![3.0, 4.0, 5.0]);
604    }
605
606    #[test]
607    fn count_axis_basic() {
608        // 2x3 array. Mask: (0,1) and (1,2). col0:2, col1:1, col2:1.
609        let ma = ma2d(
610            2,
611            3,
612            vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
613            vec![false, true, false, false, false, true],
614        );
615        let c0 = ma.count_axis(0).unwrap();
616        let v: Vec<u64> = c0.iter().copied().collect();
617        assert_eq!(v, vec![2u64, 1, 1]);
618    }
619
620    #[test]
621    fn axis_out_of_bounds_errors() {
622        let ma = ma2d(2, 3, vec![0.0; 6], vec![false; 6]);
623        assert!(ma.sum_axis(2).is_err());
624    }
625
626    #[test]
627    fn var_std_axis() {
628        // Two rows of [1, 2, 3, 4, 5] — variance along axis 1 should be 2.0 each.
629        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 1.0, 2.0, 3.0, 4.0, 5.0];
630        let ma = ma2d(2, 5, data, vec![false; 10]);
631        let v = ma.var_axis(1).unwrap();
632        let s = ma.std_axis(1).unwrap();
633        let v_d: Vec<f64> = v.data().iter().copied().collect();
634        let s_d: Vec<f64> = s.data().iter().copied().collect();
635        for &x in &v_d {
636            assert!((x - 2.0).abs() < 1e-12);
637        }
638        for &x in &s_d {
639            assert!((x - 2.0_f64.sqrt()).abs() < 1e-12);
640        }
641    }
642
643    // ---- #501: fill_value ----
644
645    #[test]
646    fn fill_value_default_is_zero() {
647        let d = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
648        let m = Array::<bool, Ix1>::from_vec(Ix1::new([3]), vec![false; 3]).unwrap();
649        let ma = MaskedArray::new(d, m).unwrap();
650        assert_eq!(ma.fill_value(), 0.0);
651    }
652
653    #[test]
654    fn with_fill_value_sets_field() {
655        let d = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
656        let m = Array::<bool, Ix1>::from_vec(Ix1::new([3]), vec![false; 3]).unwrap();
657        let ma = MaskedArray::new(d, m).unwrap().with_fill_value(99.0);
658        assert_eq!(ma.fill_value(), 99.0);
659    }
660
661    #[test]
662    fn filled_default_uses_stored_fill_value() {
663        let d = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
664        let m =
665            Array::<bool, Ix1>::from_vec(Ix1::new([4]), vec![false, true, false, true]).unwrap();
666        let ma = MaskedArray::new(d, m).unwrap().with_fill_value(-1.0);
667        let filled = ma.filled_default().unwrap();
668        let v: Vec<f64> = filled.iter().copied().collect();
669        assert_eq!(v, vec![1.0, -1.0, 3.0, -1.0]);
670    }
671
672    #[test]
673    fn arithmetic_uses_fill_value() {
674        // (Adding two masked arrays) — result data at masked positions should
675        // be the receiver's fill_value, not zero.
676        use crate::masked_add;
677        let d_a = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
678        let m_a = Array::<bool, Ix1>::from_vec(Ix1::new([3]), vec![false, true, false]).unwrap();
679        let d_b = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![10.0, 20.0, 30.0]).unwrap();
680        let m_b = Array::<bool, Ix1>::from_vec(Ix1::new([3]), vec![false; 3]).unwrap();
681        let a = MaskedArray::new(d_a, m_a).unwrap().with_fill_value(-999.0);
682        let b = MaskedArray::new(d_b, m_b).unwrap();
683        let r = masked_add(&a, &b).unwrap();
684        let r_d: Vec<f64> = r.data().iter().copied().collect();
685        assert_eq!(r_d, vec![11.0, -999.0, 33.0]);
686        assert_eq!(r.fill_value(), -999.0);
687    }
688
689    // ---- #504: broadcasting in masked arithmetic ----
690
691    #[test]
692    fn masked_add_broadcasts_within_same_rank() {
693        use crate::masked_add;
694        // (3, 1) + (1, 4) -> (3, 4) — both Ix2.
695        let d_a = Array::<f64, Ix2>::from_vec(Ix2::new([3, 1]), vec![1.0, 2.0, 3.0]).unwrap();
696        let m_a = Array::<bool, Ix2>::from_vec(Ix2::new([3, 1]), vec![false; 3]).unwrap();
697        let d_b =
698            Array::<f64, Ix2>::from_vec(Ix2::new([1, 4]), vec![10.0, 20.0, 30.0, 40.0]).unwrap();
699        let m_b = Array::<bool, Ix2>::from_vec(Ix2::new([1, 4]), vec![false; 4]).unwrap();
700        let a = MaskedArray::new(d_a, m_a).unwrap();
701        let b = MaskedArray::new(d_b, m_b).unwrap();
702        let r = masked_add(&a, &b).unwrap();
703        assert_eq!(r.shape(), &[3, 4]);
704        let r_d: Vec<f64> = r.data().iter().copied().collect();
705        assert_eq!(
706            r_d,
707            vec![
708                11.0, 21.0, 31.0, 41.0, // row 0 = 1 + {10,20,30,40}
709                12.0, 22.0, 32.0, 42.0, // row 1 = 2 + ...
710                13.0, 23.0, 33.0, 43.0, // row 2 = 3 + ...
711            ]
712        );
713        let r_m: Vec<bool> = r.mask().iter().copied().collect();
714        assert_eq!(r_m, vec![false; 12]);
715    }
716
717    #[test]
718    fn masked_sub_broadcasts_with_mask_union() {
719        use crate::masked_sub;
720        // Mask one element in `a`. After broadcasting (3,1) -> (3,4),
721        // the masked row becomes a fully-masked row in the result.
722        let d_a = Array::<f64, Ix2>::from_vec(Ix2::new([3, 1]), vec![10.0, 20.0, 30.0]).unwrap();
723        let m_a = Array::<bool, Ix2>::from_vec(Ix2::new([3, 1]), vec![false, true, false]).unwrap();
724        let d_b = Array::<f64, Ix2>::from_vec(Ix2::new([1, 4]), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
725        let m_b = Array::<bool, Ix2>::from_vec(Ix2::new([1, 4]), vec![false; 4]).unwrap();
726        let a = MaskedArray::new(d_a, m_a).unwrap();
727        let b = MaskedArray::new(d_b, m_b).unwrap();
728        let r = masked_sub(&a, &b).unwrap();
729        let r_m: Vec<bool> = r.mask().iter().copied().collect();
730        // Row 1 is fully masked (4 positions).
731        assert_eq!(
732            r_m,
733            vec![
734                false, false, false, false, // row 0
735                true, true, true, true, // row 1 (masked)
736                false, false, false, false, // row 2
737            ]
738        );
739    }
740
741    // ---- #276: all-masked whole-array reductions ----
742    //
743    // Pin the current semantics when every element is masked:
744    //   sum   -> 0 (neutral element of the fold)
745    //   mean  -> NaN
746    //   var   -> NaN
747    //   std   -> NaN
748    //   min   -> error (FerrayError::InvalidValue)
749    //   max   -> error (FerrayError::InvalidValue)
750
751    fn all_masked_ma1d(n: usize) -> MaskedArray<f64, Ix1> {
752        let d = Array::<f64, Ix1>::from_vec(Ix1::new([n]), vec![1.0; n]).unwrap();
753        let m = Array::<bool, Ix1>::from_vec(Ix1::new([n]), vec![true; n]).unwrap();
754        MaskedArray::new(d, m).unwrap()
755    }
756
757    #[test]
758    fn sum_all_masked_returns_zero() {
759        let ma = all_masked_ma1d(4);
760        assert_eq!(ma.sum().unwrap(), 0.0);
761    }
762
763    #[test]
764    fn mean_all_masked_returns_nan() {
765        let ma = all_masked_ma1d(4);
766        assert!(ma.mean().unwrap().is_nan());
767    }
768
769    #[test]
770    fn var_all_masked_returns_nan() {
771        let ma = all_masked_ma1d(4);
772        assert!(ma.var().unwrap().is_nan());
773    }
774
775    #[test]
776    fn std_all_masked_returns_nan() {
777        let ma = all_masked_ma1d(4);
778        assert!(ma.std().unwrap().is_nan());
779    }
780
781    #[test]
782    fn min_max_all_masked_error() {
783        // Documenting the asymmetry: sum/mean/var/std fall through with
784        // 0/NaN sentinels, but min/max have no neutral element and error.
785        let ma = all_masked_ma1d(4);
786        assert!(ma.min().is_err());
787        assert!(ma.max().is_err());
788    }
789
790    #[test]
791    fn sum_var_std_all_masked_2d_matches_1d() {
792        // Same semantics hold for multi-dimensional whole-array reductions.
793        let d = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![9.0; 6]).unwrap();
794        let m = Array::<bool, Ix2>::from_vec(Ix2::new([2, 3]), vec![true; 6]).unwrap();
795        let ma = MaskedArray::new(d, m).unwrap();
796        assert_eq!(ma.sum().unwrap(), 0.0);
797        assert!(ma.var().unwrap().is_nan());
798        assert!(ma.std().unwrap().is_nan());
799    }
800
801    #[test]
802    fn masked_add_broadcast_incompatible_errors() {
803        use crate::masked_add;
804        let d_a = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
805        let m_a = Array::<bool, Ix1>::from_vec(Ix1::new([3]), vec![false; 3]).unwrap();
806        let d_b = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
807        let m_b = Array::<bool, Ix1>::from_vec(Ix1::new([4]), vec![false; 4]).unwrap();
808        let a = MaskedArray::new(d_a, m_a).unwrap();
809        let b = MaskedArray::new(d_b, m_b).unwrap();
810        assert!(masked_add(&a, &b).is_err());
811    }
812}