polars_compute/rolling/nulls/
quantile.rs1#![allow(unsafe_op_in_unsafe_fn)]
2use arrow::array::MutablePrimitiveArray;
3
4use super::*;
5use crate::rolling::quantile_filter::SealedRolling;
6
7pub struct QuantileWindow<'a, T: NativeType + IsFloat + PartialOrd> {
8 sorted: SortedBufNulls<'a, T>,
9 prob: f64,
10 method: QuantileMethod,
11}
12
13impl<
14 'a,
15 T: NativeType
16 + IsFloat
17 + Float
18 + std::iter::Sum
19 + AddAssign
20 + SubAssign
21 + Div<Output = T>
22 + NumCast
23 + One
24 + Zero
25 + SealedRolling
26 + PartialOrd
27 + Sub<Output = T>,
28> RollingAggWindowNulls<'a, T> for QuantileWindow<'a, T>
29{
30 unsafe fn new(
31 slice: &'a [T],
32 validity: &'a Bitmap,
33 start: usize,
34 end: usize,
35 params: Option<RollingFnParams>,
36 window_size: Option<usize>,
37 ) -> Self {
38 let params = params.unwrap();
39 let RollingFnParams::Quantile(params) = params else {
40 unreachable!("expected Quantile params");
41 };
42 Self {
43 sorted: SortedBufNulls::new(slice, validity, start, end, window_size),
44 prob: params.prob,
45 method: params.method,
46 }
47 }
48
49 unsafe fn update(&mut self, start: usize, end: usize) -> Option<T> {
50 let null_count = self.sorted.update(start, end);
51 let mut length = self.sorted.len();
52 if null_count == length {
54 return None;
55 }
56 length -= null_count;
58 let mut idx = match self.method {
59 QuantileMethod::Nearest => ((length as f64) * self.prob) as usize,
60 QuantileMethod::Lower | QuantileMethod::Midpoint | QuantileMethod::Linear => {
61 ((length as f64 - 1.0) * self.prob).floor() as usize
62 },
63 QuantileMethod::Higher => ((length as f64 - 1.0) * self.prob).ceil() as usize,
64 QuantileMethod::Equiprobable => {
65 ((length as f64 * self.prob).ceil() - 1.0).max(0.0) as usize
66 },
67 };
68
69 idx = std::cmp::min(idx, length - 1);
70
71 match self.method {
73 QuantileMethod::Midpoint => {
74 let top_idx = ((length as f64 - 1.0) * self.prob).ceil() as usize;
75 Some(
76 (self.sorted.get(idx + null_count).unwrap()
77 + self.sorted.get(top_idx + null_count).unwrap())
78 / T::from::<f64>(2.0f64).unwrap(),
79 )
80 },
81 QuantileMethod::Linear => {
82 let float_idx = (length as f64 - 1.0) * self.prob;
83 let top_idx = f64::ceil(float_idx) as usize;
84
85 if top_idx == idx {
86 Some(self.sorted.get(idx + null_count).unwrap())
87 } else {
88 let proportion = T::from(float_idx - idx as f64).unwrap();
89 Some(
90 proportion
91 * (self.sorted.get(top_idx + null_count).unwrap()
92 - self.sorted.get(idx + null_count).unwrap())
93 + self.sorted.get(idx + null_count).unwrap(),
94 )
95 }
96 },
97 _ => Some(self.sorted.get(idx + null_count).unwrap()),
98 }
99 }
100
101 fn is_valid(&self, min_periods: usize) -> bool {
102 self.sorted.is_valid(min_periods)
103 }
104}
105
106pub fn rolling_quantile<T>(
107 arr: &PrimitiveArray<T>,
108 window_size: usize,
109 min_periods: usize,
110 center: bool,
111 weights: Option<&[f64]>,
112 params: Option<RollingFnParams>,
113) -> ArrayRef
114where
115 T: NativeType
116 + IsFloat
117 + Float
118 + std::iter::Sum
119 + AddAssign
120 + SubAssign
121 + Div<Output = T>
122 + NumCast
123 + One
124 + Zero
125 + SealedRolling
126 + PartialOrd
127 + Sub<Output = T>,
128{
129 if weights.is_some() {
130 panic!("weights not yet supported on array with null values")
131 }
132 let offset_fn = match center {
133 true => det_offsets_center,
134 false => det_offsets,
135 };
136 if !center {
137 let params = params.as_ref().unwrap();
138 let RollingFnParams::Quantile(params) = params else {
139 unreachable!("expected Quantile params");
140 };
141
142 let out = super::quantile_filter::rolling_quantile::<_, MutablePrimitiveArray<_>>(
143 params.method,
144 min_periods,
145 window_size,
146 arr.clone(),
147 params.prob,
148 );
149 let out: PrimitiveArray<T> = out.into();
150 return Box::new(out);
151 }
152 rolling_apply_agg_window::<QuantileWindow<_>, _, _>(
153 arr.values().as_slice(),
154 arr.validity().as_ref().unwrap(),
155 window_size,
156 min_periods,
157 offset_fn,
158 params,
159 )
160}
161
162#[cfg(test)]
163mod test {
164 use arrow::buffer::Buffer;
165 use arrow::datatypes::ArrowDataType;
166
167 use super::*;
168
169 #[test]
170 fn test_rolling_median_nulls() {
171 let buf = Buffer::from(vec![1.0, 2.0, 3.0, 4.0]);
172 let arr = &PrimitiveArray::new(
173 ArrowDataType::Float64,
174 buf,
175 Some(Bitmap::from(&[true, false, true, true])),
176 );
177 let med_pars = Some(RollingFnParams::Quantile(RollingQuantileParams {
178 prob: 0.5,
179 method: QuantileMethod::Linear,
180 }));
181
182 let out = rolling_quantile(arr, 2, 2, false, None, med_pars);
183 let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
184 let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
185 assert_eq!(out, &[None, None, None, Some(3.5)]);
186
187 let out = rolling_quantile(arr, 2, 1, false, None, med_pars);
188 let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
189 let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
190 assert_eq!(out, &[Some(1.0), Some(1.0), Some(3.0), Some(3.5)]);
191
192 let out = rolling_quantile(arr, 4, 1, false, None, med_pars);
193 let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
194 let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
195 assert_eq!(out, &[Some(1.0), Some(1.0), Some(2.0), Some(3.0)]);
196
197 let out = rolling_quantile(arr, 4, 1, true, None, med_pars);
198 let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
199 let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
200 assert_eq!(out, &[Some(1.0), Some(2.0), Some(3.0), Some(3.5)]);
201
202 let out = rolling_quantile(arr, 4, 4, true, None, med_pars);
203 let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
204 let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
205 assert_eq!(out, &[None, None, None, None]);
206 }
207
208 #[test]
209 fn test_rolling_quantile_nulls_limits() {
210 let buf = Buffer::<f64>::from(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
212 let values = &PrimitiveArray::new(
213 ArrowDataType::Float64,
214 buf,
215 Some(Bitmap::from(&[true, false, false, true, true])),
216 );
217
218 let methods = vec![
219 QuantileMethod::Lower,
220 QuantileMethod::Higher,
221 QuantileMethod::Nearest,
222 QuantileMethod::Midpoint,
223 QuantileMethod::Linear,
224 QuantileMethod::Equiprobable,
225 ];
226
227 for method in methods {
228 let min_pars = Some(RollingFnParams::Quantile(RollingQuantileParams {
229 prob: 0.0,
230 method,
231 }));
232 let out1 = rolling_min(values, 2, 1, false, None, None);
233 let out1 = out1.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
234 let out1 = out1.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
235 let out2 = rolling_quantile(values, 2, 1, false, None, min_pars);
236 let out2 = out2.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
237 let out2 = out2.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
238 assert_eq!(out1, out2);
239
240 let max_pars = Some(RollingFnParams::Quantile(RollingQuantileParams {
241 prob: 1.0,
242 method,
243 }));
244 let out1 = rolling_max(values, 2, 1, false, None, None);
245 let out1 = out1.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
246 let out1 = out1.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
247 let out2 = rolling_quantile(values, 2, 1, false, None, max_pars);
248 let out2 = out2.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
249 let out2 = out2.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
250 assert_eq!(out1, out2);
251 }
252 }
253}