polars_compute/rolling/nulls/
quantile.rs

1#![allow(unsafe_op_in_unsafe_fn)]
2use arrow::array::MutablePrimitiveArray;
3
4use super::*;
5use crate::rolling::quantile_filter::SealedRolling;
6
7pub struct QuantileWindow<'a, T: NativeType + IsFloat + PartialOrd> {
8    sorted: SortedBufNulls<'a, T>,
9    prob: f64,
10    method: QuantileMethod,
11}
12
13impl<
14    'a,
15    T: NativeType
16        + IsFloat
17        + Float
18        + std::iter::Sum
19        + AddAssign
20        + SubAssign
21        + Div<Output = T>
22        + NumCast
23        + One
24        + Zero
25        + SealedRolling
26        + PartialOrd
27        + Sub<Output = T>,
28> RollingAggWindowNulls<'a, T> for QuantileWindow<'a, T>
29{
30    unsafe fn new(
31        slice: &'a [T],
32        validity: &'a Bitmap,
33        start: usize,
34        end: usize,
35        params: Option<RollingFnParams>,
36        window_size: Option<usize>,
37    ) -> Self {
38        let params = params.unwrap();
39        let RollingFnParams::Quantile(params) = params else {
40            unreachable!("expected Quantile params");
41        };
42        Self {
43            sorted: SortedBufNulls::new(slice, validity, start, end, window_size),
44            prob: params.prob,
45            method: params.method,
46        }
47    }
48
49    unsafe fn update(&mut self, start: usize, end: usize) -> Option<T> {
50        let null_count = self.sorted.update(start, end);
51        let mut length = self.sorted.len();
52        // The min periods_issue will be taken care of when actually rolling
53        if null_count == length {
54            return None;
55        }
56        // Nulls are guaranteed to be at the front
57        length -= null_count;
58        let mut idx = match self.method {
59            QuantileMethod::Nearest => ((length as f64) * self.prob) as usize,
60            QuantileMethod::Lower | QuantileMethod::Midpoint | QuantileMethod::Linear => {
61                ((length as f64 - 1.0) * self.prob).floor() as usize
62            },
63            QuantileMethod::Higher => ((length as f64 - 1.0) * self.prob).ceil() as usize,
64            QuantileMethod::Equiprobable => {
65                ((length as f64 * self.prob).ceil() - 1.0).max(0.0) as usize
66            },
67        };
68
69        idx = std::cmp::min(idx, length - 1);
70
71        // we can unwrap because we sliced of the nulls
72        match self.method {
73            QuantileMethod::Midpoint => {
74                let top_idx = ((length as f64 - 1.0) * self.prob).ceil() as usize;
75                Some(
76                    (self.sorted.get(idx + null_count).unwrap()
77                        + self.sorted.get(top_idx + null_count).unwrap())
78                        / T::from::<f64>(2.0f64).unwrap(),
79                )
80            },
81            QuantileMethod::Linear => {
82                let float_idx = (length as f64 - 1.0) * self.prob;
83                let top_idx = f64::ceil(float_idx) as usize;
84
85                if top_idx == idx {
86                    Some(self.sorted.get(idx + null_count).unwrap())
87                } else {
88                    let proportion = T::from(float_idx - idx as f64).unwrap();
89                    Some(
90                        proportion
91                            * (self.sorted.get(top_idx + null_count).unwrap()
92                                - self.sorted.get(idx + null_count).unwrap())
93                            + self.sorted.get(idx + null_count).unwrap(),
94                    )
95                }
96            },
97            _ => Some(self.sorted.get(idx + null_count).unwrap()),
98        }
99    }
100
101    fn is_valid(&self, min_periods: usize) -> bool {
102        self.sorted.is_valid(min_periods)
103    }
104}
105
106pub fn rolling_quantile<T>(
107    arr: &PrimitiveArray<T>,
108    window_size: usize,
109    min_periods: usize,
110    center: bool,
111    weights: Option<&[f64]>,
112    params: Option<RollingFnParams>,
113) -> ArrayRef
114where
115    T: NativeType
116        + IsFloat
117        + Float
118        + std::iter::Sum
119        + AddAssign
120        + SubAssign
121        + Div<Output = T>
122        + NumCast
123        + One
124        + Zero
125        + SealedRolling
126        + PartialOrd
127        + Sub<Output = T>,
128{
129    if weights.is_some() {
130        panic!("weights not yet supported on array with null values")
131    }
132    let offset_fn = match center {
133        true => det_offsets_center,
134        false => det_offsets,
135    };
136    if !center {
137        let params = params.as_ref().unwrap();
138        let RollingFnParams::Quantile(params) = params else {
139            unreachable!("expected Quantile params");
140        };
141
142        let out = super::quantile_filter::rolling_quantile::<_, MutablePrimitiveArray<_>>(
143            params.method,
144            min_periods,
145            window_size,
146            arr.clone(),
147            params.prob,
148        );
149        let out: PrimitiveArray<T> = out.into();
150        return Box::new(out);
151    }
152    rolling_apply_agg_window::<QuantileWindow<_>, _, _>(
153        arr.values().as_slice(),
154        arr.validity().as_ref().unwrap(),
155        window_size,
156        min_periods,
157        offset_fn,
158        params,
159    )
160}
161
162#[cfg(test)]
163mod test {
164    use arrow::buffer::Buffer;
165    use arrow::datatypes::ArrowDataType;
166
167    use super::*;
168
169    #[test]
170    fn test_rolling_median_nulls() {
171        let buf = Buffer::from(vec![1.0, 2.0, 3.0, 4.0]);
172        let arr = &PrimitiveArray::new(
173            ArrowDataType::Float64,
174            buf,
175            Some(Bitmap::from(&[true, false, true, true])),
176        );
177        let med_pars = Some(RollingFnParams::Quantile(RollingQuantileParams {
178            prob: 0.5,
179            method: QuantileMethod::Linear,
180        }));
181
182        let out = rolling_quantile(arr, 2, 2, false, None, med_pars);
183        let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
184        let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
185        assert_eq!(out, &[None, None, None, Some(3.5)]);
186
187        let out = rolling_quantile(arr, 2, 1, false, None, med_pars);
188        let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
189        let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
190        assert_eq!(out, &[Some(1.0), Some(1.0), Some(3.0), Some(3.5)]);
191
192        let out = rolling_quantile(arr, 4, 1, false, None, med_pars);
193        let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
194        let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
195        assert_eq!(out, &[Some(1.0), Some(1.0), Some(2.0), Some(3.0)]);
196
197        let out = rolling_quantile(arr, 4, 1, true, None, med_pars);
198        let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
199        let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
200        assert_eq!(out, &[Some(1.0), Some(2.0), Some(3.0), Some(3.5)]);
201
202        let out = rolling_quantile(arr, 4, 4, true, None, med_pars);
203        let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
204        let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
205        assert_eq!(out, &[None, None, None, None]);
206    }
207
208    #[test]
209    fn test_rolling_quantile_nulls_limits() {
210        // compare quantiles to corresponding min/max/median values
211        let buf = Buffer::<f64>::from(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
212        let values = &PrimitiveArray::new(
213            ArrowDataType::Float64,
214            buf,
215            Some(Bitmap::from(&[true, false, false, true, true])),
216        );
217
218        let methods = vec![
219            QuantileMethod::Lower,
220            QuantileMethod::Higher,
221            QuantileMethod::Nearest,
222            QuantileMethod::Midpoint,
223            QuantileMethod::Linear,
224            QuantileMethod::Equiprobable,
225        ];
226
227        for method in methods {
228            let min_pars = Some(RollingFnParams::Quantile(RollingQuantileParams {
229                prob: 0.0,
230                method,
231            }));
232            let out1 = rolling_min(values, 2, 1, false, None, None);
233            let out1 = out1.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
234            let out1 = out1.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
235            let out2 = rolling_quantile(values, 2, 1, false, None, min_pars);
236            let out2 = out2.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
237            let out2 = out2.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
238            assert_eq!(out1, out2);
239
240            let max_pars = Some(RollingFnParams::Quantile(RollingQuantileParams {
241                prob: 1.0,
242                method,
243            }));
244            let out1 = rolling_max(values, 2, 1, false, None, None);
245            let out1 = out1.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
246            let out1 = out1.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
247            let out2 = rolling_quantile(values, 2, 1, false, None, max_pars);
248            let out2 = out2.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
249            let out2 = out2.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
250            assert_eq!(out1, out2);
251        }
252    }
253}