Skip to main content

ferray_ma/
reductions.rs

1// ferray-ma: Masked reductions (REQ-4)
2//
3// mean, sum, min, max, var, std, count — all skip masked elements.
4
5use ferray_core::dimension::Dimension;
6use ferray_core::dtype::Element;
7use ferray_core::error::{FerrayError, FerrayResult};
8use num_traits::Float;
9
10use crate::MaskedArray;
11
12impl<T, D> MaskedArray<T, D>
13where
14    T: Element + Copy,
15    D: Dimension,
16{
17    /// Count the number of unmasked (valid) elements.
18    ///
19    /// # Errors
20    /// This function does not currently error but returns `Result` for API
21    /// consistency.
22    pub fn count(&self) -> FerrayResult<usize> {
23        let n = self
24            .data()
25            .iter()
26            .zip(self.mask().iter())
27            .filter(|(_, m)| !**m)
28            .count();
29        Ok(n)
30    }
31}
32
33impl<T, D> MaskedArray<T, D>
34where
35    T: Element + Float,
36    D: Dimension,
37{
38    /// Compute the sum of unmasked elements.
39    ///
40    /// Returns zero if all elements are masked.
41    ///
42    /// # Errors
43    /// Returns an error only for internal failures.
44    pub fn sum(&self) -> FerrayResult<T> {
45        let zero = num_traits::zero::<T>();
46        let s = self
47            .data()
48            .iter()
49            .zip(self.mask().iter())
50            .filter(|(_, m)| !**m)
51            .fold(zero, |acc, (v, _)| acc + *v);
52        Ok(s)
53    }
54
55    /// Compute the mean of unmasked elements.
56    ///
57    /// Returns `NaN` if no elements are unmasked.
58    ///
59    /// # Errors
60    /// Returns an error only for internal failures.
61    pub fn mean(&self) -> FerrayResult<T> {
62        let zero = num_traits::zero::<T>();
63        let one: T = num_traits::one();
64        let (sum, count) = self
65            .data()
66            .iter()
67            .zip(self.mask().iter())
68            .filter(|(_, m)| !**m)
69            .fold((zero, 0usize), |(s, c), (v, _)| (s + *v, c + 1));
70        if count == 0 {
71            return Ok(T::nan());
72        }
73        Ok(sum / T::from(count).unwrap_or(one))
74    }
75
76    /// Compute the minimum of unmasked elements.
77    ///
78    /// NaN values in unmasked elements are propagated (returns NaN), matching NumPy.
79    ///
80    /// # Errors
81    /// Returns `FerrayError::InvalidValue` if no elements are unmasked.
82    pub fn min(&self) -> FerrayResult<T> {
83        self.data()
84            .iter()
85            .zip(self.mask().iter())
86            .filter(|(_, m)| !**m)
87            .map(|(v, _)| *v)
88            .fold(None, |acc: Option<T>, v| {
89                Some(match acc {
90                    Some(a) => {
91                        // NaN-propagating: if comparison is unordered, propagate NaN
92                        if a <= v { a } else if a > v { v } else { a }
93                    }
94                    None => v,
95                })
96            })
97            .ok_or_else(|| FerrayError::invalid_value("min: all elements are masked"))
98    }
99
100    /// Compute the maximum of unmasked elements.
101    ///
102    /// NaN values in unmasked elements are propagated (returns NaN), matching NumPy.
103    ///
104    /// # Errors
105    /// Returns `FerrayError::InvalidValue` if no elements are unmasked.
106    pub fn max(&self) -> FerrayResult<T> {
107        self.data()
108            .iter()
109            .zip(self.mask().iter())
110            .filter(|(_, m)| !**m)
111            .map(|(v, _)| *v)
112            .fold(None, |acc: Option<T>, v| {
113                Some(match acc {
114                    Some(a) => {
115                        if a >= v { a } else if a < v { v } else { a }
116                    }
117                    None => v,
118                })
119            })
120            .ok_or_else(|| FerrayError::invalid_value("max: all elements are masked"))
121    }
122
123    /// Compute the variance of unmasked elements (population variance, ddof=0).
124    ///
125    /// Returns `NaN` if no elements are unmasked.
126    ///
127    /// # Errors
128    /// Returns an error only for internal failures.
129    pub fn var(&self) -> FerrayResult<T> {
130        let mean = self.mean()?;
131        if mean.is_nan() {
132            return Ok(T::nan());
133        }
134        let zero = num_traits::zero::<T>();
135        let one: T = num_traits::one();
136        let (sum_sq, count) = self
137            .data()
138            .iter()
139            .zip(self.mask().iter())
140            .filter(|(_, m)| !**m)
141            .fold((zero, 0usize), |(s, c), (v, _)| {
142                let d = *v - mean;
143                (s + d * d, c + 1)
144            });
145        if count == 0 {
146            return Ok(T::nan());
147        }
148        Ok(sum_sq / T::from(count).unwrap_or(one))
149    }
150
151    /// Compute the standard deviation of unmasked elements (population, ddof=0).
152    ///
153    /// Returns `NaN` if no elements are unmasked.
154    ///
155    /// # Errors
156    /// Returns an error only for internal failures.
157    pub fn std(&self) -> FerrayResult<T> {
158        Ok(self.var()?.sqrt())
159    }
160}