Skip to main content

ferray_stats/reductions/
quantile.rs

1// ferray-stats: Quantile-based reductions — median, percentile, quantile (REQ-1)
2// Also nanmedian, nanpercentile (REQ-3)
3
4use ferray_core::error::{FerrayError, FerrayResult};
5use ferray_core::{Array, Dimension, Element, IxDyn};
6use num_traits::Float;
7
8use super::{collect_data, make_result, output_shape, reduce_axis_general, validate_axis};
9
10// ---------------------------------------------------------------------------
11// Helpers
12// ---------------------------------------------------------------------------
13
14/// Compute a single quantile value from a sorted slice using linear interpolation.
15/// `q` must be in [0, 1].
16fn quantile_sorted<T: Float>(sorted: &[T], q: T) -> T {
17    let n = sorted.len();
18    if n == 0 {
19        return T::nan();
20    }
21    if n == 1 {
22        return sorted[0];
23    }
24    let idx_f = q * T::from(n - 1).unwrap();
25    let lo = idx_f.floor();
26    let hi = idx_f.ceil();
27    let lo_i = lo.to_usize().unwrap().min(n - 1);
28    let hi_i = hi.to_usize().unwrap().min(n - 1);
29    if lo_i == hi_i {
30        sorted[lo_i]
31    } else {
32        let frac = idx_f - lo;
33        sorted[lo_i] * (T::one() - frac) + sorted[hi_i] * frac
34    }
35}
36
37/// Sort a mutable slice by partial_cmp, placing NaN at the end.
38fn partial_sort<T: Float>(data: &mut [T]) {
39    data.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
40}
41
42/// Sort and compute quantile from a lane.
43fn lane_quantile<T: Float>(lane: &[T], q: T) -> T {
44    let mut sorted: Vec<T> = lane.to_vec();
45    partial_sort(&mut sorted);
46    quantile_sorted(&sorted, q)
47}
48
49/// Sort (excluding NaN) and compute quantile from a lane.
50fn lane_nanquantile<T: Float>(lane: &[T], q: T) -> T {
51    let mut sorted: Vec<T> = lane.iter().copied().filter(|x| !x.is_nan()).collect();
52    if sorted.is_empty() {
53        return T::nan();
54    }
55    partial_sort(&mut sorted);
56    quantile_sorted(&sorted, q)
57}
58
59// ---------------------------------------------------------------------------
60// quantile
61// ---------------------------------------------------------------------------
62
63/// Compute the q-th quantile of array data along a given axis.
64///
65/// `q` must be in \[0, 1\]. Uses linear interpolation (NumPy default method).
66/// Equivalent to `numpy.quantile`.
67pub fn quantile<T, D>(a: &Array<T, D>, q: T, axis: Option<usize>) -> FerrayResult<Array<T, IxDyn>>
68where
69    T: Element + Float,
70    D: Dimension,
71{
72    if q < <T as Element>::zero() || q > <T as Element>::one() {
73        return Err(FerrayError::invalid_value("quantile q must be in [0, 1]"));
74    }
75    if a.is_empty() {
76        return Err(FerrayError::invalid_value(
77            "cannot compute quantile of empty array",
78        ));
79    }
80    let data = collect_data(a);
81    match axis {
82        None => {
83            let val = lane_quantile(&data, q);
84            make_result(&[], vec![val])
85        }
86        Some(ax) => {
87            validate_axis(ax, a.ndim())?;
88            let shape = a.shape();
89            let out_s = output_shape(shape, ax);
90            let result = reduce_axis_general(&data, shape, ax, |lane| lane_quantile(lane, q));
91            make_result(&out_s, result)
92        }
93    }
94}
95
96// ---------------------------------------------------------------------------
97// percentile
98// ---------------------------------------------------------------------------
99
100/// Compute the q-th percentile of array data along a given axis.
101///
102/// `q` must be in \[0, 100\].
103/// Equivalent to `numpy.percentile`.
104pub fn percentile<T, D>(a: &Array<T, D>, q: T, axis: Option<usize>) -> FerrayResult<Array<T, IxDyn>>
105where
106    T: Element + Float,
107    D: Dimension,
108{
109    let hundred = T::from(100.0).unwrap();
110    if q < <T as Element>::zero() || q > hundred {
111        return Err(FerrayError::invalid_value(
112            "percentile q must be in [0, 100]",
113        ));
114    }
115    quantile(a, q / hundred, axis)
116}
117
118// ---------------------------------------------------------------------------
119// median
120// ---------------------------------------------------------------------------
121
122/// Compute the median of array elements along a given axis.
123///
124/// Equivalent to `numpy.median`.
125pub fn median<T, D>(a: &Array<T, D>, axis: Option<usize>) -> FerrayResult<Array<T, IxDyn>>
126where
127    T: Element + Float,
128    D: Dimension,
129{
130    let half = T::from(0.5).unwrap();
131    quantile(a, half, axis)
132}
133
134// ---------------------------------------------------------------------------
135// NaN-aware variants
136// ---------------------------------------------------------------------------
137
138/// Median, skipping NaN values.
139///
140/// Equivalent to `numpy.nanmedian`.
141pub fn nanmedian<T, D>(a: &Array<T, D>, axis: Option<usize>) -> FerrayResult<Array<T, IxDyn>>
142where
143    T: Element + Float,
144    D: Dimension,
145{
146    let half = T::from(0.5).unwrap();
147    nanquantile(a, half, axis)
148}
149
150/// Percentile, skipping NaN values.
151///
152/// Equivalent to `numpy.nanpercentile`.
153pub fn nanpercentile<T, D>(
154    a: &Array<T, D>,
155    q: T,
156    axis: Option<usize>,
157) -> FerrayResult<Array<T, IxDyn>>
158where
159    T: Element + Float,
160    D: Dimension,
161{
162    let hundred = T::from(100.0).unwrap();
163    if q < <T as Element>::zero() || q > hundred {
164        return Err(FerrayError::invalid_value(
165            "nanpercentile q must be in [0, 100]",
166        ));
167    }
168    nanquantile(a, q / hundred, axis)
169}
170
171/// Quantile, skipping NaN values.
172fn nanquantile<T, D>(a: &Array<T, D>, q: T, axis: Option<usize>) -> FerrayResult<Array<T, IxDyn>>
173where
174    T: Element + Float,
175    D: Dimension,
176{
177    if q < <T as Element>::zero() || q > <T as Element>::one() {
178        return Err(FerrayError::invalid_value("quantile q must be in [0, 1]"));
179    }
180    if a.is_empty() {
181        return Err(FerrayError::invalid_value(
182            "cannot compute nanquantile of empty array",
183        ));
184    }
185    let data = collect_data(a);
186    match axis {
187        None => {
188            let val = lane_nanquantile(&data, q);
189            make_result(&[], vec![val])
190        }
191        Some(ax) => {
192            validate_axis(ax, a.ndim())?;
193            let shape = a.shape();
194            let out_s = output_shape(shape, ax);
195            let result = reduce_axis_general(&data, shape, ax, |lane| lane_nanquantile(lane, q));
196            make_result(&out_s, result)
197        }
198    }
199}
200
201#[cfg(test)]
202mod tests {
203    use super::*;
204    use ferray_core::Ix1;
205
206    #[test]
207    fn test_median_odd() {
208        let a = Array::<f64, Ix1>::from_vec(Ix1::new([5]), vec![5.0, 1.0, 3.0, 2.0, 4.0]).unwrap();
209        let m = median(&a, None).unwrap();
210        assert!((m.iter().next().unwrap() - 3.0).abs() < 1e-12);
211    }
212
213    #[test]
214    fn test_median_even() {
215        let a = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![4.0, 1.0, 3.0, 2.0]).unwrap();
216        let m = median(&a, None).unwrap();
217        assert!((m.iter().next().unwrap() - 2.5).abs() < 1e-12);
218    }
219
220    #[test]
221    fn test_percentile_0_50_100() {
222        let a = Array::<f64, Ix1>::from_vec(Ix1::new([5]), vec![1.0, 2.0, 3.0, 4.0, 5.0]).unwrap();
223        let p0 = percentile(&a, 0.0, None).unwrap();
224        let p50 = percentile(&a, 50.0, None).unwrap();
225        let p100 = percentile(&a, 100.0, None).unwrap();
226        assert!((p0.iter().next().unwrap() - 1.0).abs() < 1e-12);
227        assert!((p50.iter().next().unwrap() - 3.0).abs() < 1e-12);
228        assert!((p100.iter().next().unwrap() - 5.0).abs() < 1e-12);
229    }
230
231    #[test]
232    fn test_quantile_bounds() {
233        let a = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
234        assert!(quantile(&a, -0.1, None).is_err());
235        assert!(quantile(&a, 1.1, None).is_err());
236    }
237
238    #[test]
239    fn test_quantile_interpolation() {
240        let a = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
241        let q = quantile(&a, 0.25, None).unwrap();
242        // index = 0.25 * 3 = 0.75, interp between 1.0 and 2.0 -> 1.75
243        assert!((q.iter().next().unwrap() - 1.75).abs() < 1e-12);
244    }
245
246    #[test]
247    fn test_nanmedian() {
248        let a = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![1.0, f64::NAN, 3.0, 5.0]).unwrap();
249        let m = nanmedian(&a, None).unwrap();
250        // non-nan sorted: [1, 3, 5], median = 3.0
251        assert!((m.iter().next().unwrap() - 3.0).abs() < 1e-12);
252    }
253
254    #[test]
255    fn test_nanpercentile() {
256        let a = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![1.0, f64::NAN, 3.0, 5.0]).unwrap();
257        let p = nanpercentile(&a, 50.0, None).unwrap();
258        assert!((p.iter().next().unwrap() - 3.0).abs() < 1e-12);
259    }
260
261    #[test]
262    fn test_nanmedian_all_nan() {
263        let a = Array::<f64, Ix1>::from_vec(Ix1::new([2]), vec![f64::NAN, f64::NAN]).unwrap();
264        let m = nanmedian(&a, None).unwrap();
265        assert!(m.iter().next().unwrap().is_nan());
266    }
267}