1#![allow(unsafe_op_in_unsafe_fn)]
2use arrow::legacy::utils::CustomIterTools;
3use num_traits::ToPrimitive;
4use polars_error::polars_ensure;
5
6use super::QuantileMethod::*;
7use super::*;
8use crate::rolling::quantile_filter::SealedRolling;
9
10pub struct QuantileWindow<'a, T: NativeType> {
11 sorted: SortedBuf<'a, T>,
12 prob: f64,
13 method: QuantileMethod,
14}
15
16impl<
17 'a,
18 T: NativeType
19 + Float
20 + std::iter::Sum
21 + AddAssign
22 + SubAssign
23 + Div<Output = T>
24 + NumCast
25 + One
26 + Zero
27 + SealedRolling
28 + Sub<Output = T>,
29> RollingAggWindowNoNulls<'a, T> for QuantileWindow<'a, T>
30{
31 fn new(
32 slice: &'a [T],
33 start: usize,
34 end: usize,
35 params: Option<RollingFnParams>,
36 window_size: Option<usize>,
37 ) -> Self {
38 let params = params.unwrap();
39 let RollingFnParams::Quantile(params) = params else {
40 unreachable!("expected Quantile params");
41 };
42
43 Self {
44 sorted: SortedBuf::new(slice, start, end, window_size),
45 prob: params.prob,
46 method: params.method,
47 }
48 }
49
50 unsafe fn update(&mut self, start: usize, end: usize) -> Option<T> {
51 self.sorted.update(start, end);
52 let length = self.sorted.len();
53
54 let idx = match self.method {
55 Linear => {
56 let length_f = length as f64;
58 let idx = ((length_f - 1.0) * self.prob).floor() as usize;
59
60 let float_idx_top = (length_f - 1.0) * self.prob;
61 let top_idx = float_idx_top.ceil() as usize;
62 return if idx == top_idx {
63 Some(self.sorted.get(idx))
64 } else {
65 let proportion = T::from(float_idx_top - idx as f64).unwrap();
66 let vi = self.sorted.get(idx);
67 let vj = self.sorted.get(top_idx);
68
69 Some(proportion * (vj - vi) + vi)
70 };
71 },
72 Midpoint => {
73 let length_f = length as f64;
74 let idx = (length_f * self.prob) as usize;
75 let idx = std::cmp::min(idx, length - 1);
76
77 let top_idx = ((length_f - 1.0) * self.prob).ceil() as usize;
78 return if top_idx == idx {
79 Some(self.sorted.get(idx))
80 } else {
81 let (mid, mid_plus_1) = (self.sorted.get(idx), (self.sorted.get(idx + 1)));
82
83 Some((mid + mid_plus_1) / (T::one() + T::one()))
84 };
85 },
86 Nearest => {
87 let idx = ((length as f64) * self.prob) as usize;
88 std::cmp::min(idx, length - 1)
89 },
90 Lower => ((length as f64 - 1.0) * self.prob).floor() as usize,
91 Higher => {
92 let idx = ((length as f64 - 1.0) * self.prob).ceil() as usize;
93 std::cmp::min(idx, length - 1)
94 },
95 Equiprobable => ((length as f64 * self.prob).ceil() - 1.0).max(0.0) as usize,
96 };
97
98 Some(self.sorted.get(idx))
99 }
100}
101
102pub fn rolling_quantile<T>(
103 values: &[T],
104 window_size: usize,
105 min_periods: usize,
106 center: bool,
107 weights: Option<&[f64]>,
108 params: Option<RollingFnParams>,
109) -> PolarsResult<ArrayRef>
110where
111 T: NativeType
112 + IsFloat
113 + Float
114 + std::iter::Sum
115 + AddAssign
116 + SubAssign
117 + Div<Output = T>
118 + NumCast
119 + One
120 + Zero
121 + SealedRolling
122 + PartialOrd
123 + Sub<Output = T>,
124{
125 let offset_fn = match center {
126 true => det_offsets_center,
127 false => det_offsets,
128 };
129 match weights {
130 None => {
131 if !center {
132 let params = params.as_ref().unwrap();
133 let RollingFnParams::Quantile(params) = params else {
134 unreachable!("expected Quantile params");
135 };
136 let out = super::quantile_filter::rolling_quantile::<_, Vec<_>>(
137 params.method,
138 min_periods,
139 window_size,
140 values,
141 params.prob,
142 );
143 let validity = create_validity(min_periods, values.len(), window_size, offset_fn);
144 return Ok(Box::new(PrimitiveArray::new(
145 T::PRIMITIVE.into(),
146 out.into(),
147 validity.map(|b| b.into()),
148 )));
149 }
150
151 rolling_apply_agg_window::<QuantileWindow<_>, _, _>(
152 values,
153 window_size,
154 min_periods,
155 offset_fn,
156 params,
157 )
158 },
159 Some(weights) => {
160 let wsum = weights.iter().sum();
161 polars_ensure!(
162 wsum != 0.0,
163 ComputeError: "Weighted quantile is undefined if weights sum to 0"
164 );
165 let params = params.unwrap();
166 let RollingFnParams::Quantile(params) = params else {
167 unreachable!("expected Quantile params");
168 };
169
170 Ok(rolling_apply_weighted_quantile(
171 values,
172 params.prob,
173 params.method,
174 window_size,
175 min_periods,
176 offset_fn,
177 weights,
178 wsum,
179 ))
180 },
181 }
182}
183
184#[inline]
185fn compute_wq<T>(buf: &[(T, f64)], p: f64, wsum: f64, method: QuantileMethod) -> T
186where
187 T: Debug + NativeType + Mul<Output = T> + Sub<Output = T> + NumCast + ToPrimitive + Zero,
188{
189 let (mut s, mut s_old, mut vk, mut v_old) = (0.0, 0.0, T::zero(), T::zero());
193
194 let h: f64 = p * (wsum - buf[0].1) + buf[0].1;
197 for &(v, w) in buf.iter() {
198 if s > h {
199 break;
200 }
201 (s_old, v_old, vk) = (s, vk, v);
202 s += w;
203 }
204 match (h == s_old, method) {
205 (true, _) => v_old, (_, Lower) => v_old,
207 (_, Higher) => vk,
208 (_, Nearest) => {
209 if s - h > h - s_old {
210 v_old
211 } else {
212 vk
213 }
214 },
215 (_, Equiprobable) => {
216 let threshold = (wsum * p).ceil() - 1.0;
217 if s > threshold { vk } else { v_old }
218 },
219 (_, Midpoint) => (vk + v_old) * NumCast::from(0.5).unwrap(),
220 (_, Linear) => {
222 v_old + <T as NumCast>::from((h - s_old) / (s - s_old)).unwrap() * (vk - v_old)
223 },
224 }
225}
226
227#[allow(clippy::too_many_arguments)]
228fn rolling_apply_weighted_quantile<T, Fo>(
229 values: &[T],
230 p: f64,
231 method: QuantileMethod,
232 window_size: usize,
233 min_periods: usize,
234 det_offsets_fn: Fo,
235 weights: &[f64],
236 wsum: f64,
237) -> ArrayRef
238where
239 Fo: Fn(Idx, WindowSize, Len) -> (Start, End),
240 T: Debug + NativeType + Mul<Output = T> + Sub<Output = T> + NumCast + ToPrimitive + Zero,
241{
242 assert_eq!(weights.len(), window_size);
243 let nz_idx_wts: Vec<_> = weights.iter().enumerate().filter(|x| x.1 != &0.0).collect();
245 let mut buf = vec![(T::zero(), 0.0); nz_idx_wts.len()];
246 let len = values.len();
247 let out = (0..len)
248 .map(|idx| {
249 let (start, _) = det_offsets_fn(idx, window_size, len);
251
252 unsafe {
254 buf.iter_mut()
255 .zip(nz_idx_wts.iter())
256 .for_each(|(b, (i, w))| *b = (*values.get_unchecked(i + start), **w));
257 }
258 buf.sort_unstable_by(|&a, &b| a.0.tot_cmp(&b.0));
259 compute_wq(&buf, p, wsum, method)
260 })
261 .collect_trusted::<Vec<T>>();
262
263 let validity = create_validity(min_periods, len, window_size, det_offsets_fn);
264 Box::new(PrimitiveArray::new(
265 T::PRIMITIVE.into(),
266 out.into(),
267 validity.map(|b| b.into()),
268 ))
269}
270
271#[cfg(test)]
272mod test {
273 use super::*;
274
275 #[test]
276 fn test_rolling_median() {
277 let values = &[1.0, 2.0, 3.0, 4.0];
278 let med_pars = Some(RollingFnParams::Quantile(RollingQuantileParams {
279 prob: 0.5,
280 method: Linear,
281 }));
282 let out = rolling_quantile(values, 2, 2, false, None, med_pars).unwrap();
283 let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
284 let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
285 assert_eq!(out, &[None, Some(1.5), Some(2.5), Some(3.5)]);
286
287 let out = rolling_quantile(values, 2, 1, false, None, med_pars).unwrap();
288 let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
289 let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
290 assert_eq!(out, &[Some(1.0), Some(1.5), Some(2.5), Some(3.5)]);
291
292 let out = rolling_quantile(values, 4, 1, false, None, med_pars).unwrap();
293 let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
294 let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
295 assert_eq!(out, &[Some(1.0), Some(1.5), Some(2.0), Some(2.5)]);
296
297 let out = rolling_quantile(values, 4, 1, true, None, med_pars).unwrap();
298 let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
299 let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
300 assert_eq!(out, &[Some(1.5), Some(2.0), Some(2.5), Some(3.0)]);
301
302 let out = rolling_quantile(values, 4, 4, true, None, med_pars).unwrap();
303 let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
304 let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
305 assert_eq!(out, &[None, None, Some(2.5), None]);
306 }
307
308 #[test]
309 fn test_rolling_quantile_limits() {
310 let values = &[1.0f64, 2.0, 3.0, 4.0];
311
312 let methods = vec![
313 QuantileMethod::Lower,
314 QuantileMethod::Higher,
315 QuantileMethod::Nearest,
316 QuantileMethod::Midpoint,
317 QuantileMethod::Linear,
318 QuantileMethod::Equiprobable,
319 ];
320
321 for method in methods {
322 let min_pars = Some(RollingFnParams::Quantile(RollingQuantileParams {
323 prob: 0.0,
324 method,
325 }));
326 let out1 = rolling_min(values, 2, 2, false, None, None).unwrap();
327 let out1 = out1.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
328 let out1 = out1.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
329 let out2 = rolling_quantile(values, 2, 2, false, None, min_pars).unwrap();
330 let out2 = out2.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
331 let out2 = out2.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
332 assert_eq!(out1, out2);
333
334 let max_pars = Some(RollingFnParams::Quantile(RollingQuantileParams {
335 prob: 1.0,
336 method,
337 }));
338 let out1 = rolling_max(values, 2, 2, false, None, None).unwrap();
339 let out1 = out1.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
340 let out1 = out1.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
341 let out2 = rolling_quantile(values, 2, 2, false, None, max_pars).unwrap();
342 let out2 = out2.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
343 let out2 = out2.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
344 assert_eq!(out1, out2);
345 }
346 }
347}