polars_compute/rolling/no_nulls/
quantile.rs

1#![allow(unsafe_op_in_unsafe_fn)]
2use arrow::legacy::utils::CustomIterTools;
3use num_traits::ToPrimitive;
4use polars_error::polars_ensure;
5
6use super::QuantileMethod::*;
7use super::*;
8use crate::rolling::quantile_filter::SealedRolling;
9
10pub struct QuantileWindow<'a, T: NativeType> {
11    sorted: SortedBuf<'a, T>,
12    prob: f64,
13    method: QuantileMethod,
14}
15
16impl<
17    'a,
18    T: NativeType
19        + Float
20        + std::iter::Sum
21        + AddAssign
22        + SubAssign
23        + Div<Output = T>
24        + NumCast
25        + One
26        + Zero
27        + SealedRolling
28        + Sub<Output = T>,
29> RollingAggWindowNoNulls<'a, T> for QuantileWindow<'a, T>
30{
31    fn new(
32        slice: &'a [T],
33        start: usize,
34        end: usize,
35        params: Option<RollingFnParams>,
36        window_size: Option<usize>,
37    ) -> Self {
38        let params = params.unwrap();
39        let RollingFnParams::Quantile(params) = params else {
40            unreachable!("expected Quantile params");
41        };
42
43        Self {
44            sorted: SortedBuf::new(slice, start, end, window_size),
45            prob: params.prob,
46            method: params.method,
47        }
48    }
49
50    unsafe fn update(&mut self, start: usize, end: usize) -> Option<T> {
51        self.sorted.update(start, end);
52        let length = self.sorted.len();
53
54        let idx = match self.method {
55            Linear => {
56                // Maybe add a fast path for median case? They could branch depending on odd/even.
57                let length_f = length as f64;
58                let idx = ((length_f - 1.0) * self.prob).floor() as usize;
59
60                let float_idx_top = (length_f - 1.0) * self.prob;
61                let top_idx = float_idx_top.ceil() as usize;
62                return if idx == top_idx {
63                    Some(self.sorted.get(idx))
64                } else {
65                    let proportion = T::from(float_idx_top - idx as f64).unwrap();
66                    let vi = self.sorted.get(idx);
67                    let vj = self.sorted.get(top_idx);
68
69                    Some(proportion * (vj - vi) + vi)
70                };
71            },
72            Midpoint => {
73                let length_f = length as f64;
74                let idx = (length_f * self.prob) as usize;
75                let idx = std::cmp::min(idx, length - 1);
76
77                let top_idx = ((length_f - 1.0) * self.prob).ceil() as usize;
78                return if top_idx == idx {
79                    Some(self.sorted.get(idx))
80                } else {
81                    let (mid, mid_plus_1) = (self.sorted.get(idx), (self.sorted.get(idx + 1)));
82
83                    Some((mid + mid_plus_1) / (T::one() + T::one()))
84                };
85            },
86            Nearest => {
87                let idx = ((length as f64) * self.prob) as usize;
88                std::cmp::min(idx, length - 1)
89            },
90            Lower => ((length as f64 - 1.0) * self.prob).floor() as usize,
91            Higher => {
92                let idx = ((length as f64 - 1.0) * self.prob).ceil() as usize;
93                std::cmp::min(idx, length - 1)
94            },
95            Equiprobable => ((length as f64 * self.prob).ceil() - 1.0).max(0.0) as usize,
96        };
97
98        Some(self.sorted.get(idx))
99    }
100}
101
102pub fn rolling_quantile<T>(
103    values: &[T],
104    window_size: usize,
105    min_periods: usize,
106    center: bool,
107    weights: Option<&[f64]>,
108    params: Option<RollingFnParams>,
109) -> PolarsResult<ArrayRef>
110where
111    T: NativeType
112        + IsFloat
113        + Float
114        + std::iter::Sum
115        + AddAssign
116        + SubAssign
117        + Div<Output = T>
118        + NumCast
119        + One
120        + Zero
121        + SealedRolling
122        + PartialOrd
123        + Sub<Output = T>,
124{
125    let offset_fn = match center {
126        true => det_offsets_center,
127        false => det_offsets,
128    };
129    match weights {
130        None => {
131            if !center {
132                let params = params.as_ref().unwrap();
133                let RollingFnParams::Quantile(params) = params else {
134                    unreachable!("expected Quantile params");
135                };
136                let out = super::quantile_filter::rolling_quantile::<_, Vec<_>>(
137                    params.method,
138                    min_periods,
139                    window_size,
140                    values,
141                    params.prob,
142                );
143                let validity = create_validity(min_periods, values.len(), window_size, offset_fn);
144                return Ok(Box::new(PrimitiveArray::new(
145                    T::PRIMITIVE.into(),
146                    out.into(),
147                    validity.map(|b| b.into()),
148                )));
149            }
150
151            rolling_apply_agg_window::<QuantileWindow<_>, _, _>(
152                values,
153                window_size,
154                min_periods,
155                offset_fn,
156                params,
157            )
158        },
159        Some(weights) => {
160            let wsum = weights.iter().sum();
161            polars_ensure!(
162                wsum != 0.0,
163                ComputeError: "Weighted quantile is undefined if weights sum to 0"
164            );
165            let params = params.unwrap();
166            let RollingFnParams::Quantile(params) = params else {
167                unreachable!("expected Quantile params");
168            };
169
170            Ok(rolling_apply_weighted_quantile(
171                values,
172                params.prob,
173                params.method,
174                window_size,
175                min_periods,
176                offset_fn,
177                weights,
178                wsum,
179            ))
180        },
181    }
182}
183
184#[inline]
185fn compute_wq<T>(buf: &[(T, f64)], p: f64, wsum: f64, method: QuantileMethod) -> T
186where
187    T: Debug + NativeType + Mul<Output = T> + Sub<Output = T> + NumCast + ToPrimitive + Zero,
188{
189    // There are a few ways to compute a weighted quantile but no "canonical" way.
190    // This is mostly taken from the Julia implementation which was readable and reasonable
191    // https://juliastats.org/StatsBase.jl/stable/scalarstats/#Quantile-and-Related-Functions-1
192    let (mut s, mut s_old, mut vk, mut v_old) = (0.0, 0.0, T::zero(), T::zero());
193
194    // Once the cumulative weight crosses h, we've found our ind{ex/ices}. The definition may look
195    // odd but it's the equivalent of taking h = p * (n - 1) + 1 if your data is indexed from 1.
196    let h: f64 = p * (wsum - buf[0].1) + buf[0].1;
197    for &(v, w) in buf.iter() {
198        if s > h {
199            break;
200        }
201        (s_old, v_old, vk) = (s, vk, v);
202        s += w;
203    }
204    match (h == s_old, method) {
205        (true, _) => v_old, // If we hit the break exactly interpolation shouldn't matter
206        (_, Lower) => v_old,
207        (_, Higher) => vk,
208        (_, Nearest) => {
209            if s - h > h - s_old {
210                v_old
211            } else {
212                vk
213            }
214        },
215        (_, Equiprobable) => {
216            let threshold = (wsum * p).ceil() - 1.0;
217            if s > threshold { vk } else { v_old }
218        },
219        (_, Midpoint) => (vk + v_old) * NumCast::from(0.5).unwrap(),
220        // This is seemingly the canonical way to do it.
221        (_, Linear) => {
222            v_old + <T as NumCast>::from((h - s_old) / (s - s_old)).unwrap() * (vk - v_old)
223        },
224    }
225}
226
227#[allow(clippy::too_many_arguments)]
228fn rolling_apply_weighted_quantile<T, Fo>(
229    values: &[T],
230    p: f64,
231    method: QuantileMethod,
232    window_size: usize,
233    min_periods: usize,
234    det_offsets_fn: Fo,
235    weights: &[f64],
236    wsum: f64,
237) -> ArrayRef
238where
239    Fo: Fn(Idx, WindowSize, Len) -> (Start, End),
240    T: Debug + NativeType + Mul<Output = T> + Sub<Output = T> + NumCast + ToPrimitive + Zero,
241{
242    assert_eq!(weights.len(), window_size);
243    // Keep nonzero weights and their indices to know which values we need each iteration.
244    let nz_idx_wts: Vec<_> = weights.iter().enumerate().filter(|x| x.1 != &0.0).collect();
245    let mut buf = vec![(T::zero(), 0.0); nz_idx_wts.len()];
246    let len = values.len();
247    let out = (0..len)
248        .map(|idx| {
249            // Don't need end. Window size is constant and we computed offsets from start above.
250            let (start, _) = det_offsets_fn(idx, window_size, len);
251
252            // Sorting is not ideal, see https://github.com/tobiasschoch/wquantile for something faster
253            unsafe {
254                buf.iter_mut()
255                    .zip(nz_idx_wts.iter())
256                    .for_each(|(b, (i, w))| *b = (*values.get_unchecked(i + start), **w));
257            }
258            buf.sort_unstable_by(|&a, &b| a.0.tot_cmp(&b.0));
259            compute_wq(&buf, p, wsum, method)
260        })
261        .collect_trusted::<Vec<T>>();
262
263    let validity = create_validity(min_periods, len, window_size, det_offsets_fn);
264    Box::new(PrimitiveArray::new(
265        T::PRIMITIVE.into(),
266        out.into(),
267        validity.map(|b| b.into()),
268    ))
269}
270
271#[cfg(test)]
272mod test {
273    use super::*;
274
275    #[test]
276    fn test_rolling_median() {
277        let values = &[1.0, 2.0, 3.0, 4.0];
278        let med_pars = Some(RollingFnParams::Quantile(RollingQuantileParams {
279            prob: 0.5,
280            method: Linear,
281        }));
282        let out = rolling_quantile(values, 2, 2, false, None, med_pars).unwrap();
283        let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
284        let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
285        assert_eq!(out, &[None, Some(1.5), Some(2.5), Some(3.5)]);
286
287        let out = rolling_quantile(values, 2, 1, false, None, med_pars).unwrap();
288        let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
289        let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
290        assert_eq!(out, &[Some(1.0), Some(1.5), Some(2.5), Some(3.5)]);
291
292        let out = rolling_quantile(values, 4, 1, false, None, med_pars).unwrap();
293        let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
294        let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
295        assert_eq!(out, &[Some(1.0), Some(1.5), Some(2.0), Some(2.5)]);
296
297        let out = rolling_quantile(values, 4, 1, true, None, med_pars).unwrap();
298        let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
299        let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
300        assert_eq!(out, &[Some(1.5), Some(2.0), Some(2.5), Some(3.0)]);
301
302        let out = rolling_quantile(values, 4, 4, true, None, med_pars).unwrap();
303        let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
304        let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
305        assert_eq!(out, &[None, None, Some(2.5), None]);
306    }
307
308    #[test]
309    fn test_rolling_quantile_limits() {
310        let values = &[1.0f64, 2.0, 3.0, 4.0];
311
312        let methods = vec![
313            QuantileMethod::Lower,
314            QuantileMethod::Higher,
315            QuantileMethod::Nearest,
316            QuantileMethod::Midpoint,
317            QuantileMethod::Linear,
318            QuantileMethod::Equiprobable,
319        ];
320
321        for method in methods {
322            let min_pars = Some(RollingFnParams::Quantile(RollingQuantileParams {
323                prob: 0.0,
324                method,
325            }));
326            let out1 = rolling_min(values, 2, 2, false, None, None).unwrap();
327            let out1 = out1.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
328            let out1 = out1.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
329            let out2 = rolling_quantile(values, 2, 2, false, None, min_pars).unwrap();
330            let out2 = out2.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
331            let out2 = out2.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
332            assert_eq!(out1, out2);
333
334            let max_pars = Some(RollingFnParams::Quantile(RollingQuantileParams {
335                prob: 1.0,
336                method,
337            }));
338            let out1 = rolling_max(values, 2, 2, false, None, None).unwrap();
339            let out1 = out1.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
340            let out1 = out1.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
341            let out2 = rolling_quantile(values, 2, 2, false, None, max_pars).unwrap();
342            let out2 = out2.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
343            let out2 = out2.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
344            assert_eq!(out1, out2);
345        }
346    }
347}