Skip to main content

ferray_stats/reductions/
nan_aware.rs

1// ferray-stats: NaN-aware reductions — nansum, nanprod, nanmin, nanmax, nanmean, nanvar, nanstd (REQ-3, REQ-4)
2// Also nancumsum, nancumprod (REQ-2b)
3
4use ferray_core::error::{FerrayError, FerrayResult};
5use ferray_core::{Array, Dimension, Element, IxDyn};
6use num_traits::Float;
7
8use super::{borrow_data, make_result, output_shape, reduce_axis_general, validate_axis};
9
10// ---------------------------------------------------------------------------
11// Helpers
12// ---------------------------------------------------------------------------
13
14/// Sum a lane, skipping NaN values. Returns zero for all-NaN.
15fn lane_nansum<T: Float>(lane: &[T]) -> T {
16    let non_nan: Vec<T> = lane.iter().copied().filter(|x| !x.is_nan()).collect();
17    crate::parallel::pairwise_sum(&non_nan, T::zero())
18}
19
20/// Product of a lane, skipping NaN values. Returns one for all-NaN.
21fn lane_nanprod<T: Float>(lane: &[T]) -> T {
22    lane.iter()
23        .copied()
24        .filter(|x| !x.is_nan())
25        .fold(T::one(), |a, b| a * b)
26}
27
28/// Mean of a lane, skipping NaN values. Returns NaN for all-NaN.
29fn lane_nanmean<T: Float>(lane: &[T]) -> T {
30    let non_nan: Vec<T> = lane.iter().copied().filter(|x| !x.is_nan()).collect();
31    if non_nan.is_empty() {
32        T::nan()
33    } else {
34        crate::parallel::pairwise_sum(&non_nan, T::zero()) / T::from(non_nan.len()).unwrap()
35    }
36}
37
38/// Variance of a lane, skipping NaN values.
39fn lane_nanvar<T: Float>(lane: &[T], ddof: usize) -> T {
40    let non_nan: Vec<T> = lane.iter().copied().filter(|x| !x.is_nan()).collect();
41    let count = non_nan.len();
42    if count <= ddof {
43        return T::nan();
44    }
45    let mean = crate::parallel::pairwise_sum(&non_nan, T::zero()) / T::from(count).unwrap();
46    let sq_diffs: Vec<T> = non_nan
47        .iter()
48        .map(|&x| {
49            let d = x - mean;
50            d * d
51        })
52        .collect();
53    crate::parallel::pairwise_sum(&sq_diffs, T::zero()) / T::from(count - ddof).unwrap()
54}
55
56/// Min of a lane, skipping NaN values. Returns NaN for all-NaN.
57fn lane_nanmin<T: Float>(lane: &[T]) -> T {
58    lane.iter()
59        .copied()
60        .filter(|x| !x.is_nan())
61        .reduce(|a, b| if a <= b { a } else { b })
62        .unwrap_or_else(T::nan)
63}
64
65/// Max of a lane, skipping NaN values. Returns NaN for all-NaN.
66fn lane_nanmax<T: Float>(lane: &[T]) -> T {
67    lane.iter()
68        .copied()
69        .filter(|x| !x.is_nan())
70        .reduce(|a, b| if a >= b { a } else { b })
71        .unwrap_or_else(T::nan)
72}
73
74// ---------------------------------------------------------------------------
75// Public API
76// ---------------------------------------------------------------------------
77
78/// Sum of array elements, treating NaN as zero.
79///
80/// Equivalent to `numpy.nansum`.
81pub fn nansum<T, D>(a: &Array<T, D>, axis: Option<usize>) -> FerrayResult<Array<T, IxDyn>>
82where
83    T: Element + Float,
84    D: Dimension,
85{
86    let data = borrow_data(a);
87    match axis {
88        None => {
89            let total = lane_nansum(&data);
90            make_result(&[], vec![total])
91        }
92        Some(ax) => {
93            validate_axis(ax, a.ndim())?;
94            let shape = a.shape();
95            let out_s = output_shape(shape, ax);
96            let result = reduce_axis_general(&data, shape, ax, lane_nansum);
97            make_result(&out_s, result)
98        }
99    }
100}
101
102/// Product of array elements, treating NaN as one.
103///
104/// Equivalent to `numpy.nanprod`.
105pub fn nanprod<T, D>(a: &Array<T, D>, axis: Option<usize>) -> FerrayResult<Array<T, IxDyn>>
106where
107    T: Element + Float,
108    D: Dimension,
109{
110    let data = borrow_data(a);
111    match axis {
112        None => {
113            let total = lane_nanprod(&data);
114            make_result(&[], vec![total])
115        }
116        Some(ax) => {
117            validate_axis(ax, a.ndim())?;
118            let shape = a.shape();
119            let out_s = output_shape(shape, ax);
120            let result = reduce_axis_general(&data, shape, ax, lane_nanprod);
121            make_result(&out_s, result)
122        }
123    }
124}
125
126/// Mean of array elements, skipping NaN. Returns NaN for all-NaN slices.
127///
128/// Equivalent to `numpy.nanmean`.
129pub fn nanmean<T, D>(a: &Array<T, D>, axis: Option<usize>) -> FerrayResult<Array<T, IxDyn>>
130where
131    T: Element + Float,
132    D: Dimension,
133{
134    let data = borrow_data(a);
135    match axis {
136        None => {
137            let m = lane_nanmean(&data);
138            make_result(&[], vec![m])
139        }
140        Some(ax) => {
141            validate_axis(ax, a.ndim())?;
142            let shape = a.shape();
143            let out_s = output_shape(shape, ax);
144            let result = reduce_axis_general(&data, shape, ax, lane_nanmean);
145            make_result(&out_s, result)
146        }
147    }
148}
149
150/// Variance of array elements, skipping NaN.
151///
152/// Equivalent to `numpy.nanvar`.
153pub fn nanvar<T, D>(
154    a: &Array<T, D>,
155    axis: Option<usize>,
156    ddof: usize,
157) -> FerrayResult<Array<T, IxDyn>>
158where
159    T: Element + Float,
160    D: Dimension,
161{
162    let data = borrow_data(a);
163    match axis {
164        None => {
165            let v = lane_nanvar(&data, ddof);
166            make_result(&[], vec![v])
167        }
168        Some(ax) => {
169            validate_axis(ax, a.ndim())?;
170            let shape = a.shape();
171            let out_s = output_shape(shape, ax);
172            let result = reduce_axis_general(&data, shape, ax, |lane| lane_nanvar(lane, ddof));
173            make_result(&out_s, result)
174        }
175    }
176}
177
178/// Standard deviation of array elements, skipping NaN.
179///
180/// Equivalent to `numpy.nanstd`.
181pub fn nanstd<T, D>(
182    a: &Array<T, D>,
183    axis: Option<usize>,
184    ddof: usize,
185) -> FerrayResult<Array<T, IxDyn>>
186where
187    T: Element + Float,
188    D: Dimension,
189{
190    let v = nanvar(a, axis, ddof)?;
191    let data: Vec<T> = v.iter().map(|x| x.sqrt()).collect();
192    make_result(v.shape(), data)
193}
194
195/// Minimum of array elements, skipping NaN.
196///
197/// Equivalent to `numpy.nanmin`.
198pub fn nanmin<T, D>(a: &Array<T, D>, axis: Option<usize>) -> FerrayResult<Array<T, IxDyn>>
199where
200    T: Element + Float,
201    D: Dimension,
202{
203    if a.is_empty() {
204        return Err(FerrayError::invalid_value(
205            "cannot compute nanmin of empty array",
206        ));
207    }
208    let data = borrow_data(a);
209    match axis {
210        None => {
211            let m = lane_nanmin(&data);
212            make_result(&[], vec![m])
213        }
214        Some(ax) => {
215            validate_axis(ax, a.ndim())?;
216            let shape = a.shape();
217            let out_s = output_shape(shape, ax);
218            let result = reduce_axis_general(&data, shape, ax, lane_nanmin);
219            make_result(&out_s, result)
220        }
221    }
222}
223
224/// Maximum of array elements, skipping NaN.
225///
226/// Equivalent to `numpy.nanmax`.
227pub fn nanmax<T, D>(a: &Array<T, D>, axis: Option<usize>) -> FerrayResult<Array<T, IxDyn>>
228where
229    T: Element + Float,
230    D: Dimension,
231{
232    if a.is_empty() {
233        return Err(FerrayError::invalid_value(
234            "cannot compute nanmax of empty array",
235        ));
236    }
237    let data = borrow_data(a);
238    match axis {
239        None => {
240            let m = lane_nanmax(&data);
241            make_result(&[], vec![m])
242        }
243        Some(ax) => {
244            validate_axis(ax, a.ndim())?;
245            let shape = a.shape();
246            let out_s = output_shape(shape, ax);
247            let result = reduce_axis_general(&data, shape, ax, lane_nanmax);
248            make_result(&out_s, result)
249        }
250    }
251}
252
253/// Cumulative sum, treating NaN as zero.
254///
255/// Re-exported from `ferray_ufunc::nancumsum`.
256pub fn nancumsum<T, D>(a: &Array<T, D>, axis: Option<usize>) -> FerrayResult<Array<T, D>>
257where
258    T: Element + Float,
259    D: Dimension,
260{
261    ferray_ufunc::nancumsum(a, axis)
262}
263
264/// Cumulative product, treating NaN as one.
265///
266/// Re-exported from `ferray_ufunc::nancumprod`.
267pub fn nancumprod<T, D>(a: &Array<T, D>, axis: Option<usize>) -> FerrayResult<Array<T, D>>
268where
269    T: Element + Float,
270    D: Dimension,
271{
272    ferray_ufunc::nancumprod(a, axis)
273}
274
275#[cfg(test)]
276mod tests {
277    use super::*;
278    use ferray_core::Ix1;
279
280    #[test]
281    fn test_nanmean_basic() {
282        let a = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, f64::NAN, 3.0]).unwrap();
283        let m = nanmean(&a, None).unwrap();
284        assert!((m.iter().next().unwrap() - 2.0).abs() < 1e-12);
285    }
286
287    #[test]
288    fn test_nanmean_all_nan() {
289        let a = Array::<f64, Ix1>::from_vec(Ix1::new([2]), vec![f64::NAN, f64::NAN]).unwrap();
290        let m = nanmean(&a, None).unwrap();
291        assert!(m.iter().next().unwrap().is_nan());
292    }
293
294    #[test]
295    fn test_nansum_basic() {
296        let a = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, f64::NAN, 3.0]).unwrap();
297        let s = nansum(&a, None).unwrap();
298        assert!((s.iter().next().unwrap() - 4.0).abs() < 1e-12);
299    }
300
301    #[test]
302    fn test_nansum_all_nan() {
303        let a = Array::<f64, Ix1>::from_vec(Ix1::new([2]), vec![f64::NAN, f64::NAN]).unwrap();
304        let s = nansum(&a, None).unwrap();
305        assert!((s.iter().next().unwrap() - 0.0).abs() < 1e-12);
306    }
307
308    #[test]
309    fn test_nanmin_nanmax() {
310        let a = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![3.0, f64::NAN, 1.0, 4.0]).unwrap();
311        let mn = nanmin(&a, None).unwrap();
312        let mx = nanmax(&a, None).unwrap();
313        assert!((mn.iter().next().unwrap() - 1.0).abs() < 1e-12);
314        assert!((mx.iter().next().unwrap() - 4.0).abs() < 1e-12);
315    }
316
317    #[test]
318    fn test_nanvar() {
319        let a = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![1.0, f64::NAN, 3.0, 5.0]).unwrap();
320        let v = nanvar(&a, None, 0).unwrap();
321        // non-nan values: [1, 3, 5], mean=3, var = (4+0+4)/3 = 8/3
322        let expected = 8.0 / 3.0;
323        assert!((v.iter().next().unwrap() - expected).abs() < 1e-12);
324    }
325
326    #[test]
327    fn test_nanstd() {
328        let a = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![1.0, f64::NAN, 3.0, 5.0]).unwrap();
329        let s = nanstd(&a, None, 0).unwrap();
330        let expected = (8.0_f64 / 3.0).sqrt();
331        assert!((s.iter().next().unwrap() - expected).abs() < 1e-12);
332    }
333
334    #[test]
335    fn test_nanprod() {
336        let a = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![2.0, f64::NAN, 3.0]).unwrap();
337        let p = nanprod(&a, None).unwrap();
338        assert!((p.iter().next().unwrap() - 6.0).abs() < 1e-12);
339    }
340
341    #[test]
342    fn test_nancumsum() {
343        let a = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, f64::NAN, 3.0]).unwrap();
344        let cs = nancumsum(&a, None).unwrap();
345        let data: Vec<f64> = cs.iter().copied().collect();
346        assert!((data[0] - 1.0).abs() < 1e-12);
347        assert!((data[1] - 1.0).abs() < 1e-12);
348        assert!((data[2] - 4.0).abs() < 1e-12);
349    }
350
351    #[test]
352    fn test_nancumprod() {
353        let a = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![2.0, f64::NAN, 3.0]).unwrap();
354        let cp = nancumprod(&a, None).unwrap();
355        let data: Vec<f64> = cp.iter().copied().collect();
356        assert!((data[0] - 2.0).abs() < 1e-12);
357        assert!((data[1] - 2.0).abs() < 1e-12);
358        assert!((data[2] - 6.0).abs() < 1e-12);
359    }
360}