Skip to main content

polars_compute/rolling/nulls/
quantile.rs

1#![allow(unsafe_op_in_unsafe_fn)]
2use super::*;
3use crate::rolling::quantile_filter::SealedRolling;
4
5pub struct QuantileWindow<'a, T: NativeType + IsFloat + PartialOrd> {
6    sorted: SortedBufNulls<'a, T>,
7    prob: f64,
8    method: QuantileMethod,
9}
10
11impl<
12    T: NativeType
13        + IsFloat
14        + Float
15        + std::iter::Sum
16        + AddAssign
17        + SubAssign
18        + Div<Output = T>
19        + NumCast
20        + One
21        + Zero
22        + SealedRolling
23        + PartialOrd
24        + Sub<Output = T>,
25> RollingAggWindowNulls<T> for QuantileWindow<'_, T>
26{
27    type This<'a> = QuantileWindow<'a, T>;
28
29    fn new<'a>(
30        slice: &'a [T],
31        validity: &'a Bitmap,
32        start: usize,
33        end: usize,
34        params: Option<RollingFnParams>,
35        window_size: Option<usize>,
36    ) -> Self::This<'a> {
37        let params = params.unwrap();
38        let RollingFnParams::Quantile(params) = params else {
39            unreachable!("expected Quantile params");
40        };
41        QuantileWindow {
42            sorted: SortedBufNulls::new(slice, validity, start, end, window_size),
43            prob: params.prob,
44            method: params.method,
45        }
46    }
47
48    unsafe fn update(&mut self, new_start: usize, new_end: usize) {
49        self.sorted.update(new_start, new_end);
50    }
51
52    fn get_agg(&self, _idx: usize) -> Option<T> {
53        let mut length = self.sorted.len();
54        let null_count = self.sorted.null_count;
55
56        // The min periods_issue will be taken care of when actually rolling
57        if null_count == length {
58            return None;
59        }
60        // Nulls are guaranteed to be at the front
61        length -= null_count;
62        let mut idx = match self.method {
63            QuantileMethod::Nearest => (((length as f64) - 1.0) * self.prob).round() as usize,
64            QuantileMethod::Lower | QuantileMethod::Midpoint | QuantileMethod::Linear => {
65                ((length as f64 - 1.0) * self.prob).floor() as usize
66            },
67            QuantileMethod::Higher => ((length as f64 - 1.0) * self.prob).ceil() as usize,
68            QuantileMethod::Equiprobable => {
69                ((length as f64 * self.prob).ceil() - 1.0).max(0.0) as usize
70            },
71        };
72
73        idx = std::cmp::min(idx, length - 1);
74
75        // we can unwrap because we sliced of the nulls
76        match self.method {
77            QuantileMethod::Midpoint => {
78                let top_idx = ((length as f64 - 1.0) * self.prob).ceil() as usize;
79
80                debug_assert!(idx <= top_idx);
81                let v = if idx != top_idx {
82                    let low = self.sorted.get(idx + null_count).unwrap();
83                    let high = self.sorted.get(idx + null_count + 1).unwrap();
84                    (low + high) / T::from::<f64>(2.0f64).unwrap()
85                } else {
86                    self.sorted.get(idx + null_count).unwrap()
87                };
88
89                Some(v)
90            },
91            QuantileMethod::Linear => {
92                let float_idx = (length as f64 - 1.0) * self.prob;
93                let top_idx = f64::ceil(float_idx) as usize;
94
95                if top_idx == idx {
96                    Some(self.sorted.get(idx + null_count).unwrap())
97                } else {
98                    let low = self.sorted.get(idx + null_count).unwrap();
99                    let high = self.sorted.get(top_idx + null_count).unwrap();
100                    let proportion = T::from(float_idx - idx as f64).unwrap();
101                    Some(proportion * (high - low) + low)
102                }
103            },
104            _ => Some(self.sorted.get(idx + null_count).unwrap()),
105        }
106    }
107
108    fn is_valid(&self, min_periods: usize) -> bool {
109        self.sorted.is_valid(min_periods)
110    }
111
112    fn slice_len(&self) -> usize {
113        self.sorted.slice_len()
114    }
115}
116
117pub fn rolling_quantile<T>(
118    arr: &PrimitiveArray<T>,
119    window_size: usize,
120    min_periods: usize,
121    center: bool,
122    weights: Option<&[f64]>,
123    params: Option<RollingFnParams>,
124) -> ArrayRef
125where
126    T: NativeType
127        + IsFloat
128        + Float
129        + std::iter::Sum
130        + AddAssign
131        + SubAssign
132        + Div<Output = T>
133        + NumCast
134        + One
135        + Zero
136        + SealedRolling
137        + PartialOrd
138        + Sub<Output = T>,
139{
140    if weights.is_some() {
141        panic!("weights not yet supported on array with null values")
142    }
143    let offset_fn = match center {
144        true => det_offsets_center,
145        false => det_offsets,
146    };
147    /*
148    TODO: fix or remove the dancing links based rolling implementation
149    see https://github.com/pola-rs/polars/issues/23480
150    if !center {
151        let params = params.as_ref().unwrap();
152        let RollingFnParams::Quantile(params) = params else {
153            unreachable!("expected Quantile params");
154        };
155
156        let out = super::quantile_filter::rolling_quantile::<_, MutablePrimitiveArray<_>>(
157            params.method,
158            min_periods,
159            window_size,
160            arr.clone(),
161            params.prob,
162        );
163        let out: PrimitiveArray<T> = out.into();
164        return Box::new(out);
165    }
166    */
167    rolling_apply_agg_window::<QuantileWindow<T>, _, _, _>(
168        arr.values().as_slice(),
169        arr.validity().as_ref().unwrap(),
170        window_size,
171        min_periods,
172        offset_fn,
173        params,
174    )
175}
176
177#[cfg(test)]
178mod test {
179    use arrow::datatypes::ArrowDataType;
180    use polars_buffer::Buffer;
181
182    use super::*;
183
184    #[test]
185    fn test_rolling_median_nulls() {
186        let buf = Buffer::from(vec![1.0, 2.0, 3.0, 4.0]);
187        let arr = &PrimitiveArray::new(
188            ArrowDataType::Float64,
189            buf,
190            Some(Bitmap::from(&[true, false, true, true])),
191        );
192        let med_pars = Some(RollingFnParams::Quantile(RollingQuantileParams {
193            prob: 0.5,
194            method: QuantileMethod::Linear,
195        }));
196
197        let out = rolling_quantile(arr, 2, 2, false, None, med_pars);
198        let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
199        let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
200        assert_eq!(out, &[None, None, None, Some(3.5)]);
201
202        let out = rolling_quantile(arr, 2, 1, false, None, med_pars);
203        let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
204        let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
205        assert_eq!(out, &[Some(1.0), Some(1.0), Some(3.0), Some(3.5)]);
206
207        let out = rolling_quantile(arr, 4, 1, false, None, med_pars);
208        let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
209        let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
210        assert_eq!(out, &[Some(1.0), Some(1.0), Some(2.0), Some(3.0)]);
211
212        let out = rolling_quantile(arr, 4, 1, true, None, med_pars);
213        let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
214        let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
215        assert_eq!(out, &[Some(1.0), Some(2.0), Some(3.0), Some(3.5)]);
216
217        let out = rolling_quantile(arr, 4, 4, true, None, med_pars);
218        let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
219        let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
220        assert_eq!(out, &[None, None, None, None]);
221    }
222
223    #[test]
224    fn test_rolling_quantile_nulls_limits() {
225        // compare quantiles to corresponding min/max/median values
226        let buf = Buffer::<f64>::from(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
227        let values = &PrimitiveArray::new(
228            ArrowDataType::Float64,
229            buf,
230            Some(Bitmap::from(&[true, false, false, true, true])),
231        );
232
233        let methods = vec![
234            QuantileMethod::Lower,
235            QuantileMethod::Higher,
236            QuantileMethod::Nearest,
237            QuantileMethod::Midpoint,
238            QuantileMethod::Linear,
239            QuantileMethod::Equiprobable,
240        ];
241
242        for method in methods {
243            let min_pars = Some(RollingFnParams::Quantile(RollingQuantileParams {
244                prob: 0.0,
245                method,
246            }));
247            let out1 = rolling_min(values, 2, 1, false, None, None);
248            let out1 = out1.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
249            let out1 = out1.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
250            let out2 = rolling_quantile(values, 2, 1, false, None, min_pars);
251            let out2 = out2.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
252            let out2 = out2.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
253            assert_eq!(out1, out2);
254
255            let max_pars = Some(RollingFnParams::Quantile(RollingQuantileParams {
256                prob: 1.0,
257                method,
258            }));
259            let out1 = rolling_max(values, 2, 1, false, None, None);
260            let out1 = out1.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
261            let out1 = out1.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
262            let out2 = rolling_quantile(values, 2, 1, false, None, max_pars);
263            let out2 = out2.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
264            let out2 = out2.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
265            assert_eq!(out1, out2);
266        }
267    }
268}