Skip to main content

ferray_stats/reductions/
quantile.rs

1// ferray-stats: Quantile-based reductions — median, percentile, quantile (REQ-1)
2// Also nanmedian, nanpercentile (REQ-3)
3//
4// ## REQ status (ferray-stats quantiles, NumPy parity)
5//  - REQ-1 (median/percentile/quantile with optional `axis`) — SHIPPED:
6//    `pub fn median`, `pub fn percentile`, `pub fn quantile` (this file), each
7//    taking `axis: Option<usize>`. `percentile` forwards to `quantile` after
8//    scaling `q` by 1/100, matching numpy (numpy/lib/_function_base_impl.py:4065
9//    `percentile`, :4268 `quantile`, :3915 `median`). Non-test consumers: the
10//    `ferray_stats::median`/`percentile`/`quantile` `#[pyfunction]` shims in
11//    `ferray-python/src/stats.rs`, and in-crate `descriptive::iqr` which calls
12//    `crate::reductions::quantile::quantile` for the 25th/75th percentiles.
13//  - all 9 interpolation methods incl. ClosestObservation (#1080) — SHIPPED:
14//    `pub enum QuantileMethod` (this file) and `pub fn quantile_with_method` /
15//    `pub fn percentile_with_method` cover the Hyndman-Fan continuous methods
16//    (`Linear`, `Lower`, `Higher`, `Nearest`, `Midpoint`, plus the
17//    interpolation-parameterized variants) and the discrete
18//    `ClosestObservation` rule, matching numpy's `method=` set
19//    (numpy/lib/_function_base_impl.py:4268-4300). Fixed and audited green.
20//    Consumers: `quantile`/`percentile`/`median` (this file) call
21//    `quantile_with_method` with `QuantileMethod::Linear`; the
22//    `ferray_stats::percentile`/`quantile` python shims expose the method arg.
23//  - REQ-3 (nanmedian / nanpercentile / nanquantile — skip NaN) — SHIPPED:
24//    `pub fn nanmedian`, `pub fn nanpercentile`, `pub fn nanquantile` (this
25//    file) drop NaN from each lane before quantiling, matching numpy
26//    (numpy/lib/_nanfunctions_impl.py). Consumers: `ferray_stats::nanmedian`/
27//    `nanpercentile`/`nanquantile` `#[pyfunction]` shims in
28//    `ferray-python/src/stats.rs`.
29
30use ferray_core::error::{FerrayError, FerrayResult};
31use ferray_core::{Array, Dimension, Element, IxDyn};
32use num_traits::Float;
33
34use super::{collect_data, make_result, output_shape, reduce_axis_general, validate_axis};
35
36// ---------------------------------------------------------------------------
37// Helpers
38// ---------------------------------------------------------------------------
39
40/// Interpolation method for [`quantile_with_method`] and its `percentile`
41/// / `median` friends.
42///
43/// Matches all 13 `NumPy` quantile methods (#462, #566). The continuous
44/// methods use the Hyndman-Fan 1996 `(alpha, beta)` parameterization:
45/// `virtual_index = n*q + alpha + q*(1 - alpha - beta) - 1` with
46/// linear interpolation between the two bracketing sorted elements.
47/// The discrete methods compute an integer index via method-specific
48/// rules and return the exact sorted element (no interpolation).
49#[derive(Debug, Clone, Copy, PartialEq, Eq)]
50pub enum QuantileMethod {
51    /// `NumPy` default. Continuous with `alpha = beta = 1`. Returns
52    /// `lo_val * (1 - frac) + hi_val * frac`.
53    Linear,
54    /// Pick the element at `floor(q * (n - 1))` — the lower of the two
55    /// bracketing sorted elements.
56    Lower,
57    /// Pick the element at `ceil(q * (n - 1))` — the upper of the two
58    /// bracketing sorted elements.
59    Higher,
60    /// Pick the sorted element nearest to `q * (n - 1)`, with ties
61    /// (frac = 0.5) broken to the even index (matches `NumPy`'s
62    /// round-half-to-even convention).
63    Nearest,
64    /// Average of the two bracketing sorted elements: `(lo + hi) / 2`.
65    Midpoint,
66    /// Discrete method, Hyndman definition 1. Returns
67    /// `sorted[ceil(n * q) - 1]` — a step function with jumps at
68    /// `k / n`. `NumPy`'s `'inverted_cdf'`.
69    InvertedCdf,
70    /// Discrete method, Hyndman definition 2. Same as `InvertedCdf`
71    /// except that when `n * q` is an integer the result is the
72    /// average of the two bracketing sorted elements. `NumPy`'s
73    /// `'averaged_inverted_cdf'`.
74    AveragedInvertedCdf,
75    /// Discrete method, Hyndman definition 3. Virtual index
76    /// `idx = n*q - 1.5`; result index is `floor(idx) + 1`, except
77    /// `floor(idx)` when `idx` is integral and `floor(idx)` is odd (the
78    /// even-order-statistic correction), then clipped to `[0, n-1]`.
79    /// `NumPy`'s `'closest_observation'`.
80    ClosestObservation,
81    /// Continuous method, Hyndman definition 4. `alpha = 0`, `beta = 1`.
82    /// `NumPy`'s `'interpolated_inverted_cdf'`.
83    InterpolatedInvertedCdf,
84    /// Continuous method, Hyndman definition 5. `alpha = beta = 0.5`.
85    /// `NumPy`'s `'hazen'`.
86    Hazen,
87    /// Continuous method, Hyndman definition 6. `alpha = beta = 0`.
88    /// `NumPy`'s `'weibull'`.
89    Weibull,
90    /// Continuous method, Hyndman definition 8. `alpha = beta = 1/3`.
91    /// `NumPy`'s `'median_unbiased'`.
92    MedianUnbiased,
93    /// Continuous method, Hyndman definition 9. `alpha = beta = 3/8`.
94    /// `NumPy`'s `'normal_unbiased'`.
95    NormalUnbiased,
96}
97
98/// Compute a Hyndman-Fan virtual index for the continuous quantile
99/// methods. Returns `(lo_i, gamma)` with `lo_i` clamped to `[0, n - 1]`
100/// and `gamma` in `[0, 1]`.
101///
102/// `virtual_index = n * q + alpha + q * (1 - alpha - beta) - 1`
103#[inline]
104fn continuous_vidx<T: Float>(n: usize, q: T, alpha: T, beta: T) -> (usize, T) {
105    let nf = T::from(n).unwrap();
106    let zero = T::zero();
107    let one = T::one();
108    let n_minus_1 = T::from(n - 1).unwrap();
109
110    let vidx = nf * q + alpha + q * (one - alpha - beta) - one;
111
112    // Clamp the virtual index into the addressable range.
113    let vidx_clamped = if vidx < zero {
114        zero
115    } else if vidx > n_minus_1 {
116        n_minus_1
117    } else {
118        vidx
119    };
120
121    let lo = vidx_clamped.floor();
122    let lo_i = lo.to_usize().unwrap_or(0).min(n - 1);
123    let gamma = vidx_clamped - lo;
124    (lo_i, gamma)
125}
126
127/// Compute `(lo_i, gamma)` for a given quantile method, where the
128/// result of the quantile is `(1 - gamma) * sorted[lo_i] + gamma *
129/// sorted[lo_i + 1]`. For all discrete methods and for integer virtual
130/// indices `gamma = 0`, which short-circuits to `sorted[lo_i]` and
131/// avoids the second-pass scan for `hi_val`.
132//
133// Each `QuantileMethod` arm reproduces the corresponding NumPy formula
134// verbatim; splitting the dispatch into helpers would scatter the spec
135// across the file and make maintenance harder than the line count.
136#[allow(clippy::too_many_lines)]
137fn method_index_and_gamma<T: Float>(n: usize, q: T, method: QuantileMethod) -> (usize, T) {
138    let zero = T::zero();
139    let one = T::one();
140    let half = T::from(0.5).unwrap();
141    let nf = T::from(n).unwrap();
142
143    match method {
144        // --- Continuous methods via (alpha, beta). Linear is (1, 1). ---
145        QuantileMethod::Linear => continuous_vidx(n, q, one, one),
146        QuantileMethod::Weibull => continuous_vidx(n, q, zero, zero),
147        QuantileMethod::Hazen => continuous_vidx(n, q, half, half),
148        QuantileMethod::InterpolatedInvertedCdf => continuous_vidx(n, q, zero, one),
149        QuantileMethod::MedianUnbiased => {
150            let third = T::from(1.0 / 3.0).unwrap();
151            continuous_vidx(n, q, third, third)
152        }
153        QuantileMethod::NormalUnbiased => {
154            let ae = T::from(3.0 / 8.0).unwrap();
155            continuous_vidx(n, q, ae, ae)
156        }
157
158        // --- Old discrete classics: reuse the linear virtual index
159        //     and apply their specific rounding rules.
160        QuantileMethod::Lower => {
161            let vidx = q * T::from(n - 1).unwrap();
162            let lo_i = vidx.floor().to_usize().unwrap_or(0).min(n - 1);
163            (lo_i, zero)
164        }
165        QuantileMethod::Higher => {
166            let vidx = q * T::from(n - 1).unwrap();
167            let lo = vidx.floor();
168            let lo_i = lo.to_usize().unwrap_or(0).min(n - 1);
169            let frac = vidx - lo;
170            // If there's a fractional part, gamma=1 picks sorted[lo_i+1]
171            // exactly. If not, lo_val itself is the ceiling.
172            if frac > zero && lo_i + 1 < n {
173                (lo_i, one)
174            } else {
175                (lo_i, zero)
176            }
177        }
178        QuantileMethod::Nearest => {
179            let vidx = q * T::from(n - 1).unwrap();
180            let lo = vidx.floor();
181            let lo_i = lo.to_usize().unwrap_or(0).min(n - 1);
182            let frac = vidx - lo;
183            if frac < half {
184                (lo_i, zero)
185            } else if frac > half {
186                if lo_i + 1 < n {
187                    (lo_i, one)
188                } else {
189                    (lo_i, zero)
190                }
191            } else {
192                // Tie: round to the even lo_i.
193                if lo_i.is_multiple_of(2) || lo_i + 1 >= n {
194                    (lo_i, zero)
195                } else {
196                    (lo_i, one)
197                }
198            }
199        }
200        QuantileMethod::Midpoint => {
201            let vidx = q * T::from(n - 1).unwrap();
202            let lo = vidx.floor();
203            let lo_i = lo.to_usize().unwrap_or(0).min(n - 1);
204            let frac = vidx - lo;
205            if frac > zero && lo_i + 1 < n {
206                (lo_i, half)
207            } else {
208                (lo_i, zero)
209            }
210        }
211
212        // --- Discrete step-function methods ---
213        QuantileMethod::InvertedCdf => {
214            // k = ceil(n * q) - 1, clamped to [0, n - 1].
215            let nq = nf * q;
216            let k = if nq <= zero {
217                0
218            } else {
219                nq.ceil()
220                    .to_usize()
221                    .unwrap_or(0)
222                    .saturating_sub(1)
223                    .min(n - 1)
224            };
225            (k, zero)
226        }
227        QuantileMethod::AveragedInvertedCdf => {
228            // Same as InvertedCdf for non-integer n*q; for exact
229            // integer n*q the result is (sorted[k-1] + sorted[k]) / 2.
230            let nq = nf * q;
231            let floor_nq = nq.floor();
232            let is_integer = nq == floor_nq;
233            if is_integer && nq > zero && nq < nf {
234                // nq is in (0, n), so k = floor(nq) = floor(nq) - 1 + 1
235                // and we average sorted[k-1] and sorted[k].
236                let k = floor_nq.to_usize().unwrap_or(0);
237                let lo_i = k.saturating_sub(1).min(n - 1);
238                if lo_i + 1 < n {
239                    (lo_i, half)
240                } else {
241                    (lo_i, zero)
242                }
243            } else {
244                // Non-integer, or at the boundary — same as InvertedCdf.
245                let k = if nq <= zero {
246                    0
247                } else {
248                    nq.ceil()
249                        .to_usize()
250                        .unwrap_or(0)
251                        .saturating_sub(1)
252                        .min(n - 1)
253                };
254                (k, zero)
255            }
256        }
257        QuantileMethod::ClosestObservation => {
258            // NumPy `_closest_observation` delegates to
259            // `_discrete_interpolation_to_boundaries` with virtual index
260            // `(n*q) - 1 - 0.5 = n*q - 1.5`
261            // (numpy/lib/_function_base_impl.py:4611-4630):
262            //   previous = floor(index); next = previous + 1;
263            //   gamma = index - previous;
264            //   res = next, except `previous` where
265            //         `(gamma == 0) & (floor(index) % 2 == 1)`
266            //         (the even-order-statistic correction);
267            //   then `res[res < 0] = 0`.
268            let one_and_half = one + half;
269            let idx = nf * q - one_and_half;
270            let previous = idx.floor();
271            let gamma = idx - previous;
272            // `previous` as a signed integer: it can be negative for small
273            // `q` (e.g. n=4, q=0.3 -> idx = -0.3 -> previous = -1).
274            let previous_i = previous.to_i64().unwrap_or(0);
275            // Even-order-statistic correction: keep `previous` only when the
276            // virtual index sits exactly on an integer (`gamma == 0`) and
277            // `floor(index)` is odd; otherwise advance to `next`.
278            let chosen = if gamma == zero && previous_i.rem_euclid(2) == 1 {
279                previous_i
280            } else {
281                previous_i + 1
282            };
283            // Clip into the addressable range: numpy clips `< 0` to 0, and the
284            // top is bounded by `n - 1` for the q-near-1 boundary.
285            let k = if chosen < 0 {
286                0
287            } else {
288                usize::try_from(chosen).unwrap_or(0).min(n - 1)
289            };
290            (k, zero)
291        }
292    }
293}
294
295/// Compute a single quantile value from an unsorted slice using
296/// `select_nth_unstable_by` rather than a full `sort_by`.
297///
298/// The selection algorithm gives an O(n) average-time path (quickselect)
299/// instead of the O(n log n) full sort the previous implementation used
300/// (#175). All 13 `NumPy` quantile methods are supported: every method
301/// produces a `(lo_i, gamma)` pair via [`method_index_and_gamma`] and
302/// the kernel applies a single uniform interpolation formula.
303///
304/// `data` is consumed: it is partitioned in place so the caller should
305/// pass an owned buffer (or a clone they no longer need).
306fn quantile_select<T: Float>(mut data: Vec<T>, q: T, method: QuantileMethod) -> T {
307    let n = data.len();
308    if n == 0 {
309        return T::nan();
310    }
311    if n == 1 {
312        return data[0];
313    }
314
315    let cmp = |a: &T, b: &T| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal);
316    let (lo_i, gamma) = method_index_and_gamma(n, q, method);
317
318    // First selection: place the lo_i-th smallest at position lo_i.
319    data.select_nth_unstable_by(lo_i, cmp);
320    let lo_val = data[lo_i];
321
322    // Fast exit: no interpolation needed for discrete methods or
323    // whenever the virtual index landed exactly on an integer position.
324    if gamma == T::zero() || lo_i >= n - 1 {
325        return lo_val;
326    }
327
328    // After the partial select, every element in `data[lo_i + 1..]` is
329    // ordered-after `lo_val`; the smallest of them is the
330    // `(lo_i + 1)`-th smallest element overall, which is the `hi_val`
331    // the interpolation formula needs.
332    let hi_val = data[lo_i + 1..]
333        .iter()
334        .copied()
335        .reduce(|a, b| match cmp(&a, &b) {
336            std::cmp::Ordering::Less | std::cmp::Ordering::Equal => a,
337            std::cmp::Ordering::Greater => b,
338        })
339        .unwrap_or(lo_val);
340
341    (T::one() - gamma) * lo_val + gamma * hi_val
342}
343
344/// Compute quantile on a lane using a caller-chosen interpolation method.
345fn lane_quantile_with_method<T: Float>(lane: &[T], q: T, method: QuantileMethod) -> T {
346    quantile_select(lane.to_vec(), q, method)
347}
348
349/// Compute quantile on a lane, excluding NaNs.
350fn lane_nanquantile<T: Float>(lane: &[T], q: T) -> T {
351    let filtered: Vec<T> = lane.iter().copied().filter(|x| !x.is_nan()).collect();
352    if filtered.is_empty() {
353        return T::nan();
354    }
355    quantile_select(filtered, q, QuantileMethod::Linear)
356}
357
358// ---------------------------------------------------------------------------
359// quantile
360// ---------------------------------------------------------------------------
361
362/// Compute the q-th quantile of array data along a given axis.
363///
364/// `q` must be in \[0, 1\]. Uses linear interpolation (`NumPy` default method).
365/// Equivalent to `numpy.quantile`. See [`quantile_with_method`] for the
366/// variant that accepts a [`QuantileMethod`] selector (#462).
367pub fn quantile<T, D>(a: &Array<T, D>, q: T, axis: Option<usize>) -> FerrayResult<Array<T, IxDyn>>
368where
369    T: Element + Float,
370    D: Dimension,
371{
372    quantile_with_method(a, q, axis, QuantileMethod::Linear)
373}
374
375/// Compute the q-th quantile of array data along a given axis using a
376/// specific interpolation method.
377///
378/// Equivalent to `numpy.quantile(a, q, axis=axis, method=method)` for the
379/// five classic methods exposed via [`QuantileMethod`]. `q` must be in
380/// \[0, 1\]. Added for #462.
381///
382/// # Errors
383/// - `FerrayError::InvalidValue` if `q` is outside \[0, 1\] or the array
384///   is empty.
385/// - `FerrayError::AxisOutOfBounds` if `axis` is out of range.
386pub fn quantile_with_method<T, D>(
387    a: &Array<T, D>,
388    q: T,
389    axis: Option<usize>,
390    method: QuantileMethod,
391) -> FerrayResult<Array<T, IxDyn>>
392where
393    T: Element + Float,
394    D: Dimension,
395{
396    if q < <T as Element>::zero() || q > <T as Element>::one() {
397        return Err(FerrayError::invalid_value("quantile q must be in [0, 1]"));
398    }
399    if a.is_empty() {
400        return Err(FerrayError::invalid_value(
401            "cannot compute quantile of empty array",
402        ));
403    }
404    let data = collect_data(a);
405    match axis {
406        None => {
407            let val = lane_quantile_with_method(&data, q, method);
408            make_result(&[], vec![val])
409        }
410        Some(ax) => {
411            validate_axis(ax, a.ndim())?;
412            let shape = a.shape();
413            let out_s = output_shape(shape, ax);
414            let result = reduce_axis_general(&data, shape, ax, |lane| {
415                lane_quantile_with_method(lane, q, method)
416            });
417            make_result(&out_s, result)
418        }
419    }
420}
421
422// ---------------------------------------------------------------------------
423// percentile
424// ---------------------------------------------------------------------------
425
426/// Compute the q-th percentile of array data along a given axis.
427///
428/// `q` must be in \[0, 100\]. Uses linear interpolation. See
429/// [`percentile_with_method`] for the variant that accepts a
430/// [`QuantileMethod`] selector.
431///
432/// Equivalent to `numpy.percentile`.
433pub fn percentile<T, D>(a: &Array<T, D>, q: T, axis: Option<usize>) -> FerrayResult<Array<T, IxDyn>>
434where
435    T: Element + Float,
436    D: Dimension,
437{
438    percentile_with_method(a, q, axis, QuantileMethod::Linear)
439}
440
441/// Compute the q-th percentile of array data along a given axis using a
442/// specific interpolation method.
443///
444/// `q` must be in \[0, 100\]. Equivalent to
445/// `numpy.percentile(a, q, axis=axis, method=method)` for the five
446/// classic methods exposed via [`QuantileMethod`].
447pub fn percentile_with_method<T, D>(
448    a: &Array<T, D>,
449    q: T,
450    axis: Option<usize>,
451    method: QuantileMethod,
452) -> FerrayResult<Array<T, IxDyn>>
453where
454    T: Element + Float,
455    D: Dimension,
456{
457    let hundred = T::from(100.0).unwrap();
458    if q < <T as Element>::zero() || q > hundred {
459        return Err(FerrayError::invalid_value(
460            "percentile q must be in [0, 100]",
461        ));
462    }
463    quantile_with_method(a, q / hundred, axis, method)
464}
465
466// ---------------------------------------------------------------------------
467// median
468// ---------------------------------------------------------------------------
469
470/// Compute the median of array elements along a given axis.
471///
472/// Equivalent to `numpy.median`.
473pub fn median<T, D>(a: &Array<T, D>, axis: Option<usize>) -> FerrayResult<Array<T, IxDyn>>
474where
475    T: Element + Float,
476    D: Dimension,
477{
478    let half = T::from(0.5).unwrap();
479    quantile(a, half, axis)
480}
481
482// ---------------------------------------------------------------------------
483// NaN-aware variants
484// ---------------------------------------------------------------------------
485
486/// Median, skipping NaN values.
487///
488/// Equivalent to `numpy.nanmedian`.
489pub fn nanmedian<T, D>(a: &Array<T, D>, axis: Option<usize>) -> FerrayResult<Array<T, IxDyn>>
490where
491    T: Element + Float,
492    D: Dimension,
493{
494    let half = T::from(0.5).unwrap();
495    nanquantile(a, half, axis)
496}
497
498/// Percentile, skipping NaN values.
499///
500/// Equivalent to `numpy.nanpercentile`.
501pub fn nanpercentile<T, D>(
502    a: &Array<T, D>,
503    q: T,
504    axis: Option<usize>,
505) -> FerrayResult<Array<T, IxDyn>>
506where
507    T: Element + Float,
508    D: Dimension,
509{
510    let hundred = T::from(100.0).unwrap();
511    if q < <T as Element>::zero() || q > hundred {
512        return Err(FerrayError::invalid_value(
513            "nanpercentile q must be in [0, 100]",
514        ));
515    }
516    nanquantile(a, q / hundred, axis)
517}
518
519/// Quantile, skipping NaN values. Equivalent to `numpy.nanquantile`
520/// (#93 — was previously private, only accessible indirectly through
521/// `nanmedian`/`nanpercentile`).
522pub fn nanquantile<T, D>(
523    a: &Array<T, D>,
524    q: T,
525    axis: Option<usize>,
526) -> FerrayResult<Array<T, IxDyn>>
527where
528    T: Element + Float,
529    D: Dimension,
530{
531    if q < <T as Element>::zero() || q > <T as Element>::one() {
532        return Err(FerrayError::invalid_value("quantile q must be in [0, 1]"));
533    }
534    if a.is_empty() {
535        return Err(FerrayError::invalid_value(
536            "cannot compute nanquantile of empty array",
537        ));
538    }
539    let data = collect_data(a);
540    match axis {
541        None => {
542            let val = lane_nanquantile(&data, q);
543            make_result(&[], vec![val])
544        }
545        Some(ax) => {
546            validate_axis(ax, a.ndim())?;
547            let shape = a.shape();
548            let out_s = output_shape(shape, ax);
549            let result = reduce_axis_general(&data, shape, ax, |lane| lane_nanquantile(lane, q));
550            make_result(&out_s, result)
551        }
552    }
553}
554
555#[cfg(test)]
556mod tests {
557    use super::*;
558    use ferray_core::Ix1;
559
560    #[test]
561    fn test_median_odd() {
562        let a = Array::<f64, Ix1>::from_vec(Ix1::new([5]), vec![5.0, 1.0, 3.0, 2.0, 4.0]).unwrap();
563        let m = median(&a, None).unwrap();
564        assert!((m.iter().next().unwrap() - 3.0).abs() < 1e-12);
565    }
566
567    #[test]
568    fn test_median_even() {
569        let a = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![4.0, 1.0, 3.0, 2.0]).unwrap();
570        let m = median(&a, None).unwrap();
571        assert!((m.iter().next().unwrap() - 2.5).abs() < 1e-12);
572    }
573
574    #[test]
575    fn test_percentile_0_50_100() {
576        let a = Array::<f64, Ix1>::from_vec(Ix1::new([5]), vec![1.0, 2.0, 3.0, 4.0, 5.0]).unwrap();
577        let p0 = percentile(&a, 0.0, None).unwrap();
578        let p50 = percentile(&a, 50.0, None).unwrap();
579        let p100 = percentile(&a, 100.0, None).unwrap();
580        assert!((p0.iter().next().unwrap() - 1.0).abs() < 1e-12);
581        assert!((p50.iter().next().unwrap() - 3.0).abs() < 1e-12);
582        assert!((p100.iter().next().unwrap() - 5.0).abs() < 1e-12);
583    }
584
585    #[test]
586    fn test_quantile_bounds() {
587        let a = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
588        assert!(quantile(&a, -0.1, None).is_err());
589        assert!(quantile(&a, 1.1, None).is_err());
590    }
591
592    #[test]
593    fn test_quantile_interpolation() {
594        let a = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
595        let q = quantile(&a, 0.25, None).unwrap();
596        // index = 0.25 * 3 = 0.75, interp between 1.0 and 2.0 -> 1.75
597        assert!((q.iter().next().unwrap() - 1.75).abs() < 1e-12);
598    }
599
600    #[test]
601    fn test_nanmedian() {
602        let a = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![1.0, f64::NAN, 3.0, 5.0]).unwrap();
603        let m = nanmedian(&a, None).unwrap();
604        // non-nan sorted: [1, 3, 5], median = 3.0
605        assert!((m.iter().next().unwrap() - 3.0).abs() < 1e-12);
606    }
607
608    #[test]
609    fn test_nanpercentile() {
610        let a = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![1.0, f64::NAN, 3.0, 5.0]).unwrap();
611        let p = nanpercentile(&a, 50.0, None).unwrap();
612        assert!((p.iter().next().unwrap() - 3.0).abs() < 1e-12);
613    }
614
615    #[test]
616    fn test_nanmedian_all_nan() {
617        let a = Array::<f64, Ix1>::from_vec(Ix1::new([2]), vec![f64::NAN, f64::NAN]).unwrap();
618        let m = nanmedian(&a, None).unwrap();
619        assert!(m.iter().next().unwrap().is_nan());
620    }
621
622    // ---- quantile interpolation methods (#462) ----
623
624    fn arr_1_5() -> Array<f64, Ix1> {
625        Array::<f64, Ix1>::from_vec(Ix1::new([5]), vec![1.0, 2.0, 3.0, 4.0, 5.0]).unwrap()
626    }
627
628    #[test]
629    fn test_quantile_method_linear_matches_legacy() {
630        // Default quantile uses Linear; explicit Linear must match.
631        let a = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
632        let legacy = quantile(&a, 0.25, None).unwrap();
633        let with_flag = quantile_with_method(&a, 0.25, None, QuantileMethod::Linear).unwrap();
634        assert_eq!(
635            legacy.iter().next().unwrap(),
636            with_flag.iter().next().unwrap()
637        );
638    }
639
640    #[test]
641    fn test_quantile_method_lower() {
642        // n=5, q=0.25 → idx=1.0 → lo_i=1, frac=0.0 (integer index)
643        // All methods agree: result = 2.0
644        let a = arr_1_5();
645        let q = quantile_with_method(&a, 0.25, None, QuantileMethod::Lower).unwrap();
646        assert!((q.iter().next().unwrap() - 2.0).abs() < 1e-12);
647
648        // n=4, q=0.25 → idx=0.75, lo_i=0, frac=0.75 → Lower = lo_val = 1.0
649        let a4 = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
650        let q = quantile_with_method(&a4, 0.25, None, QuantileMethod::Lower).unwrap();
651        assert!((q.iter().next().unwrap() - 1.0).abs() < 1e-12);
652    }
653
654    #[test]
655    fn test_quantile_method_higher() {
656        // n=4, q=0.25 → idx=0.75 → Higher = hi_val = 2.0
657        let a = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
658        let q = quantile_with_method(&a, 0.25, None, QuantileMethod::Higher).unwrap();
659        assert!((q.iter().next().unwrap() - 2.0).abs() < 1e-12);
660    }
661
662    #[test]
663    fn test_quantile_method_nearest_round_down() {
664        // n=4, q=0.2 → idx=0.6, frac=0.6 > 0.5 → pick hi_val (index 1) = 2.0
665        let a = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
666        let q = quantile_with_method(&a, 0.2, None, QuantileMethod::Nearest).unwrap();
667        assert!((q.iter().next().unwrap() - 2.0).abs() < 1e-12);
668
669        // n=4, q=0.1 → idx=0.3, frac=0.3 < 0.5 → pick lo_val (index 0) = 1.0
670        let q2 = quantile_with_method(&a, 0.1, None, QuantileMethod::Nearest).unwrap();
671        assert!((q2.iter().next().unwrap() - 1.0).abs() < 1e-12);
672    }
673
674    #[test]
675    fn test_quantile_method_nearest_tie_even() {
676        // n=5, q=0.125 → idx=0.5, frac=0.5, lo_i=0 (even) → pick lo_val = 1.0
677        let a = arr_1_5();
678        let q = quantile_with_method(&a, 0.125, None, QuantileMethod::Nearest).unwrap();
679        assert!((q.iter().next().unwrap() - 1.0).abs() < 1e-12);
680
681        // n=5, q=0.375 → idx=1.5, frac=0.5, lo_i=1 (odd) → pick hi_val = 3.0
682        let q2 = quantile_with_method(&a, 0.375, None, QuantileMethod::Nearest).unwrap();
683        assert!((q2.iter().next().unwrap() - 3.0).abs() < 1e-12);
684    }
685
686    #[test]
687    fn test_quantile_method_midpoint() {
688        // n=4, q=0.25 → idx=0.75, lo_val=1.0, hi_val=2.0 → midpoint = 1.5
689        let a = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
690        let q = quantile_with_method(&a, 0.25, None, QuantileMethod::Midpoint).unwrap();
691        assert!((q.iter().next().unwrap() - 1.5).abs() < 1e-12);
692
693        // n=4, q=0.75 → idx=2.25, lo_val=3.0, hi_val=4.0 → midpoint = 3.5
694        let q2 = quantile_with_method(&a, 0.75, None, QuantileMethod::Midpoint).unwrap();
695        assert!((q2.iter().next().unwrap() - 3.5).abs() < 1e-12);
696    }
697
698    #[test]
699    fn test_quantile_method_integer_index_all_agree() {
700        // n=5, q=0.5 → idx=2.0, exactly on sorted[2]. All five methods
701        // must return the same value.
702        let a = arr_1_5();
703        let linear = quantile_with_method(&a, 0.5, None, QuantileMethod::Linear).unwrap();
704        let lower = quantile_with_method(&a, 0.5, None, QuantileMethod::Lower).unwrap();
705        let higher = quantile_with_method(&a, 0.5, None, QuantileMethod::Higher).unwrap();
706        let nearest = quantile_with_method(&a, 0.5, None, QuantileMethod::Nearest).unwrap();
707        let midpoint = quantile_with_method(&a, 0.5, None, QuantileMethod::Midpoint).unwrap();
708        let expected = 3.0;
709        assert!((linear.iter().next().unwrap() - expected).abs() < 1e-12);
710        assert!((lower.iter().next().unwrap() - expected).abs() < 1e-12);
711        assert!((higher.iter().next().unwrap() - expected).abs() < 1e-12);
712        assert!((nearest.iter().next().unwrap() - expected).abs() < 1e-12);
713        assert!((midpoint.iter().next().unwrap() - expected).abs() < 1e-12);
714    }
715
716    #[test]
717    fn test_quantile_method_axis_variant() {
718        // Per-row quantile with a non-linear method.
719        use ferray_core::Ix2;
720        let a = Array::<f64, Ix2>::from_vec(
721            Ix2::new([2, 4]),
722            vec![1.0, 2.0, 3.0, 4.0, 10.0, 20.0, 30.0, 40.0],
723        )
724        .unwrap();
725        // q=0.25, n=4, idx=0.75 → Lower picks lo_val (index 0).
726        let r = quantile_with_method(&a, 0.25, Some(1), QuantileMethod::Lower).unwrap();
727        assert_eq!(r.as_slice().unwrap(), &[1.0, 10.0]);
728    }
729
730    #[test]
731    fn test_percentile_with_method_50() {
732        let a = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
733        // q=50 (percentile) → 0.5 (quantile), n=4, idx=1.5
734        // Linear: 2.5, Lower: 2.0, Higher: 3.0, Nearest (tie, lo_i=1 odd): 3.0, Midpoint: 2.5
735        let lin = percentile_with_method(&a, 50.0, None, QuantileMethod::Linear).unwrap();
736        let lo = percentile_with_method(&a, 50.0, None, QuantileMethod::Lower).unwrap();
737        let hi = percentile_with_method(&a, 50.0, None, QuantileMethod::Higher).unwrap();
738        let nr = percentile_with_method(&a, 50.0, None, QuantileMethod::Nearest).unwrap();
739        let mp = percentile_with_method(&a, 50.0, None, QuantileMethod::Midpoint).unwrap();
740        assert!((lin.iter().next().unwrap() - 2.5).abs() < 1e-12);
741        assert!((lo.iter().next().unwrap() - 2.0).abs() < 1e-12);
742        assert!((hi.iter().next().unwrap() - 3.0).abs() < 1e-12);
743        assert!((nr.iter().next().unwrap() - 3.0).abs() < 1e-12);
744        assert!((mp.iter().next().unwrap() - 2.5).abs() < 1e-12);
745    }
746
747    #[test]
748    fn test_percentile_with_method_rejects_out_of_range() {
749        let a = arr_1_5();
750        assert!(percentile_with_method(&a, -1.0, None, QuantileMethod::Linear).is_err());
751        assert!(percentile_with_method(&a, 101.0, None, QuantileMethod::Linear).is_err());
752    }
753
754    // ---- remaining 8 NumPy quantile methods (#566) ----
755    //
756    // Hand-verified expected values come from Hyndman & Fan 1996 / NumPy
757    // source. For continuous methods the virtual index is
758    //   vidx = n*q + alpha + q*(1 - alpha - beta) - 1
759    // and the result is (1 - gamma) * sorted[lo_i] + gamma * sorted[lo_i+1]
760    // with gamma = frac(vidx_clamped), lo_i = floor(vidx_clamped).
761
762    fn arr_1_4() -> Array<f64, Ix1> {
763        Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![1.0, 2.0, 3.0, 4.0]).unwrap()
764    }
765
766    #[test]
767    fn test_quantile_weibull_q_half() {
768        // n=4, q=0.5, alpha=beta=0:
769        //   vidx = 4*0.5 + 0 + 0.5*(1 - 0 - 0) - 1 = 2 + 0.5 - 1 = 1.5
770        //   lo_i=1, gamma=0.5 → 0.5*sorted[1] + 0.5*sorted[2] = 0.5*2 + 0.5*3 = 2.5
771        let a = arr_1_4();
772        let q = quantile_with_method(&a, 0.5, None, QuantileMethod::Weibull).unwrap();
773        assert!((q.iter().next().copied().unwrap() - 2.5).abs() < 1e-12);
774    }
775
776    #[test]
777    fn test_quantile_weibull_q_quarter() {
778        // n=4, q=0.25, alpha=beta=0:
779        //   vidx = 4*0.25 + 0 + 0.25 - 1 = 0.25
780        //   lo_i=0, gamma=0.25 → 0.75*1 + 0.25*2 = 1.25
781        let a = arr_1_4();
782        let q = quantile_with_method(&a, 0.25, None, QuantileMethod::Weibull).unwrap();
783        assert!((q.iter().next().copied().unwrap() - 1.25).abs() < 1e-12);
784    }
785
786    #[test]
787    fn test_quantile_hazen_q_quarter() {
788        // n=4, q=0.25, alpha=beta=0.5:
789        //   vidx = 4*0.25 + 0.5 + 0.25*(1 - 1) - 1 = 1 + 0.5 - 1 = 0.5
790        //   lo_i=0, gamma=0.5 → 0.5*1 + 0.5*2 = 1.5
791        let a = arr_1_4();
792        let q = quantile_with_method(&a, 0.25, None, QuantileMethod::Hazen).unwrap();
793        assert!((q.iter().next().copied().unwrap() - 1.5).abs() < 1e-12);
794    }
795
796    #[test]
797    fn test_quantile_median_unbiased_q_half() {
798        // n=4, q=0.5, alpha=beta=1/3:
799        //   vidx = 4*0.5 + 1/3 + 0.5*(1 - 2/3) - 1
800        //        = 2 + 1/3 + 0.5*(1/3) - 1
801        //        = 2 + 1/3 + 1/6 - 1
802        //        = 2 + 0.5 - 1 = 1.5
803        //   lo_i=1, gamma=0.5 → 2.5 (same as Linear's median for n=4)
804        let a = arr_1_4();
805        let q = quantile_with_method(&a, 0.5, None, QuantileMethod::MedianUnbiased).unwrap();
806        assert!((q.iter().next().copied().unwrap() - 2.5).abs() < 1e-12);
807    }
808
809    #[test]
810    fn test_quantile_normal_unbiased_q_half() {
811        // n=4, q=0.5, alpha=beta=3/8:
812        //   vidx = 4*0.5 + 3/8 + 0.5*(1 - 6/8) - 1
813        //        = 2 + 0.375 + 0.5*0.25 - 1
814        //        = 2 + 0.375 + 0.125 - 1
815        //        = 1.5
816        //   → 2.5 (matches Linear median at n=4)
817        let a = arr_1_4();
818        let q = quantile_with_method(&a, 0.5, None, QuantileMethod::NormalUnbiased).unwrap();
819        assert!((q.iter().next().copied().unwrap() - 2.5).abs() < 1e-12);
820    }
821
822    #[test]
823    fn test_quantile_interpolated_inverted_cdf_q_half() {
824        // n=4, q=0.5, alpha=0, beta=1:
825        //   vidx = 4*0.5 + 0 + 0.5*(1 - 0 - 1) - 1
826        //        = 2 + 0 + 0 - 1 = 1
827        //   lo_i=1, gamma=0 → sorted[1] = 2
828        // This is different from the median-family methods.
829        let a = arr_1_4();
830        let q =
831            quantile_with_method(&a, 0.5, None, QuantileMethod::InterpolatedInvertedCdf).unwrap();
832        assert!((q.iter().next().copied().unwrap() - 2.0).abs() < 1e-12);
833    }
834
835    #[test]
836    fn test_quantile_inverted_cdf_q_half() {
837        // n=4, q=0.5: nq=2, ceil-1=1 → sorted[1] = 2
838        let a = arr_1_4();
839        let q = quantile_with_method(&a, 0.5, None, QuantileMethod::InvertedCdf).unwrap();
840        assert!((q.iter().next().copied().unwrap() - 2.0).abs() < 1e-12);
841    }
842
843    #[test]
844    fn test_quantile_inverted_cdf_step_function() {
845        // n=5, q values straddling the k/n steps:
846        //   q=0.19 → nq=0.95 → ceil-1 = 0 → sorted[0] = 1
847        //   q=0.21 → nq=1.05 → ceil-1 = 1 → sorted[1] = 2
848        let a = arr_1_5();
849        let q1 = quantile_with_method(&a, 0.19, None, QuantileMethod::InvertedCdf).unwrap();
850        assert!((q1.iter().next().copied().unwrap() - 1.0).abs() < 1e-12);
851        let q2 = quantile_with_method(&a, 0.21, None, QuantileMethod::InvertedCdf).unwrap();
852        assert!((q2.iter().next().copied().unwrap() - 2.0).abs() < 1e-12);
853    }
854
855    #[test]
856    fn test_quantile_averaged_inverted_cdf_integer_nq_averages() {
857        // n=4, q=0.5 → nq=2 (integer) → average of sorted[1] and sorted[2]
858        // = 0.5*2 + 0.5*3 = 2.5
859        let a = arr_1_4();
860        let q = quantile_with_method(&a, 0.5, None, QuantileMethod::AveragedInvertedCdf).unwrap();
861        assert!((q.iter().next().copied().unwrap() - 2.5).abs() < 1e-12);
862    }
863
864    #[test]
865    fn test_quantile_averaged_inverted_cdf_non_integer_nq_matches_inverted_cdf() {
866        // n=5, q=0.3 → nq=1.5 (non-integer) → same as InvertedCdf:
867        //   ceil(1.5) - 1 = 1 → sorted[1] = 2
868        let a = arr_1_5();
869        let q1 = quantile_with_method(&a, 0.3, None, QuantileMethod::AveragedInvertedCdf).unwrap();
870        let q2 = quantile_with_method(&a, 0.3, None, QuantileMethod::InvertedCdf).unwrap();
871        assert_eq!(
872            q1.iter().next().copied().unwrap(),
873            q2.iter().next().copied().unwrap()
874        );
875        assert!((q1.iter().next().copied().unwrap() - 2.0).abs() < 1e-12);
876    }
877
878    #[test]
879    fn test_quantile_closest_observation_half_to_even() {
880        // NumPy virtual index = n*q - 1.5.
881        // n=4, q=0.5: idx = 2 - 1.5 = 0.5, gamma=0.5 != 0 → next =
882        //   floor(0.5)+1 = 1 → sorted[1] = 2 (numpy 2.4.x oracle).
883        let a = arr_1_4();
884        let q = quantile_with_method(&a, 0.5, None, QuantileMethod::ClosestObservation).unwrap();
885        assert!((q.iter().next().copied().unwrap() - 2.0).abs() < 1e-12);
886
887        // n=4, q=0.125: idx = 0.5 - 1.5 = -1.0, gamma=0, floor=-1 (odd) →
888        //   keep previous = -1 → clipped to 0 → sorted[0] = 1.
889        let q2 = quantile_with_method(&a, 0.125, None, QuantileMethod::ClosestObservation).unwrap();
890        assert!((q2.iter().next().copied().unwrap() - 1.0).abs() < 1e-12);
891    }
892
893    #[test]
894    fn test_quantile_closest_observation_nq_0_875_rounds_up() {
895        // n=4, q=0.875: idx = 3.5 - 1.5 = 2.0, gamma=0, floor=2 (even) →
896        //   not corrected → next = 3 → sorted[3] = 4 (numpy 2.4.x oracle).
897        let a = arr_1_4();
898        let q = quantile_with_method(&a, 0.875, None, QuantileMethod::ClosestObservation).unwrap();
899        assert!((q.iter().next().copied().unwrap() - 4.0).abs() < 1e-12);
900    }
901
902    #[test]
903    fn test_quantile_continuous_methods_agree_at_q_0_and_q_1() {
904        // At q=0 all continuous methods should return the min; at q=1 all
905        // should return the max (clamping). This is a sanity check that
906        // the virtual-index clamp works in every branch.
907        let a = arr_1_5();
908        let methods = [
909            QuantileMethod::Linear,
910            QuantileMethod::Weibull,
911            QuantileMethod::Hazen,
912            QuantileMethod::InterpolatedInvertedCdf,
913            QuantileMethod::MedianUnbiased,
914            QuantileMethod::NormalUnbiased,
915        ];
916        for &m in &methods {
917            let q0 = quantile_with_method(&a, 0.0, None, m).unwrap();
918            let q1 = quantile_with_method(&a, 1.0, None, m).unwrap();
919            assert!(
920                (q0.iter().next().copied().unwrap() - 1.0).abs() < 1e-12,
921                "method {m:?} at q=0 should be min"
922            );
923            assert!(
924                (q1.iter().next().copied().unwrap() - 5.0).abs() < 1e-12,
925                "method {m:?} at q=1 should be max"
926            );
927        }
928    }
929
930    #[test]
931    fn test_quantile_discrete_methods_agree_at_q_1() {
932        // At q=1.0 every method returns the max.
933        let a = arr_1_5();
934        let methods = [
935            QuantileMethod::InvertedCdf,
936            QuantileMethod::AveragedInvertedCdf,
937            QuantileMethod::ClosestObservation,
938        ];
939        for &m in &methods {
940            let q = quantile_with_method(&a, 1.0, None, m).unwrap();
941            assert!(
942                (q.iter().next().copied().unwrap() - 5.0).abs() < 1e-12,
943                "method {m:?} at q=1 should be max"
944            );
945        }
946    }
947
948    #[test]
949    fn test_quantile_all_13_methods_at_integer_index_agree() {
950        // When the virtual index lands on a real integer position,
951        // EVERY method should return that exact sorted element because
952        // each method's dispatch produces gamma=0 or the continuous
953        // formula yields fractional = 0. On n=5 with q=0.5, the linear
954        // virtual index is exactly 2.0 → sorted[2] = 3.
955        let a = arr_1_5();
956        let all_methods = [
957            QuantileMethod::Linear,
958            QuantileMethod::Lower,
959            QuantileMethod::Higher,
960            QuantileMethod::Nearest,
961            QuantileMethod::Midpoint,
962            QuantileMethod::Weibull,
963            QuantileMethod::Hazen,
964            QuantileMethod::MedianUnbiased,
965            QuantileMethod::NormalUnbiased,
966            // These three use different virtual indices so they MAY
967            // disagree at q=0.5 even for odd n; check them separately.
968        ];
969        for &m in &all_methods {
970            let r = quantile_with_method(&a, 0.5, None, m).unwrap();
971            assert!(
972                (r.iter().next().copied().unwrap() - 3.0).abs() < 1e-12,
973                "method {m:?} at odd n, q=0.5 should be 3.0"
974            );
975        }
976    }
977
978    #[test]
979    fn test_quantile_method_axis_variant_weibull() {
980        use ferray_core::Ix2;
981        // (2, 4) rows; per-row Weibull quantile at q=0.5 on [1,2,3,4]
982        // is 2.5, and on [10,20,30,40] is 25.0.
983        let a = Array::<f64, Ix2>::from_vec(
984            Ix2::new([2, 4]),
985            vec![1.0, 2.0, 3.0, 4.0, 10.0, 20.0, 30.0, 40.0],
986        )
987        .unwrap();
988        let r = quantile_with_method(&a, 0.5, Some(1), QuantileMethod::Weibull).unwrap();
989        assert_eq!(r.shape(), &[2]);
990        let s = r.as_slice().unwrap();
991        assert!((s[0] - 2.5).abs() < 1e-12);
992        assert!((s[1] - 25.0).abs() < 1e-12);
993    }
994}