augurs_core/
float_iter.rs

1use std::cmp::Ordering;
2
3use num_traits::{Float, FromPrimitive};
4
5/// The result of a call to `nanminmax`.
6#[derive(Debug, Clone, Copy, PartialEq, Eq)]
7pub enum NanMinMaxResult<T> {
8    /// The iterator contains multiple distinct float; the minimum and maximum are returned.
9    MinMax(T, T),
10    /// The iterator contains exactly one distict float, after optionally ignoring NaNs.
11    OneElement(T),
12    /// The iterator was empty, or was empty after ignoring NaNs.
13    NoElements,
14    /// The iterator contains at least one NaN value, and NaNs were not ignored.
15    ///
16    /// This is unreachable if `nanminmax` was called with `ignore_nans: true`.
17    NaN,
18}
19
20// Helper function used by nanmin and nanmax.
21fn nan_reduce<I, T, F>(iter: I, ignore_nans: bool, f: F) -> T
22where
23    I: Iterator<Item = T>,
24    T: Float + FromPrimitive,
25    F: Fn(T, T) -> T,
26{
27    iter.reduce(|acc, x| {
28        if ignore_nans && x.is_nan() {
29            acc
30        } else if x.is_nan() || acc.is_nan() {
31            T::nan()
32        } else {
33            f(acc, x)
34        }
35    })
36    .unwrap_or_else(T::nan)
37}
38
39/// Helper trait for calculating summary statistics on floating point iterators with alternative NaN handling.
40///
41/// This is intended to be similar to numpy's `nanmean`, `nanmin`, `nanmax` etc.
42pub trait FloatIterExt<T: Float + FromPrimitive>: Iterator<Item = T> {
43    /// Returns the minimum of all elements in the iterator, handling NaN values.
44    ///
45    /// If `ignore_nans` is true, NaN values will be ignored and
46    /// not included in the minimum.
47    /// Otherwise, the minimum will be NaN if any element is NaN.
48    ///
49    /// # Examples
50    ///
51    /// ## Simple usage
52    ///
53    /// ```rust
54    /// use augurs_core::FloatIterExt;
55    ///
56    /// let x = [1.0, 2.0, 3.0, f64::NAN, 5.0];
57    /// assert_eq!(x.iter().copied().nanmin(true), 1.0);
58    /// assert!(x.iter().copied().nanmin(false).is_nan());
59    /// ```
60    ///
61    /// ## Empty iterator
62    ///
63    /// ```rust
64    /// use augurs_core::FloatIterExt;
65    ///
66    /// let x: [f64; 0] = [];
67    /// assert!(x.iter().copied().nanmin(true).is_nan());
68    /// assert!(x.iter().copied().nanmin(false).is_nan());
69    /// ```
70    ///
71    /// ## Only NaN values
72    ///
73    /// ```rust
74    /// use augurs_core::FloatIterExt;
75    ///
76    /// let x = [f64::NAN, f64::NAN];
77    /// assert!(x.iter().copied().nanmin(true).is_nan());
78    /// assert!(x.iter().copied().nanmin(false).is_nan());
79    /// ```
80    fn nanmin(self, ignore_nans: bool) -> T
81    where
82        Self: Sized,
83    {
84        nan_reduce(self, ignore_nans, T::min)
85    }
86
87    /// Returns the maximum of all elements in the iterator, handling NaN values.
88    ///
89    /// If `ignore_nans` is true, NaN values will be ignored and
90    /// not included in the maximum.
91    /// Otherwise, the maximum will be NaN if any element is NaN.
92    ///
93    /// # Examples
94    ///
95    /// ## Simple usage
96    ///
97    /// ```rust
98    /// use augurs_core::FloatIterExt;
99    ///
100    /// let x = [1.0, 2.0, 3.0, f64::NAN, 5.0];
101    /// assert_eq!(x.iter().copied().nanmax(true), 5.0);
102    /// assert!(x.iter().copied().nanmax(false).is_nan());
103    /// ```
104    ///
105    /// ## Empty iterator
106    ///
107    /// ```rust
108    /// use augurs_core::FloatIterExt;
109    ///
110    /// let x: [f64; 0] = [];
111    /// assert!(x.iter().copied().nanmax(true).is_nan());
112    /// assert!(x.iter().copied().nanmax(false).is_nan());
113    /// ```
114    ///
115    /// ## Only NaN values
116    ///
117    /// ```rust
118    /// use augurs_core::FloatIterExt;
119    ///
120    /// let x = [f64::NAN, f64::NAN];
121    /// assert!(x.iter().copied().nanmax(true).is_nan());
122    /// assert!(x.iter().copied().nanmax(false).is_nan());
123    /// ```
124    fn nanmax(self, ignore_nans: bool) -> T
125    where
126        Self: Sized,
127    {
128        nan_reduce(self, ignore_nans, T::max)
129    }
130
131    /// Returns the minimum and maximum of all elements in the iterator,
132    /// handling NaN values.
133    ///
134    /// If `ignore_nans` is true, NaN values will be ignored and
135    /// not included in the minimum or maximum.
136    /// Otherwise, the minimum and maximum will be NaN if any element is NaN.
137    ///
138    /// The return value is a [`NanMinMaxResult`], which is similar to
139    /// [`itertools::MinMaxResult`](https://docs.rs/itertools/latest/itertools/enum.MinMaxResult.html)
140    /// and provides more granular information on the result.
141    ///
142    /// # Examples
143    ///
144    /// ## Simple usage, ignoring NaNs
145    ///
146    /// ```
147    /// use augurs_core::{FloatIterExt, NanMinMaxResult};
148    ///
149    /// let x = [1.0, 2.0, 3.0, f64::NAN, 5.0];
150    /// let min_max = x.iter().copied().nanminmax(true);
151    /// assert_eq!(min_max, NanMinMaxResult::MinMax(1.0, 5.0));
152    /// ```
153    ///
154    /// ## Simple usage, including NaNs
155    ///
156    /// ```
157    /// use augurs_core::{FloatIterExt, NanMinMaxResult};
158    ///
159    /// let x = [1.0, 2.0, 3.0, f64::NAN, 5.0];
160    /// let min_max = x.iter().copied().nanminmax(false);
161    /// assert_eq!(min_max, NanMinMaxResult::NaN);
162    /// ```
163    ///
164    /// ## Only NaNs
165    ///
166    /// ```
167    /// use augurs_core::{FloatIterExt, NanMinMaxResult};
168    ///
169    /// let x = [f64::NAN, f64::NAN, f64::NAN];
170    /// let min_max = x.iter().copied().nanminmax(true);
171    /// assert_eq!(min_max, NanMinMaxResult::NoElements);
172    ///
173    /// let min_max = x.iter().copied().nanminmax(false);
174    /// assert_eq!(min_max, NanMinMaxResult::NaN);
175    /// ```
176    ///
177    /// ## Empty iterator
178    ///
179    /// ```
180    /// use augurs_core::{FloatIterExt, NanMinMaxResult};
181    ///
182    /// let x: [f64; 0] = [];
183    /// let min_max = x.iter().copied().nanminmax(true);
184    /// assert_eq!(min_max, NanMinMaxResult::NoElements);
185    ///
186    /// let min_max = x.iter().copied().nanminmax(false);
187    /// assert_eq!(min_max, NanMinMaxResult::NoElements);
188    /// ```
189    ///
190    /// ## Only one distinct element
191    ///
192    /// ```
193    /// use augurs_core::{FloatIterExt, NanMinMaxResult};
194    ///
195    /// let x = [1.0, f64::NAN, 1.0];
196    /// let min_max = x.iter().copied().nanminmax(true);
197    /// assert_eq!(min_max, NanMinMaxResult::OneElement(1.0));
198    ///
199    /// let min_max = x.iter().copied().nanminmax(false);
200    /// assert_eq!(min_max, NanMinMaxResult::NaN);
201    /// ```
202    fn nanminmax(self, ignore_nans: bool) -> NanMinMaxResult<T>
203    where
204        Self: Sized,
205    {
206        let mut acc = NanMinMaxResult::NoElements;
207        for x in self {
208            let is_nan = x.is_nan();
209            if is_nan && !ignore_nans {
210                return NanMinMaxResult::NaN;
211            }
212            if is_nan {
213                continue;
214            }
215            // From here on, we're ignoring NaNs.
216            acc = match acc {
217                NanMinMaxResult::NoElements => NanMinMaxResult::OneElement(x),
218                NanMinMaxResult::OneElement(one) => {
219                    match one.partial_cmp(&x).expect("x should not be NaN") {
220                        Ordering::Equal => acc,
221                        Ordering::Less => NanMinMaxResult::MinMax(one, x),
222                        Ordering::Greater => NanMinMaxResult::MinMax(x, one),
223                    }
224                }
225                NanMinMaxResult::MinMax(min, max) => {
226                    NanMinMaxResult::MinMax(min.min(x), max.max(x))
227                }
228                // This case is unreachable because we return early for NaN values when ignore_nans is false
229                NanMinMaxResult::NaN => {
230                    unreachable!("NaN case should have been handled by early return")
231                }
232            };
233        }
234        acc
235    }
236
237    /// Returns the mean of all elements in the iterator, handling NaN values.
238    ///
239    /// If `ignore_nans` is true, NaN values will be ignored and
240    /// not included in the mean.
241    /// Otherwise, the mean will be NaN if any element is NaN.
242    ///
243    /// # Examples
244    ///
245    /// ## Simple usage
246    ///
247    /// ```rust
248    /// use augurs_core::FloatIterExt;
249    ///
250    /// let x = [1.0, 2.0, 3.0, f64::NAN, 4.0];
251    /// assert_eq!(x.iter().copied().nanmean(true), 2.5);
252    /// assert!(x.iter().copied().nanmean(false).is_nan());
253    /// ```
254    ///
255    /// ## Empty iterator
256    ///
257    /// ```rust
258    /// use augurs_core::FloatIterExt;
259    ///
260    /// let x: [f64; 0] = [];
261    /// assert!(x.iter().copied().nanmean(true).is_nan());
262    /// assert!(x.iter().copied().nanmean(false).is_nan());
263    /// ```
264    ///
265    /// ## Only NaN values
266    ///
267    /// ```rust
268    /// use augurs_core::FloatIterExt;
269    ///
270    /// let x = [f64::NAN, f64::NAN];
271    /// assert!(x.iter().copied().nanmean(true).is_nan());
272    /// assert!(x.iter().copied().nanmean(false).is_nan());
273    /// ```
274    fn nanmean(self, ignore_nans: bool) -> T
275    where
276        Self: Sized,
277    {
278        let (n, sum) = self.fold((0, T::zero()), |(n, sum), x| {
279            if ignore_nans && x.is_nan() {
280                (n, sum)
281            } else if x.is_nan() || sum.is_nan() {
282                (n, T::nan())
283            } else {
284                (n + 1, sum + x)
285            }
286        });
287        if n == 0 {
288            T::nan()
289        } else if sum.is_nan() {
290            sum
291        } else {
292            sum / T::from_usize(n).unwrap_or_else(|| T::nan())
293        }
294    }
295}
296
297impl<T: Float + FromPrimitive, I: Iterator<Item = T>> FloatIterExt<T> for I {}
298
299#[cfg(test)]
300mod test {
301    use super::*;
302
303    #[test]
304    fn empty() {
305        let x: &[f64] = &[];
306        assert!(x.iter().copied().nanmin(true).is_nan());
307        assert!(x.iter().copied().nanmax(true).is_nan());
308    }
309
310    #[test]
311    fn no_nans() {
312        let x: &[f64] = &[-3.0, -2.0, -1.0, 0.0, 1.0, 2.0, 3.0];
313        assert_eq!(x.iter().copied().nanmin(true), -3.0);
314        assert_eq!(x.iter().copied().nanmax(true), 3.0);
315        assert_eq!(x.iter().copied().nanmin(false), -3.0);
316        assert_eq!(x.iter().copied().nanmax(false), 3.0);
317    }
318
319    #[test]
320    fn nans() {
321        let x: &[f64] = &[-3.0, -2.0, -1.0, f64::NAN, 1.0, 2.0, 3.0];
322        assert_eq!(x.iter().copied().nanmin(true), -3.0);
323        assert_eq!(x.iter().copied().nanmax(true), 3.0);
324
325        assert!(x.iter().copied().nanmin(false).is_nan());
326        assert!(x.iter().copied().nanmax(false).is_nan());
327    }
328
329    #[test]
330    fn nanmean() {
331        let x: &[f64] = &[-3.0, -2.0, -1.0, 0.0, 1.0, 2.0, 3.0];
332        assert_eq!(x.iter().copied().nanmean(true), 0.0);
333
334        let y: &[f64] = &[-3.0, -2.0, -1.0, f64::NAN, 1.0, 2.0, 3.0];
335        assert_eq!(y.iter().copied().nanmean(true), 0.0);
336        assert!(y.iter().copied().nanmean(false).is_nan());
337
338        let z: &[f64] = &[f64::NAN, f64::NAN];
339        assert!(z.iter().copied().nanmean(true).is_nan());
340    }
341
342    #[test]
343    fn nanminmax() {
344        let x: &[f64] = &[-3.0, -2.0, -1.0, 0.0, 1.0, 2.0, 3.0];
345        assert_eq!(
346            x.iter().copied().nanminmax(true),
347            NanMinMaxResult::MinMax(-3.0, 3.0)
348        );
349        assert_eq!(
350            x.iter().copied().nanminmax(false),
351            NanMinMaxResult::MinMax(-3.0, 3.0)
352        );
353
354        let y: &[f64] = &[-3.0, -2.0, -1.0, f64::NAN, 1.0, 2.0, 3.0];
355        assert_eq!(
356            y.iter().copied().nanminmax(true),
357            NanMinMaxResult::MinMax(-3.0, 3.0)
358        );
359        assert_eq!(y.iter().copied().nanminmax(false), NanMinMaxResult::NaN);
360
361        let z: &[f64] = &[f64::NAN, f64::NAN];
362        assert_eq!(
363            z.iter().copied().nanminmax(true),
364            NanMinMaxResult::NoElements
365        );
366        assert_eq!(z.iter().copied().nanminmax(false), NanMinMaxResult::NaN);
367
368        let e: &[f64] = &[];
369        assert_eq!(
370            e.iter().copied().nanminmax(true),
371            NanMinMaxResult::NoElements
372        );
373        assert_eq!(
374            e.iter().copied().nanminmax(false),
375            NanMinMaxResult::NoElements
376        );
377
378        let o: &[f64] = &[1.0, f64::NAN, 1.0];
379        assert_eq!(
380            o.iter().copied().nanminmax(true),
381            NanMinMaxResult::OneElement(1.0),
382        );
383        assert_eq!(o.iter().copied().nanminmax(false), NanMinMaxResult::NaN);
384    }
385}