ndarray_stats/quantile/
mod.rs

1use self::interpolate::{higher_index, lower_index, Interpolate};
2use super::sort::get_many_from_sorted_mut_unchecked;
3use crate::errors::QuantileError;
4use crate::errors::{EmptyInput, MinMaxError, MinMaxError::UndefinedOrder};
5use crate::{MaybeNan, MaybeNanExt};
6use ndarray::prelude::*;
7use ndarray::{Data, DataMut, RemoveAxis, Zip};
8use noisy_float::types::N64;
9use std::cmp;
10
11/// Quantile methods for `ArrayBase`.
12pub trait QuantileExt<A, S, D>
13where
14    S: Data<Elem = A>,
15    D: Dimension,
16{
17    /// Finds the index of the minimum value of the array.
18    ///
19    /// Returns `Err(MinMaxError::UndefinedOrder)` if any of the pairwise
20    /// orderings tested by the function are undefined. (For example, this
21    /// occurs if there are any floating-point NaN values in the array.)
22    ///
23    /// Returns `Err(MinMaxError::EmptyInput)` if the array is empty.
24    ///
25    /// Even if there are multiple (equal) elements that are minima, only one
26    /// index is returned. (Which one is returned is unspecified and may depend
27    /// on the memory layout of the array.)
28    ///
29    /// # Example
30    ///
31    /// ```
32    /// use ndarray::array;
33    /// use ndarray_stats::QuantileExt;
34    ///
35    /// let a = array![[1., 3., 5.],
36    ///                [2., 0., 6.]];
37    /// assert_eq!(a.argmin(), Ok((1, 1)));
38    /// ```
39    fn argmin(&self) -> Result<D::Pattern, MinMaxError>
40    where
41        A: PartialOrd;
42
43    /// Finds the index of the minimum value of the array skipping NaN values.
44    ///
45    /// Returns `Err(EmptyInput)` if the array is empty or none of the values in the array
46    /// are non-NaN values.
47    ///
48    /// Even if there are multiple (equal) elements that are minima, only one
49    /// index is returned. (Which one is returned is unspecified and may depend
50    /// on the memory layout of the array.)
51    ///
52    /// # Example
53    ///
54    /// ```
55    /// use ndarray::array;
56    /// use ndarray_stats::QuantileExt;
57    ///
58    /// let a = array![[::std::f64::NAN, 3., 5.],
59    ///                [2., 0., 6.]];
60    /// assert_eq!(a.argmin_skipnan(), Ok((1, 1)));
61    /// ```
62    fn argmin_skipnan(&self) -> Result<D::Pattern, EmptyInput>
63    where
64        A: MaybeNan,
65        A::NotNan: Ord;
66
67    /// Finds the elementwise minimum of the array.
68    ///
69    /// Returns `Err(MinMaxError::UndefinedOrder)` if any of the pairwise
70    /// orderings tested by the function are undefined. (For example, this
71    /// occurs if there are any floating-point NaN values in the array.)
72    ///
73    /// Returns `Err(MinMaxError::EmptyInput)` if the array is empty.
74    ///
75    /// Even if there are multiple (equal) elements that are minima, only one
76    /// is returned. (Which one is returned is unspecified and may depend on
77    /// the memory layout of the array.)
78    fn min(&self) -> Result<&A, MinMaxError>
79    where
80        A: PartialOrd;
81
82    /// Finds the elementwise minimum of the array, skipping NaN values.
83    ///
84    /// Even if there are multiple (equal) elements that are minima, only one
85    /// is returned. (Which one is returned is unspecified and may depend on
86    /// the memory layout of the array.)
87    ///
88    /// **Warning** This method will return a NaN value if none of the values
89    /// in the array are non-NaN values. Note that the NaN value might not be
90    /// in the array.
91    fn min_skipnan(&self) -> &A
92    where
93        A: MaybeNan,
94        A::NotNan: Ord;
95
96    /// Finds the index of the maximum value of the array.
97    ///
98    /// Returns `Err(MinMaxError::UndefinedOrder)` if any of the pairwise
99    /// orderings tested by the function are undefined. (For example, this
100    /// occurs if there are any floating-point NaN values in the array.)
101    ///
102    /// Returns `Err(MinMaxError::EmptyInput)` if the array is empty.
103    ///
104    /// Even if there are multiple (equal) elements that are maxima, only one
105    /// index is returned. (Which one is returned is unspecified and may depend
106    /// on the memory layout of the array.)
107    ///
108    /// # Example
109    ///
110    /// ```
111    /// use ndarray::array;
112    /// use ndarray_stats::QuantileExt;
113    ///
114    /// let a = array![[1., 3., 7.],
115    ///                [2., 5., 6.]];
116    /// assert_eq!(a.argmax(), Ok((0, 2)));
117    /// ```
118    fn argmax(&self) -> Result<D::Pattern, MinMaxError>
119    where
120        A: PartialOrd;
121
122    /// Finds the index of the maximum value of the array skipping NaN values.
123    ///
124    /// Returns `Err(EmptyInput)` if the array is empty or none of the values in the array
125    /// are non-NaN values.
126    ///
127    /// Even if there are multiple (equal) elements that are maxima, only one
128    /// index is returned. (Which one is returned is unspecified and may depend
129    /// on the memory layout of the array.)
130    ///
131    /// # Example
132    ///
133    /// ```
134    /// use ndarray::array;
135    /// use ndarray_stats::QuantileExt;
136    ///
137    /// let a = array![[::std::f64::NAN, 3., 5.],
138    ///                [2., 0., 6.]];
139    /// assert_eq!(a.argmax_skipnan(), Ok((1, 2)));
140    /// ```
141    fn argmax_skipnan(&self) -> Result<D::Pattern, EmptyInput>
142    where
143        A: MaybeNan,
144        A::NotNan: Ord;
145
146    /// Finds the elementwise maximum of the array.
147    ///
148    /// Returns `Err(MinMaxError::UndefinedOrder)` if any of the pairwise
149    /// orderings tested by the function are undefined. (For example, this
150    /// occurs if there are any floating-point NaN values in the array.)
151    ///
152    /// Returns `Err(EmptyInput)` if the array is empty.
153    ///
154    /// Even if there are multiple (equal) elements that are maxima, only one
155    /// is returned. (Which one is returned is unspecified and may depend on
156    /// the memory layout of the array.)
157    fn max(&self) -> Result<&A, MinMaxError>
158    where
159        A: PartialOrd;
160
161    /// Finds the elementwise maximum of the array, skipping NaN values.
162    ///
163    /// Even if there are multiple (equal) elements that are maxima, only one
164    /// is returned. (Which one is returned is unspecified and may depend on
165    /// the memory layout of the array.)
166    ///
167    /// **Warning** This method will return a NaN value if none of the values
168    /// in the array are non-NaN values. Note that the NaN value might not be
169    /// in the array.
170    fn max_skipnan(&self) -> &A
171    where
172        A: MaybeNan,
173        A::NotNan: Ord;
174
175    /// Return the qth quantile of the data along the specified axis.
176    ///
177    /// `q` needs to be a float between 0 and 1, bounds included.
178    /// The qth quantile for a 1-dimensional lane of length `N` is defined
179    /// as the element that would be indexed as `(N-1)q` if the lane were to be sorted
180    /// in increasing order.
181    /// If `(N-1)q` is not an integer the desired quantile lies between
182    /// two data points: we return the lower, nearest, higher or interpolated
183    /// value depending on the `interpolate` strategy.
184    ///
185    /// Some examples:
186    /// - `q=0.` returns the minimum along each 1-dimensional lane;
187    /// - `q=0.5` returns the median along each 1-dimensional lane;
188    /// - `q=1.` returns the maximum along each 1-dimensional lane.
189    /// (`q=0` and `q=1` are considered improper quantiles)
190    ///
191    /// The array is shuffled **in place** along each 1-dimensional lane in
192    /// order to produce the required quantile without allocating a copy
193    /// of the original array. Each 1-dimensional lane is shuffled independently
194    /// from the others.
195    /// No assumptions should be made on the ordering of the array elements
196    /// after this computation.
197    ///
198    /// Complexity ([quickselect](https://en.wikipedia.org/wiki/Quickselect)):
199    /// - average case: O(`m`);
200    /// - worst case: O(`m`^2);
201    /// where `m` is the number of elements in the array.
202    ///
203    /// Returns `Err(EmptyInput)` when the specified axis has length 0.
204    ///
205    /// Returns `Err(InvalidQuantile(q))` if `q` is not between `0.` and `1.` (inclusive).
206    ///
207    /// **Panics** if `axis` is out of bounds.
208    fn quantile_axis_mut<I>(
209        &mut self,
210        axis: Axis,
211        q: N64,
212        interpolate: &I,
213    ) -> Result<Array<A, D::Smaller>, QuantileError>
214    where
215        D: RemoveAxis,
216        A: Ord + Clone,
217        S: DataMut,
218        I: Interpolate<A>;
219
220    /// A bulk version of [`quantile_axis_mut`], optimized to retrieve multiple
221    /// quantiles at once.
222    ///
223    /// Returns an `Array`, where subviews along `axis` of the array correspond
224    /// to the elements of `qs`.
225    ///
226    /// See [`quantile_axis_mut`] for additional details on quantiles and the algorithm
227    /// used to retrieve them.
228    ///
229    /// Returns `Err(EmptyInput)` when the specified axis has length 0.
230    ///
231    /// Returns `Err(InvalidQuantile(q))` if any `q` in `qs` is not between `0.` and `1.` (inclusive).
232    ///
233    /// **Panics** if `axis` is out of bounds.
234    ///
235    /// [`quantile_axis_mut`]: #tymethod.quantile_axis_mut
236    ///
237    /// # Example
238    ///
239    /// ```rust
240    /// use ndarray::{array, aview1, Axis};
241    /// use ndarray_stats::{QuantileExt, interpolate::Nearest};
242    /// use noisy_float::types::n64;
243    ///
244    /// let mut data = array![[3, 4, 5], [6, 7, 8]];
245    /// let axis = Axis(1);
246    /// let qs = &[n64(0.3), n64(0.7)];
247    /// let quantiles = data.quantiles_axis_mut(axis, &aview1(qs), &Nearest).unwrap();
248    /// for (&q, quantile) in qs.iter().zip(quantiles.axis_iter(axis)) {
249    ///     assert_eq!(quantile, data.quantile_axis_mut(axis, q, &Nearest).unwrap());
250    /// }
251    /// ```
252    fn quantiles_axis_mut<S2, I>(
253        &mut self,
254        axis: Axis,
255        qs: &ArrayBase<S2, Ix1>,
256        interpolate: &I,
257    ) -> Result<Array<A, D>, QuantileError>
258    where
259        D: RemoveAxis,
260        A: Ord + Clone,
261        S: DataMut,
262        S2: Data<Elem = N64>,
263        I: Interpolate<A>;
264
265    /// Return the `q`th quantile of the data along the specified axis, skipping NaN values.
266    ///
267    /// See [`quantile_axis_mut`](#tymethod.quantile_axis_mut) for details.
268    fn quantile_axis_skipnan_mut<I>(
269        &mut self,
270        axis: Axis,
271        q: N64,
272        interpolate: &I,
273    ) -> Result<Array<A, D::Smaller>, QuantileError>
274    where
275        D: RemoveAxis,
276        A: MaybeNan,
277        A::NotNan: Clone + Ord,
278        S: DataMut,
279        I: Interpolate<A::NotNan>;
280
281    private_decl! {}
282}
283
284impl<A, S, D> QuantileExt<A, S, D> for ArrayBase<S, D>
285where
286    S: Data<Elem = A>,
287    D: Dimension,
288{
289    fn argmin(&self) -> Result<D::Pattern, MinMaxError>
290    where
291        A: PartialOrd,
292    {
293        let mut current_min = self.first().ok_or(EmptyInput)?;
294        let mut current_pattern_min = D::zeros(self.ndim()).into_pattern();
295
296        for (pattern, elem) in self.indexed_iter() {
297            if elem.partial_cmp(current_min).ok_or(UndefinedOrder)? == cmp::Ordering::Less {
298                current_pattern_min = pattern;
299                current_min = elem
300            }
301        }
302
303        Ok(current_pattern_min)
304    }
305
306    fn argmin_skipnan(&self) -> Result<D::Pattern, EmptyInput>
307    where
308        A: MaybeNan,
309        A::NotNan: Ord,
310    {
311        let mut pattern_min = D::zeros(self.ndim()).into_pattern();
312        let min = self.indexed_fold_skipnan(None, |current_min, (pattern, elem)| {
313            Some(match current_min {
314                Some(m) if (m <= elem) => m,
315                _ => {
316                    pattern_min = pattern;
317                    elem
318                }
319            })
320        });
321        if min.is_some() {
322            Ok(pattern_min)
323        } else {
324            Err(EmptyInput)
325        }
326    }
327
328    fn min(&self) -> Result<&A, MinMaxError>
329    where
330        A: PartialOrd,
331    {
332        let first = self.first().ok_or(EmptyInput)?;
333        self.fold(Ok(first), |acc, elem| {
334            let acc = acc?;
335            match elem.partial_cmp(acc).ok_or(UndefinedOrder)? {
336                cmp::Ordering::Less => Ok(elem),
337                _ => Ok(acc),
338            }
339        })
340    }
341
342    fn min_skipnan(&self) -> &A
343    where
344        A: MaybeNan,
345        A::NotNan: Ord,
346    {
347        let first = self.first().and_then(|v| v.try_as_not_nan());
348        A::from_not_nan_ref_opt(self.fold_skipnan(first, |acc, elem| {
349            Some(match acc {
350                Some(acc) => acc.min(elem),
351                None => elem,
352            })
353        }))
354    }
355
356    fn argmax(&self) -> Result<D::Pattern, MinMaxError>
357    where
358        A: PartialOrd,
359    {
360        let mut current_max = self.first().ok_or(EmptyInput)?;
361        let mut current_pattern_max = D::zeros(self.ndim()).into_pattern();
362
363        for (pattern, elem) in self.indexed_iter() {
364            if elem.partial_cmp(current_max).ok_or(UndefinedOrder)? == cmp::Ordering::Greater {
365                current_pattern_max = pattern;
366                current_max = elem
367            }
368        }
369
370        Ok(current_pattern_max)
371    }
372
373    fn argmax_skipnan(&self) -> Result<D::Pattern, EmptyInput>
374    where
375        A: MaybeNan,
376        A::NotNan: Ord,
377    {
378        let mut pattern_max = D::zeros(self.ndim()).into_pattern();
379        let max = self.indexed_fold_skipnan(None, |current_max, (pattern, elem)| {
380            Some(match current_max {
381                Some(m) if m >= elem => m,
382                _ => {
383                    pattern_max = pattern;
384                    elem
385                }
386            })
387        });
388        if max.is_some() {
389            Ok(pattern_max)
390        } else {
391            Err(EmptyInput)
392        }
393    }
394
395    fn max(&self) -> Result<&A, MinMaxError>
396    where
397        A: PartialOrd,
398    {
399        let first = self.first().ok_or(EmptyInput)?;
400        self.fold(Ok(first), |acc, elem| {
401            let acc = acc?;
402            match elem.partial_cmp(acc).ok_or(UndefinedOrder)? {
403                cmp::Ordering::Greater => Ok(elem),
404                _ => Ok(acc),
405            }
406        })
407    }
408
409    fn max_skipnan(&self) -> &A
410    where
411        A: MaybeNan,
412        A::NotNan: Ord,
413    {
414        let first = self.first().and_then(|v| v.try_as_not_nan());
415        A::from_not_nan_ref_opt(self.fold_skipnan(first, |acc, elem| {
416            Some(match acc {
417                Some(acc) => acc.max(elem),
418                None => elem,
419            })
420        }))
421    }
422
423    fn quantiles_axis_mut<S2, I>(
424        &mut self,
425        axis: Axis,
426        qs: &ArrayBase<S2, Ix1>,
427        interpolate: &I,
428    ) -> Result<Array<A, D>, QuantileError>
429    where
430        D: RemoveAxis,
431        A: Ord + Clone,
432        S: DataMut,
433        S2: Data<Elem = N64>,
434        I: Interpolate<A>,
435    {
436        // Minimize number of type parameters to avoid monomorphization bloat.
437        fn quantiles_axis_mut<A, D, I>(
438            mut data: ArrayViewMut<'_, A, D>,
439            axis: Axis,
440            qs: ArrayView1<'_, N64>,
441            _interpolate: &I,
442        ) -> Result<Array<A, D>, QuantileError>
443        where
444            D: RemoveAxis,
445            A: Ord + Clone,
446            I: Interpolate<A>,
447        {
448            for &q in qs {
449                if !((q >= 0.) && (q <= 1.)) {
450                    return Err(QuantileError::InvalidQuantile(q));
451                }
452            }
453
454            let axis_len = data.len_of(axis);
455            if axis_len == 0 {
456                return Err(QuantileError::EmptyInput);
457            }
458
459            let mut results_shape = data.raw_dim();
460            results_shape[axis.index()] = qs.len();
461            if results_shape.size() == 0 {
462                return Ok(Array::from_shape_vec(results_shape, Vec::new()).unwrap());
463            }
464
465            let mut searched_indexes = Vec::with_capacity(2 * qs.len());
466            for &q in &qs {
467                if I::needs_lower(q, axis_len) {
468                    searched_indexes.push(lower_index(q, axis_len));
469                }
470                if I::needs_higher(q, axis_len) {
471                    searched_indexes.push(higher_index(q, axis_len));
472                }
473            }
474            searched_indexes.sort();
475            searched_indexes.dedup();
476
477            let mut results = Array::from_elem(results_shape, data.first().unwrap().clone());
478            Zip::from(results.lanes_mut(axis))
479                .and(data.lanes_mut(axis))
480                .for_each(|mut results, mut data| {
481                    let index_map =
482                        get_many_from_sorted_mut_unchecked(&mut data, &searched_indexes);
483                    for (result, &q) in results.iter_mut().zip(qs) {
484                        let lower = if I::needs_lower(q, axis_len) {
485                            Some(index_map[&lower_index(q, axis_len)].clone())
486                        } else {
487                            None
488                        };
489                        let higher = if I::needs_higher(q, axis_len) {
490                            Some(index_map[&higher_index(q, axis_len)].clone())
491                        } else {
492                            None
493                        };
494                        *result = I::interpolate(lower, higher, q, axis_len);
495                    }
496                });
497            Ok(results)
498        }
499
500        quantiles_axis_mut(self.view_mut(), axis, qs.view(), interpolate)
501    }
502
503    fn quantile_axis_mut<I>(
504        &mut self,
505        axis: Axis,
506        q: N64,
507        interpolate: &I,
508    ) -> Result<Array<A, D::Smaller>, QuantileError>
509    where
510        D: RemoveAxis,
511        A: Ord + Clone,
512        S: DataMut,
513        I: Interpolate<A>,
514    {
515        self.quantiles_axis_mut(axis, &aview1(&[q]), interpolate)
516            .map(|a| a.index_axis_move(axis, 0))
517    }
518
519    fn quantile_axis_skipnan_mut<I>(
520        &mut self,
521        axis: Axis,
522        q: N64,
523        interpolate: &I,
524    ) -> Result<Array<A, D::Smaller>, QuantileError>
525    where
526        D: RemoveAxis,
527        A: MaybeNan,
528        A::NotNan: Clone + Ord,
529        S: DataMut,
530        I: Interpolate<A::NotNan>,
531    {
532        if !((q >= 0.) && (q <= 1.)) {
533            return Err(QuantileError::InvalidQuantile(q));
534        }
535
536        if self.len_of(axis) == 0 {
537            return Err(QuantileError::EmptyInput);
538        }
539
540        let quantile = self.map_axis_mut(axis, |lane| {
541            let mut not_nan = A::remove_nan_mut(lane);
542            A::from_not_nan_opt(if not_nan.is_empty() {
543                None
544            } else {
545                Some(
546                    not_nan
547                        .quantile_axis_mut::<I>(Axis(0), q, interpolate)
548                        .unwrap()
549                        .into_scalar(),
550                )
551            })
552        });
553        Ok(quantile)
554    }
555
556    private_impl! {}
557}
558
559/// Quantile methods for 1-D arrays.
560pub trait Quantile1dExt<A, S>
561where
562    S: Data<Elem = A>,
563{
564    /// Return the qth quantile of the data.
565    ///
566    /// `q` needs to be a float between 0 and 1, bounds included.
567    /// The qth quantile for a 1-dimensional array of length `N` is defined
568    /// as the element that would be indexed as `(N-1)q` if the array were to be sorted
569    /// in increasing order.
570    /// If `(N-1)q` is not an integer the desired quantile lies between
571    /// two data points: we return the lower, nearest, higher or interpolated
572    /// value depending on the `interpolate` strategy.
573    ///
574    /// Some examples:
575    /// - `q=0.` returns the minimum;
576    /// - `q=0.5` returns the median;
577    /// - `q=1.` returns the maximum.
578    /// (`q=0` and `q=1` are considered improper quantiles)
579    ///
580    /// The array is shuffled **in place** in order to produce the required quantile
581    /// without allocating a copy.
582    /// No assumptions should be made on the ordering of the array elements
583    /// after this computation.
584    ///
585    /// Complexity ([quickselect](https://en.wikipedia.org/wiki/Quickselect)):
586    /// - average case: O(`m`);
587    /// - worst case: O(`m`^2);
588    /// where `m` is the number of elements in the array.
589    ///
590    /// Returns `Err(EmptyInput)` if the array is empty.
591    ///
592    /// Returns `Err(InvalidQuantile(q))` if `q` is not between `0.` and `1.` (inclusive).
593    fn quantile_mut<I>(&mut self, q: N64, interpolate: &I) -> Result<A, QuantileError>
594    where
595        A: Ord + Clone,
596        S: DataMut,
597        I: Interpolate<A>;
598
599    /// A bulk version of [`quantile_mut`], optimized to retrieve multiple
600    /// quantiles at once.
601    ///
602    /// Returns an `Array`, where the elements of the array correspond to the
603    /// elements of `qs`.
604    ///
605    /// Returns `Err(EmptyInput)` if the array is empty.
606    ///
607    /// Returns `Err(InvalidQuantile(q))` if any `q` in
608    /// `qs` is not between `0.` and `1.` (inclusive).
609    ///
610    /// See [`quantile_mut`] for additional details on quantiles and the algorithm
611    /// used to retrieve them.
612    ///
613    /// [`quantile_mut`]: #tymethod.quantile_mut
614    fn quantiles_mut<S2, I>(
615        &mut self,
616        qs: &ArrayBase<S2, Ix1>,
617        interpolate: &I,
618    ) -> Result<Array1<A>, QuantileError>
619    where
620        A: Ord + Clone,
621        S: DataMut,
622        S2: Data<Elem = N64>,
623        I: Interpolate<A>;
624
625    private_decl! {}
626}
627
628impl<A, S> Quantile1dExt<A, S> for ArrayBase<S, Ix1>
629where
630    S: Data<Elem = A>,
631{
632    fn quantile_mut<I>(&mut self, q: N64, interpolate: &I) -> Result<A, QuantileError>
633    where
634        A: Ord + Clone,
635        S: DataMut,
636        I: Interpolate<A>,
637    {
638        Ok(self
639            .quantile_axis_mut(Axis(0), q, interpolate)?
640            .into_scalar())
641    }
642
643    fn quantiles_mut<S2, I>(
644        &mut self,
645        qs: &ArrayBase<S2, Ix1>,
646        interpolate: &I,
647    ) -> Result<Array1<A>, QuantileError>
648    where
649        A: Ord + Clone,
650        S: DataMut,
651        S2: Data<Elem = N64>,
652        I: Interpolate<A>,
653    {
654        self.quantiles_axis_mut(Axis(0), qs, interpolate)
655    }
656
657    private_impl! {}
658}
659
660pub mod interpolate;