polars_compute/rolling/nulls/
quantile.rs1#![allow(unsafe_op_in_unsafe_fn)]
2use super::*;
3use crate::rolling::quantile_filter::SealedRolling;
4
5pub struct QuantileWindow<'a, T: NativeType + IsFloat + PartialOrd> {
6 sorted: SortedBufNulls<'a, T>,
7 prob: f64,
8 method: QuantileMethod,
9}
10
11impl<
12 T: NativeType
13 + IsFloat
14 + Float
15 + std::iter::Sum
16 + AddAssign
17 + SubAssign
18 + Div<Output = T>
19 + NumCast
20 + One
21 + Zero
22 + SealedRolling
23 + PartialOrd
24 + Sub<Output = T>,
25> RollingAggWindowNulls<T> for QuantileWindow<'_, T>
26{
27 type This<'a> = QuantileWindow<'a, T>;
28
29 fn new<'a>(
30 slice: &'a [T],
31 validity: &'a Bitmap,
32 start: usize,
33 end: usize,
34 params: Option<RollingFnParams>,
35 window_size: Option<usize>,
36 ) -> Self::This<'a> {
37 let params = params.unwrap();
38 let RollingFnParams::Quantile(params) = params else {
39 unreachable!("expected Quantile params");
40 };
41 QuantileWindow {
42 sorted: SortedBufNulls::new(slice, validity, start, end, window_size),
43 prob: params.prob,
44 method: params.method,
45 }
46 }
47
48 unsafe fn update(&mut self, new_start: usize, new_end: usize) {
49 self.sorted.update(new_start, new_end);
50 }
51
52 fn get_agg(&self, _idx: usize) -> Option<T> {
53 let mut length = self.sorted.len();
54 let null_count = self.sorted.null_count;
55
56 if null_count == length {
58 return None;
59 }
60 length -= null_count;
62 let mut idx = match self.method {
63 QuantileMethod::Nearest => (((length as f64) - 1.0) * self.prob).round() as usize,
64 QuantileMethod::Lower | QuantileMethod::Midpoint | QuantileMethod::Linear => {
65 ((length as f64 - 1.0) * self.prob).floor() as usize
66 },
67 QuantileMethod::Higher => ((length as f64 - 1.0) * self.prob).ceil() as usize,
68 QuantileMethod::Equiprobable => {
69 ((length as f64 * self.prob).ceil() - 1.0).max(0.0) as usize
70 },
71 };
72
73 idx = std::cmp::min(idx, length - 1);
74
75 match self.method {
77 QuantileMethod::Midpoint => {
78 let top_idx = ((length as f64 - 1.0) * self.prob).ceil() as usize;
79
80 debug_assert!(idx <= top_idx);
81 let v = if idx != top_idx {
82 let low = self.sorted.get(idx + null_count).unwrap();
83 let high = self.sorted.get(idx + null_count + 1).unwrap();
84 (low + high) / T::from::<f64>(2.0f64).unwrap()
85 } else {
86 self.sorted.get(idx + null_count).unwrap()
87 };
88
89 Some(v)
90 },
91 QuantileMethod::Linear => {
92 let float_idx = (length as f64 - 1.0) * self.prob;
93 let top_idx = f64::ceil(float_idx) as usize;
94
95 if top_idx == idx {
96 Some(self.sorted.get(idx + null_count).unwrap())
97 } else {
98 let low = self.sorted.get(idx + null_count).unwrap();
99 let high = self.sorted.get(top_idx + null_count).unwrap();
100 let proportion = T::from(float_idx - idx as f64).unwrap();
101 Some(proportion * (high - low) + low)
102 }
103 },
104 _ => Some(self.sorted.get(idx + null_count).unwrap()),
105 }
106 }
107
108 fn is_valid(&self, min_periods: usize) -> bool {
109 self.sorted.is_valid(min_periods)
110 }
111
112 fn slice_len(&self) -> usize {
113 self.sorted.slice_len()
114 }
115}
116
117pub fn rolling_quantile<T>(
118 arr: &PrimitiveArray<T>,
119 window_size: usize,
120 min_periods: usize,
121 center: bool,
122 weights: Option<&[f64]>,
123 params: Option<RollingFnParams>,
124) -> ArrayRef
125where
126 T: NativeType
127 + IsFloat
128 + Float
129 + std::iter::Sum
130 + AddAssign
131 + SubAssign
132 + Div<Output = T>
133 + NumCast
134 + One
135 + Zero
136 + SealedRolling
137 + PartialOrd
138 + Sub<Output = T>,
139{
140 if weights.is_some() {
141 panic!("weights not yet supported on array with null values")
142 }
143 let offset_fn = match center {
144 true => det_offsets_center,
145 false => det_offsets,
146 };
147 rolling_apply_agg_window::<QuantileWindow<T>, _, _, _>(
168 arr.values().as_slice(),
169 arr.validity().as_ref().unwrap(),
170 window_size,
171 min_periods,
172 offset_fn,
173 params,
174 )
175}
176
177#[cfg(test)]
178mod test {
179 use arrow::datatypes::ArrowDataType;
180 use polars_buffer::Buffer;
181
182 use super::*;
183
184 #[test]
185 fn test_rolling_median_nulls() {
186 let buf = Buffer::from(vec![1.0, 2.0, 3.0, 4.0]);
187 let arr = &PrimitiveArray::new(
188 ArrowDataType::Float64,
189 buf,
190 Some(Bitmap::from(&[true, false, true, true])),
191 );
192 let med_pars = Some(RollingFnParams::Quantile(RollingQuantileParams {
193 prob: 0.5,
194 method: QuantileMethod::Linear,
195 }));
196
197 let out = rolling_quantile(arr, 2, 2, false, None, med_pars);
198 let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
199 let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
200 assert_eq!(out, &[None, None, None, Some(3.5)]);
201
202 let out = rolling_quantile(arr, 2, 1, false, None, med_pars);
203 let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
204 let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
205 assert_eq!(out, &[Some(1.0), Some(1.0), Some(3.0), Some(3.5)]);
206
207 let out = rolling_quantile(arr, 4, 1, false, None, med_pars);
208 let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
209 let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
210 assert_eq!(out, &[Some(1.0), Some(1.0), Some(2.0), Some(3.0)]);
211
212 let out = rolling_quantile(arr, 4, 1, true, None, med_pars);
213 let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
214 let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
215 assert_eq!(out, &[Some(1.0), Some(2.0), Some(3.0), Some(3.5)]);
216
217 let out = rolling_quantile(arr, 4, 4, true, None, med_pars);
218 let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
219 let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
220 assert_eq!(out, &[None, None, None, None]);
221 }
222
223 #[test]
224 fn test_rolling_quantile_nulls_limits() {
225 let buf = Buffer::<f64>::from(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
227 let values = &PrimitiveArray::new(
228 ArrowDataType::Float64,
229 buf,
230 Some(Bitmap::from(&[true, false, false, true, true])),
231 );
232
233 let methods = vec![
234 QuantileMethod::Lower,
235 QuantileMethod::Higher,
236 QuantileMethod::Nearest,
237 QuantileMethod::Midpoint,
238 QuantileMethod::Linear,
239 QuantileMethod::Equiprobable,
240 ];
241
242 for method in methods {
243 let min_pars = Some(RollingFnParams::Quantile(RollingQuantileParams {
244 prob: 0.0,
245 method,
246 }));
247 let out1 = rolling_min(values, 2, 1, false, None, None);
248 let out1 = out1.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
249 let out1 = out1.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
250 let out2 = rolling_quantile(values, 2, 1, false, None, min_pars);
251 let out2 = out2.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
252 let out2 = out2.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
253 assert_eq!(out1, out2);
254
255 let max_pars = Some(RollingFnParams::Quantile(RollingQuantileParams {
256 prob: 1.0,
257 method,
258 }));
259 let out1 = rolling_max(values, 2, 1, false, None, None);
260 let out1 = out1.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
261 let out1 = out1.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
262 let out2 = rolling_quantile(values, 2, 1, false, None, max_pars);
263 let out2 = out2.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
264 let out2 = out2.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
265 assert_eq!(out1, out2);
266 }
267 }
268}