augurs_core/float_iter.rs
1use std::cmp::Ordering;
2
3use num_traits::{Float, FromPrimitive};
4
5/// The result of a call to `nanminmax`.
6#[derive(Debug, Clone, Copy, PartialEq, Eq)]
7pub enum NanMinMaxResult<T> {
8 /// The iterator contains multiple distinct float; the minimum and maximum are returned.
9 MinMax(T, T),
10 /// The iterator contains exactly one distict float, after optionally ignoring NaNs.
11 OneElement(T),
12 /// The iterator was empty, or was empty after ignoring NaNs.
13 NoElements,
14 /// The iterator contains at least one NaN value, and NaNs were not ignored.
15 ///
16 /// This is unreachable if `nanminmax` was called with `ignore_nans: true`.
17 NaN,
18}
19
20// Helper function used by nanmin and nanmax.
21fn nan_reduce<I, T, F>(iter: I, ignore_nans: bool, f: F) -> T
22where
23 I: Iterator<Item = T>,
24 T: Float + FromPrimitive,
25 F: Fn(T, T) -> T,
26{
27 iter.reduce(|acc, x| {
28 if ignore_nans && x.is_nan() {
29 acc
30 } else if x.is_nan() || acc.is_nan() {
31 T::nan()
32 } else {
33 f(acc, x)
34 }
35 })
36 .unwrap_or_else(T::nan)
37}
38
39/// Helper trait for calculating summary statistics on floating point iterators with alternative NaN handling.
40///
41/// This is intended to be similar to numpy's `nanmean`, `nanmin`, `nanmax` etc.
42pub trait FloatIterExt<T: Float + FromPrimitive>: Iterator<Item = T> {
43 /// Returns the minimum of all elements in the iterator, handling NaN values.
44 ///
45 /// If `ignore_nans` is true, NaN values will be ignored and
46 /// not included in the minimum.
47 /// Otherwise, the minimum will be NaN if any element is NaN.
48 ///
49 /// # Examples
50 ///
51 /// ## Simple usage
52 ///
53 /// ```rust
54 /// use augurs_core::FloatIterExt;
55 ///
56 /// let x = [1.0, 2.0, 3.0, f64::NAN, 5.0];
57 /// assert_eq!(x.iter().copied().nanmin(true), 1.0);
58 /// assert!(x.iter().copied().nanmin(false).is_nan());
59 /// ```
60 ///
61 /// ## Empty iterator
62 ///
63 /// ```rust
64 /// use augurs_core::FloatIterExt;
65 ///
66 /// let x: [f64; 0] = [];
67 /// assert!(x.iter().copied().nanmin(true).is_nan());
68 /// assert!(x.iter().copied().nanmin(false).is_nan());
69 /// ```
70 ///
71 /// ## Only NaN values
72 ///
73 /// ```rust
74 /// use augurs_core::FloatIterExt;
75 ///
76 /// let x = [f64::NAN, f64::NAN];
77 /// assert!(x.iter().copied().nanmin(true).is_nan());
78 /// assert!(x.iter().copied().nanmin(false).is_nan());
79 /// ```
80 fn nanmin(self, ignore_nans: bool) -> T
81 where
82 Self: Sized,
83 {
84 nan_reduce(self, ignore_nans, T::min)
85 }
86
87 /// Returns the maximum of all elements in the iterator, handling NaN values.
88 ///
89 /// If `ignore_nans` is true, NaN values will be ignored and
90 /// not included in the maximum.
91 /// Otherwise, the maximum will be NaN if any element is NaN.
92 ///
93 /// # Examples
94 ///
95 /// ## Simple usage
96 ///
97 /// ```rust
98 /// use augurs_core::FloatIterExt;
99 ///
100 /// let x = [1.0, 2.0, 3.0, f64::NAN, 5.0];
101 /// assert_eq!(x.iter().copied().nanmax(true), 5.0);
102 /// assert!(x.iter().copied().nanmax(false).is_nan());
103 /// ```
104 ///
105 /// ## Empty iterator
106 ///
107 /// ```rust
108 /// use augurs_core::FloatIterExt;
109 ///
110 /// let x: [f64; 0] = [];
111 /// assert!(x.iter().copied().nanmax(true).is_nan());
112 /// assert!(x.iter().copied().nanmax(false).is_nan());
113 /// ```
114 ///
115 /// ## Only NaN values
116 ///
117 /// ```rust
118 /// use augurs_core::FloatIterExt;
119 ///
120 /// let x = [f64::NAN, f64::NAN];
121 /// assert!(x.iter().copied().nanmax(true).is_nan());
122 /// assert!(x.iter().copied().nanmax(false).is_nan());
123 /// ```
124 fn nanmax(self, ignore_nans: bool) -> T
125 where
126 Self: Sized,
127 {
128 nan_reduce(self, ignore_nans, T::max)
129 }
130
131 /// Returns the minimum and maximum of all elements in the iterator,
132 /// handling NaN values.
133 ///
134 /// If `ignore_nans` is true, NaN values will be ignored and
135 /// not included in the minimum or maximum.
136 /// Otherwise, the minimum and maximum will be NaN if any element is NaN.
137 ///
138 /// The return value is a [`NanMinMaxResult`], which is similar to
139 /// [`itertools::MinMaxResult`](https://docs.rs/itertools/latest/itertools/enum.MinMaxResult.html)
140 /// and provides more granular information on the result.
141 ///
142 /// # Examples
143 ///
144 /// ## Simple usage, ignoring NaNs
145 ///
146 /// ```
147 /// use augurs_core::{FloatIterExt, NanMinMaxResult};
148 ///
149 /// let x = [1.0, 2.0, 3.0, f64::NAN, 5.0];
150 /// let min_max = x.iter().copied().nanminmax(true);
151 /// assert_eq!(min_max, NanMinMaxResult::MinMax(1.0, 5.0));
152 /// ```
153 ///
154 /// ## Simple usage, including NaNs
155 ///
156 /// ```
157 /// use augurs_core::{FloatIterExt, NanMinMaxResult};
158 ///
159 /// let x = [1.0, 2.0, 3.0, f64::NAN, 5.0];
160 /// let min_max = x.iter().copied().nanminmax(false);
161 /// assert_eq!(min_max, NanMinMaxResult::NaN);
162 /// ```
163 ///
164 /// ## Only NaNs
165 ///
166 /// ```
167 /// use augurs_core::{FloatIterExt, NanMinMaxResult};
168 ///
169 /// let x = [f64::NAN, f64::NAN, f64::NAN];
170 /// let min_max = x.iter().copied().nanminmax(true);
171 /// assert_eq!(min_max, NanMinMaxResult::NoElements);
172 ///
173 /// let min_max = x.iter().copied().nanminmax(false);
174 /// assert_eq!(min_max, NanMinMaxResult::NaN);
175 /// ```
176 ///
177 /// ## Empty iterator
178 ///
179 /// ```
180 /// use augurs_core::{FloatIterExt, NanMinMaxResult};
181 ///
182 /// let x: [f64; 0] = [];
183 /// let min_max = x.iter().copied().nanminmax(true);
184 /// assert_eq!(min_max, NanMinMaxResult::NoElements);
185 ///
186 /// let min_max = x.iter().copied().nanminmax(false);
187 /// assert_eq!(min_max, NanMinMaxResult::NoElements);
188 /// ```
189 ///
190 /// ## Only one distinct element
191 ///
192 /// ```
193 /// use augurs_core::{FloatIterExt, NanMinMaxResult};
194 ///
195 /// let x = [1.0, f64::NAN, 1.0];
196 /// let min_max = x.iter().copied().nanminmax(true);
197 /// assert_eq!(min_max, NanMinMaxResult::OneElement(1.0));
198 ///
199 /// let min_max = x.iter().copied().nanminmax(false);
200 /// assert_eq!(min_max, NanMinMaxResult::NaN);
201 /// ```
202 fn nanminmax(self, ignore_nans: bool) -> NanMinMaxResult<T>
203 where
204 Self: Sized,
205 {
206 let mut acc = NanMinMaxResult::NoElements;
207 for x in self {
208 let is_nan = x.is_nan();
209 if is_nan && !ignore_nans {
210 return NanMinMaxResult::NaN;
211 }
212 if is_nan {
213 continue;
214 }
215 // From here on, we're ignoring NaNs.
216 acc = match acc {
217 NanMinMaxResult::NoElements => NanMinMaxResult::OneElement(x),
218 NanMinMaxResult::OneElement(one) => {
219 match one.partial_cmp(&x).expect("x should not be NaN") {
220 Ordering::Equal => acc,
221 Ordering::Less => NanMinMaxResult::MinMax(one, x),
222 Ordering::Greater => NanMinMaxResult::MinMax(x, one),
223 }
224 }
225 NanMinMaxResult::MinMax(min, max) => {
226 NanMinMaxResult::MinMax(min.min(x), max.max(x))
227 }
228 // This case is unreachable because we return early for NaN values when ignore_nans is false
229 NanMinMaxResult::NaN => {
230 unreachable!("NaN case should have been handled by early return")
231 }
232 };
233 }
234 acc
235 }
236
237 /// Returns the mean of all elements in the iterator, handling NaN values.
238 ///
239 /// If `ignore_nans` is true, NaN values will be ignored and
240 /// not included in the mean.
241 /// Otherwise, the mean will be NaN if any element is NaN.
242 ///
243 /// # Examples
244 ///
245 /// ## Simple usage
246 ///
247 /// ```rust
248 /// use augurs_core::FloatIterExt;
249 ///
250 /// let x = [1.0, 2.0, 3.0, f64::NAN, 4.0];
251 /// assert_eq!(x.iter().copied().nanmean(true), 2.5);
252 /// assert!(x.iter().copied().nanmean(false).is_nan());
253 /// ```
254 ///
255 /// ## Empty iterator
256 ///
257 /// ```rust
258 /// use augurs_core::FloatIterExt;
259 ///
260 /// let x: [f64; 0] = [];
261 /// assert!(x.iter().copied().nanmean(true).is_nan());
262 /// assert!(x.iter().copied().nanmean(false).is_nan());
263 /// ```
264 ///
265 /// ## Only NaN values
266 ///
267 /// ```rust
268 /// use augurs_core::FloatIterExt;
269 ///
270 /// let x = [f64::NAN, f64::NAN];
271 /// assert!(x.iter().copied().nanmean(true).is_nan());
272 /// assert!(x.iter().copied().nanmean(false).is_nan());
273 /// ```
274 fn nanmean(self, ignore_nans: bool) -> T
275 where
276 Self: Sized,
277 {
278 let (n, sum) = self.fold((0, T::zero()), |(n, sum), x| {
279 if ignore_nans && x.is_nan() {
280 (n, sum)
281 } else if x.is_nan() || sum.is_nan() {
282 (n, T::nan())
283 } else {
284 (n + 1, sum + x)
285 }
286 });
287 if n == 0 {
288 T::nan()
289 } else if sum.is_nan() {
290 sum
291 } else {
292 sum / T::from_usize(n).unwrap_or_else(|| T::nan())
293 }
294 }
295}
296
297impl<T: Float + FromPrimitive, I: Iterator<Item = T>> FloatIterExt<T> for I {}
298
299#[cfg(test)]
300mod test {
301 use super::*;
302
303 #[test]
304 fn empty() {
305 let x: &[f64] = &[];
306 assert!(x.iter().copied().nanmin(true).is_nan());
307 assert!(x.iter().copied().nanmax(true).is_nan());
308 }
309
310 #[test]
311 fn no_nans() {
312 let x: &[f64] = &[-3.0, -2.0, -1.0, 0.0, 1.0, 2.0, 3.0];
313 assert_eq!(x.iter().copied().nanmin(true), -3.0);
314 assert_eq!(x.iter().copied().nanmax(true), 3.0);
315 assert_eq!(x.iter().copied().nanmin(false), -3.0);
316 assert_eq!(x.iter().copied().nanmax(false), 3.0);
317 }
318
319 #[test]
320 fn nans() {
321 let x: &[f64] = &[-3.0, -2.0, -1.0, f64::NAN, 1.0, 2.0, 3.0];
322 assert_eq!(x.iter().copied().nanmin(true), -3.0);
323 assert_eq!(x.iter().copied().nanmax(true), 3.0);
324
325 assert!(x.iter().copied().nanmin(false).is_nan());
326 assert!(x.iter().copied().nanmax(false).is_nan());
327 }
328
329 #[test]
330 fn nanmean() {
331 let x: &[f64] = &[-3.0, -2.0, -1.0, 0.0, 1.0, 2.0, 3.0];
332 assert_eq!(x.iter().copied().nanmean(true), 0.0);
333
334 let y: &[f64] = &[-3.0, -2.0, -1.0, f64::NAN, 1.0, 2.0, 3.0];
335 assert_eq!(y.iter().copied().nanmean(true), 0.0);
336 assert!(y.iter().copied().nanmean(false).is_nan());
337
338 let z: &[f64] = &[f64::NAN, f64::NAN];
339 assert!(z.iter().copied().nanmean(true).is_nan());
340 }
341
342 #[test]
343 fn nanminmax() {
344 let x: &[f64] = &[-3.0, -2.0, -1.0, 0.0, 1.0, 2.0, 3.0];
345 assert_eq!(
346 x.iter().copied().nanminmax(true),
347 NanMinMaxResult::MinMax(-3.0, 3.0)
348 );
349 assert_eq!(
350 x.iter().copied().nanminmax(false),
351 NanMinMaxResult::MinMax(-3.0, 3.0)
352 );
353
354 let y: &[f64] = &[-3.0, -2.0, -1.0, f64::NAN, 1.0, 2.0, 3.0];
355 assert_eq!(
356 y.iter().copied().nanminmax(true),
357 NanMinMaxResult::MinMax(-3.0, 3.0)
358 );
359 assert_eq!(y.iter().copied().nanminmax(false), NanMinMaxResult::NaN);
360
361 let z: &[f64] = &[f64::NAN, f64::NAN];
362 assert_eq!(
363 z.iter().copied().nanminmax(true),
364 NanMinMaxResult::NoElements
365 );
366 assert_eq!(z.iter().copied().nanminmax(false), NanMinMaxResult::NaN);
367
368 let e: &[f64] = &[];
369 assert_eq!(
370 e.iter().copied().nanminmax(true),
371 NanMinMaxResult::NoElements
372 );
373 assert_eq!(
374 e.iter().copied().nanminmax(false),
375 NanMinMaxResult::NoElements
376 );
377
378 let o: &[f64] = &[1.0, f64::NAN, 1.0];
379 assert_eq!(
380 o.iter().copied().nanminmax(true),
381 NanMinMaxResult::OneElement(1.0),
382 );
383 assert_eq!(o.iter().copied().nanminmax(false), NanMinMaxResult::NaN);
384 }
385}