Skip to main content

ferray_ma/
reductions.rs

1// ferray-ma: Masked reductions (REQ-4)
2//
3// mean, sum, min, max, var, std, count — all skip masked elements.
4
5use ferray_core::dimension::Dimension;
6use ferray_core::dtype::Element;
7use ferray_core::error::{FerrayError, FerrayResult};
8use num_traits::Float;
9
10use crate::MaskedArray;
11
12impl<T, D> MaskedArray<T, D>
13where
14    T: Element + Copy,
15    D: Dimension,
16{
17    /// Count the number of unmasked (valid) elements.
18    ///
19    /// # Errors
20    /// This function does not currently error but returns `Result` for API
21    /// consistency.
22    pub fn count(&self) -> FerrayResult<usize> {
23        let n = self
24            .data()
25            .iter()
26            .zip(self.mask().iter())
27            .filter(|(_, m)| !**m)
28            .count();
29        Ok(n)
30    }
31}
32
33impl<T, D> MaskedArray<T, D>
34where
35    T: Element + Float,
36    D: Dimension,
37{
38    /// Compute the sum of unmasked elements.
39    ///
40    /// Returns zero if all elements are masked.
41    ///
42    /// # Errors
43    /// Returns an error only for internal failures.
44    pub fn sum(&self) -> FerrayResult<T> {
45        let zero = num_traits::zero::<T>();
46        let s = self
47            .data()
48            .iter()
49            .zip(self.mask().iter())
50            .filter(|(_, m)| !**m)
51            .fold(zero, |acc, (v, _)| acc + *v);
52        Ok(s)
53    }
54
55    /// Compute the mean of unmasked elements.
56    ///
57    /// Returns `NaN` if no elements are unmasked.
58    ///
59    /// # Errors
60    /// Returns an error only for internal failures.
61    pub fn mean(&self) -> FerrayResult<T> {
62        let zero = num_traits::zero::<T>();
63        let one: T = num_traits::one();
64        let (sum, count) = self
65            .data()
66            .iter()
67            .zip(self.mask().iter())
68            .filter(|(_, m)| !**m)
69            .fold((zero, 0usize), |(s, c), (v, _)| (s + *v, c + 1));
70        if count == 0 {
71            return Ok(T::nan());
72        }
73        Ok(sum / T::from(count).unwrap_or(one))
74    }
75
76    /// Compute the minimum of unmasked elements.
77    ///
78    /// # Errors
79    /// Returns `FerrayError::InvalidValue` if no elements are unmasked.
80    pub fn min(&self) -> FerrayResult<T> {
81        self.data()
82            .iter()
83            .zip(self.mask().iter())
84            .filter(|(_, m)| !**m)
85            .map(|(v, _)| *v)
86            .fold(None, |acc: Option<T>, v| {
87                Some(match acc {
88                    Some(a) => {
89                        if v < a {
90                            v
91                        } else {
92                            a
93                        }
94                    }
95                    None => v,
96                })
97            })
98            .ok_or_else(|| FerrayError::invalid_value("min: all elements are masked"))
99    }
100
101    /// Compute the maximum of unmasked elements.
102    ///
103    /// # Errors
104    /// Returns `FerrayError::InvalidValue` if no elements are unmasked.
105    pub fn max(&self) -> FerrayResult<T> {
106        self.data()
107            .iter()
108            .zip(self.mask().iter())
109            .filter(|(_, m)| !**m)
110            .map(|(v, _)| *v)
111            .fold(None, |acc: Option<T>, v| {
112                Some(match acc {
113                    Some(a) => {
114                        if v > a {
115                            v
116                        } else {
117                            a
118                        }
119                    }
120                    None => v,
121                })
122            })
123            .ok_or_else(|| FerrayError::invalid_value("max: all elements are masked"))
124    }
125
126    /// Compute the variance of unmasked elements (population variance, ddof=0).
127    ///
128    /// Returns `NaN` if no elements are unmasked.
129    ///
130    /// # Errors
131    /// Returns an error only for internal failures.
132    pub fn var(&self) -> FerrayResult<T> {
133        let mean = self.mean()?;
134        if mean.is_nan() {
135            return Ok(T::nan());
136        }
137        let zero = num_traits::zero::<T>();
138        let one: T = num_traits::one();
139        let (sum_sq, count) = self
140            .data()
141            .iter()
142            .zip(self.mask().iter())
143            .filter(|(_, m)| !**m)
144            .fold((zero, 0usize), |(s, c), (v, _)| {
145                let d = *v - mean;
146                (s + d * d, c + 1)
147            });
148        if count == 0 {
149            return Ok(T::nan());
150        }
151        Ok(sum_sq / T::from(count).unwrap_or(one))
152    }
153
154    /// Compute the standard deviation of unmasked elements (population, ddof=0).
155    ///
156    /// Returns `NaN` if no elements are unmasked.
157    ///
158    /// # Errors
159    /// Returns an error only for internal failures.
160    pub fn std(&self) -> FerrayResult<T> {
161        Ok(self.var()?.sqrt())
162    }
163}