1use ferray_core::error::{FerrayError, FerrayResult};
5use ferray_core::{Array, Dimension, Element, IxDyn};
6use num_traits::Float;
7
8use super::{collect_data, make_result, output_shape, reduce_axis_general, validate_axis};
9
10fn quantile_sorted<T: Float>(sorted: &[T], q: T) -> T {
17 let n = sorted.len();
18 if n == 0 {
19 return T::nan();
20 }
21 if n == 1 {
22 return sorted[0];
23 }
24 let idx_f = q * T::from(n - 1).unwrap();
25 let lo = idx_f.floor();
26 let hi = idx_f.ceil();
27 let lo_i = lo.to_usize().unwrap().min(n - 1);
28 let hi_i = hi.to_usize().unwrap().min(n - 1);
29 if lo_i == hi_i {
30 sorted[lo_i]
31 } else {
32 let frac = idx_f - lo;
33 sorted[lo_i] * (T::one() - frac) + sorted[hi_i] * frac
34 }
35}
36
37fn partial_sort<T: Float>(data: &mut [T]) {
39 data.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
40}
41
42fn lane_quantile<T: Float>(lane: &[T], q: T) -> T {
44 let mut sorted: Vec<T> = lane.to_vec();
45 partial_sort(&mut sorted);
46 quantile_sorted(&sorted, q)
47}
48
49fn lane_nanquantile<T: Float>(lane: &[T], q: T) -> T {
51 let mut sorted: Vec<T> = lane.iter().copied().filter(|x| !x.is_nan()).collect();
52 if sorted.is_empty() {
53 return T::nan();
54 }
55 partial_sort(&mut sorted);
56 quantile_sorted(&sorted, q)
57}
58
59pub fn quantile<T, D>(a: &Array<T, D>, q: T, axis: Option<usize>) -> FerrayResult<Array<T, IxDyn>>
68where
69 T: Element + Float,
70 D: Dimension,
71{
72 if q < <T as Element>::zero() || q > <T as Element>::one() {
73 return Err(FerrayError::invalid_value("quantile q must be in [0, 1]"));
74 }
75 if a.is_empty() {
76 return Err(FerrayError::invalid_value(
77 "cannot compute quantile of empty array",
78 ));
79 }
80 let data = collect_data(a);
81 match axis {
82 None => {
83 let val = lane_quantile(&data, q);
84 make_result(&[], vec![val])
85 }
86 Some(ax) => {
87 validate_axis(ax, a.ndim())?;
88 let shape = a.shape();
89 let out_s = output_shape(shape, ax);
90 let result = reduce_axis_general(&data, shape, ax, |lane| lane_quantile(lane, q));
91 make_result(&out_s, result)
92 }
93 }
94}
95
96pub fn percentile<T, D>(a: &Array<T, D>, q: T, axis: Option<usize>) -> FerrayResult<Array<T, IxDyn>>
105where
106 T: Element + Float,
107 D: Dimension,
108{
109 let hundred = T::from(100.0).unwrap();
110 if q < <T as Element>::zero() || q > hundred {
111 return Err(FerrayError::invalid_value(
112 "percentile q must be in [0, 100]",
113 ));
114 }
115 quantile(a, q / hundred, axis)
116}
117
118pub fn median<T, D>(a: &Array<T, D>, axis: Option<usize>) -> FerrayResult<Array<T, IxDyn>>
126where
127 T: Element + Float,
128 D: Dimension,
129{
130 let half = T::from(0.5).unwrap();
131 quantile(a, half, axis)
132}
133
134pub fn nanmedian<T, D>(a: &Array<T, D>, axis: Option<usize>) -> FerrayResult<Array<T, IxDyn>>
142where
143 T: Element + Float,
144 D: Dimension,
145{
146 let half = T::from(0.5).unwrap();
147 nanquantile(a, half, axis)
148}
149
150pub fn nanpercentile<T, D>(
154 a: &Array<T, D>,
155 q: T,
156 axis: Option<usize>,
157) -> FerrayResult<Array<T, IxDyn>>
158where
159 T: Element + Float,
160 D: Dimension,
161{
162 let hundred = T::from(100.0).unwrap();
163 if q < <T as Element>::zero() || q > hundred {
164 return Err(FerrayError::invalid_value(
165 "nanpercentile q must be in [0, 100]",
166 ));
167 }
168 nanquantile(a, q / hundred, axis)
169}
170
171fn nanquantile<T, D>(a: &Array<T, D>, q: T, axis: Option<usize>) -> FerrayResult<Array<T, IxDyn>>
173where
174 T: Element + Float,
175 D: Dimension,
176{
177 if q < <T as Element>::zero() || q > <T as Element>::one() {
178 return Err(FerrayError::invalid_value("quantile q must be in [0, 1]"));
179 }
180 if a.is_empty() {
181 return Err(FerrayError::invalid_value(
182 "cannot compute nanquantile of empty array",
183 ));
184 }
185 let data = collect_data(a);
186 match axis {
187 None => {
188 let val = lane_nanquantile(&data, q);
189 make_result(&[], vec![val])
190 }
191 Some(ax) => {
192 validate_axis(ax, a.ndim())?;
193 let shape = a.shape();
194 let out_s = output_shape(shape, ax);
195 let result = reduce_axis_general(&data, shape, ax, |lane| lane_nanquantile(lane, q));
196 make_result(&out_s, result)
197 }
198 }
199}
200
201#[cfg(test)]
202mod tests {
203 use super::*;
204 use ferray_core::Ix1;
205
206 #[test]
207 fn test_median_odd() {
208 let a = Array::<f64, Ix1>::from_vec(Ix1::new([5]), vec![5.0, 1.0, 3.0, 2.0, 4.0]).unwrap();
209 let m = median(&a, None).unwrap();
210 assert!((m.iter().next().unwrap() - 3.0).abs() < 1e-12);
211 }
212
213 #[test]
214 fn test_median_even() {
215 let a = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![4.0, 1.0, 3.0, 2.0]).unwrap();
216 let m = median(&a, None).unwrap();
217 assert!((m.iter().next().unwrap() - 2.5).abs() < 1e-12);
218 }
219
220 #[test]
221 fn test_percentile_0_50_100() {
222 let a = Array::<f64, Ix1>::from_vec(Ix1::new([5]), vec![1.0, 2.0, 3.0, 4.0, 5.0]).unwrap();
223 let p0 = percentile(&a, 0.0, None).unwrap();
224 let p50 = percentile(&a, 50.0, None).unwrap();
225 let p100 = percentile(&a, 100.0, None).unwrap();
226 assert!((p0.iter().next().unwrap() - 1.0).abs() < 1e-12);
227 assert!((p50.iter().next().unwrap() - 3.0).abs() < 1e-12);
228 assert!((p100.iter().next().unwrap() - 5.0).abs() < 1e-12);
229 }
230
231 #[test]
232 fn test_quantile_bounds() {
233 let a = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
234 assert!(quantile(&a, -0.1, None).is_err());
235 assert!(quantile(&a, 1.1, None).is_err());
236 }
237
238 #[test]
239 fn test_quantile_interpolation() {
240 let a = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
241 let q = quantile(&a, 0.25, None).unwrap();
242 assert!((q.iter().next().unwrap() - 1.75).abs() < 1e-12);
244 }
245
246 #[test]
247 fn test_nanmedian() {
248 let a = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![1.0, f64::NAN, 3.0, 5.0]).unwrap();
249 let m = nanmedian(&a, None).unwrap();
250 assert!((m.iter().next().unwrap() - 3.0).abs() < 1e-12);
252 }
253
254 #[test]
255 fn test_nanpercentile() {
256 let a = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![1.0, f64::NAN, 3.0, 5.0]).unwrap();
257 let p = nanpercentile(&a, 50.0, None).unwrap();
258 assert!((p.iter().next().unwrap() - 3.0).abs() < 1e-12);
259 }
260
261 #[test]
262 fn test_nanmedian_all_nan() {
263 let a = Array::<f64, Ix1>::from_vec(Ix1::new([2]), vec![f64::NAN, f64::NAN]).unwrap();
264 let m = nanmedian(&a, None).unwrap();
265 assert!(m.iter().next().unwrap().is_nan());
266 }
267}