ferray_ma/reductions.rs
1// ferray-ma: Masked reductions (REQ-4)
2//
3// mean, sum, min, max, var, std, count — all skip masked elements.
4
5use ferray_core::dimension::Dimension;
6use ferray_core::dtype::Element;
7use ferray_core::error::{FerrayError, FerrayResult};
8use num_traits::Float;
9
10use crate::MaskedArray;
11
12impl<T, D> MaskedArray<T, D>
13where
14 T: Element + Copy,
15 D: Dimension,
16{
17 /// Count the number of unmasked (valid) elements.
18 ///
19 /// # Errors
20 /// This function does not currently error but returns `Result` for API
21 /// consistency.
22 pub fn count(&self) -> FerrayResult<usize> {
23 let n = self
24 .data()
25 .iter()
26 .zip(self.mask().iter())
27 .filter(|(_, m)| !**m)
28 .count();
29 Ok(n)
30 }
31}
32
33impl<T, D> MaskedArray<T, D>
34where
35 T: Element + Float,
36 D: Dimension,
37{
38 /// Compute the sum of unmasked elements.
39 ///
40 /// Returns zero if all elements are masked.
41 ///
42 /// # Errors
43 /// Returns an error only for internal failures.
44 pub fn sum(&self) -> FerrayResult<T> {
45 let zero = num_traits::zero::<T>();
46 let s = self
47 .data()
48 .iter()
49 .zip(self.mask().iter())
50 .filter(|(_, m)| !**m)
51 .fold(zero, |acc, (v, _)| acc + *v);
52 Ok(s)
53 }
54
55 /// Compute the mean of unmasked elements.
56 ///
57 /// Returns `NaN` if no elements are unmasked.
58 ///
59 /// # Errors
60 /// Returns an error only for internal failures.
61 pub fn mean(&self) -> FerrayResult<T> {
62 let zero = num_traits::zero::<T>();
63 let one: T = num_traits::one();
64 let (sum, count) = self
65 .data()
66 .iter()
67 .zip(self.mask().iter())
68 .filter(|(_, m)| !**m)
69 .fold((zero, 0usize), |(s, c), (v, _)| (s + *v, c + 1));
70 if count == 0 {
71 return Ok(T::nan());
72 }
73 Ok(sum / T::from(count).unwrap_or(one))
74 }
75
76 /// Compute the minimum of unmasked elements.
77 ///
78 /// NaN values in unmasked elements are propagated (returns NaN), matching NumPy.
79 ///
80 /// # Errors
81 /// Returns `FerrayError::InvalidValue` if no elements are unmasked.
82 pub fn min(&self) -> FerrayResult<T> {
83 self.data()
84 .iter()
85 .zip(self.mask().iter())
86 .filter(|(_, m)| !**m)
87 .map(|(v, _)| *v)
88 .fold(None, |acc: Option<T>, v| {
89 Some(match acc {
90 Some(a) => {
91 // NaN-propagating: if comparison is unordered, propagate NaN
92 if a <= v { a } else if a > v { v } else { a }
93 }
94 None => v,
95 })
96 })
97 .ok_or_else(|| FerrayError::invalid_value("min: all elements are masked"))
98 }
99
100 /// Compute the maximum of unmasked elements.
101 ///
102 /// NaN values in unmasked elements are propagated (returns NaN), matching NumPy.
103 ///
104 /// # Errors
105 /// Returns `FerrayError::InvalidValue` if no elements are unmasked.
106 pub fn max(&self) -> FerrayResult<T> {
107 self.data()
108 .iter()
109 .zip(self.mask().iter())
110 .filter(|(_, m)| !**m)
111 .map(|(v, _)| *v)
112 .fold(None, |acc: Option<T>, v| {
113 Some(match acc {
114 Some(a) => {
115 if a >= v { a } else if a < v { v } else { a }
116 }
117 None => v,
118 })
119 })
120 .ok_or_else(|| FerrayError::invalid_value("max: all elements are masked"))
121 }
122
123 /// Compute the variance of unmasked elements (population variance, ddof=0).
124 ///
125 /// Returns `NaN` if no elements are unmasked.
126 ///
127 /// # Errors
128 /// Returns an error only for internal failures.
129 pub fn var(&self) -> FerrayResult<T> {
130 let mean = self.mean()?;
131 if mean.is_nan() {
132 return Ok(T::nan());
133 }
134 let zero = num_traits::zero::<T>();
135 let one: T = num_traits::one();
136 let (sum_sq, count) = self
137 .data()
138 .iter()
139 .zip(self.mask().iter())
140 .filter(|(_, m)| !**m)
141 .fold((zero, 0usize), |(s, c), (v, _)| {
142 let d = *v - mean;
143 (s + d * d, c + 1)
144 });
145 if count == 0 {
146 return Ok(T::nan());
147 }
148 Ok(sum_sq / T::from(count).unwrap_or(one))
149 }
150
151 /// Compute the standard deviation of unmasked elements (population, ddof=0).
152 ///
153 /// Returns `NaN` if no elements are unmasked.
154 ///
155 /// # Errors
156 /// Returns an error only for internal failures.
157 pub fn std(&self) -> FerrayResult<T> {
158 Ok(self.var()?.sqrt())
159 }
160}