#![allow(unsafe_op_in_unsafe_fn)]
use super::*;
use crate::rolling::quantile_filter::SealedRolling;
pub struct QuantileWindow<'a, T: NativeType + IsFloat + PartialOrd> {
sorted: SortedBufNulls<'a, T>,
prob: f64,
method: QuantileMethod,
}
impl<
T: NativeType
+ IsFloat
+ Float
+ std::iter::Sum
+ AddAssign
+ SubAssign
+ Div<Output = T>
+ NumCast
+ One
+ Zero
+ SealedRolling
+ PartialOrd
+ Sub<Output = T>,
> RollingAggWindowNulls<T> for QuantileWindow<'_, T>
{
type This<'a> = QuantileWindow<'a, T>;
fn new<'a>(
slice: &'a [T],
validity: &'a Bitmap,
start: usize,
end: usize,
params: Option<RollingFnParams>,
window_size: Option<usize>,
) -> Self::This<'a> {
let params = params.unwrap();
let RollingFnParams::Quantile(params) = params else {
unreachable!("expected Quantile params");
};
QuantileWindow {
sorted: SortedBufNulls::new(slice, validity, start, end, window_size),
prob: params.prob,
method: params.method,
}
}
unsafe fn update(&mut self, new_start: usize, new_end: usize) {
self.sorted.update(new_start, new_end);
}
fn get_agg(&self, _idx: usize) -> Option<T> {
let mut length = self.sorted.len();
let null_count = self.sorted.null_count;
if null_count == length {
return None;
}
length -= null_count;
let mut idx = match self.method {
QuantileMethod::Nearest => (((length as f64) - 1.0) * self.prob).round() as usize,
QuantileMethod::Lower | QuantileMethod::Midpoint | QuantileMethod::Linear => {
((length as f64 - 1.0) * self.prob).floor() as usize
},
QuantileMethod::Higher => ((length as f64 - 1.0) * self.prob).ceil() as usize,
QuantileMethod::Equiprobable => {
((length as f64 * self.prob).ceil() - 1.0).max(0.0) as usize
},
};
idx = std::cmp::min(idx, length - 1);
match self.method {
QuantileMethod::Midpoint => {
let top_idx = ((length as f64 - 1.0) * self.prob).ceil() as usize;
debug_assert!(idx <= top_idx);
let v = if idx != top_idx {
let low = self.sorted.get(idx + null_count).unwrap();
let high = self.sorted.get(idx + null_count + 1).unwrap();
(low + high) / T::from::<f64>(2.0f64).unwrap()
} else {
self.sorted.get(idx + null_count).unwrap()
};
Some(v)
},
QuantileMethod::Linear => {
let float_idx = (length as f64 - 1.0) * self.prob;
let top_idx = f64::ceil(float_idx) as usize;
if top_idx == idx {
Some(self.sorted.get(idx + null_count).unwrap())
} else {
let low = self.sorted.get(idx + null_count).unwrap();
let high = self.sorted.get(top_idx + null_count).unwrap();
let proportion = T::from(float_idx - idx as f64).unwrap();
Some(proportion * (high - low) + low)
}
},
_ => Some(self.sorted.get(idx + null_count).unwrap()),
}
}
fn is_valid(&self, min_periods: usize) -> bool {
self.sorted.is_valid(min_periods)
}
fn slice_len(&self) -> usize {
self.sorted.slice_len()
}
}
pub fn rolling_quantile<T>(
arr: &PrimitiveArray<T>,
window_size: usize,
min_periods: usize,
center: bool,
weights: Option<&[f64]>,
params: Option<RollingFnParams>,
) -> ArrayRef
where
T: NativeType
+ IsFloat
+ Float
+ std::iter::Sum
+ AddAssign
+ SubAssign
+ Div<Output = T>
+ NumCast
+ One
+ Zero
+ SealedRolling
+ PartialOrd
+ Sub<Output = T>,
{
if weights.is_some() {
panic!("weights not yet supported on array with null values")
}
let offset_fn = match center {
true => det_offsets_center,
false => det_offsets,
};
rolling_apply_agg_window::<QuantileWindow<T>, _, _, _>(
arr.values().as_slice(),
arr.validity().as_ref().unwrap(),
window_size,
min_periods,
offset_fn,
params,
)
}
#[cfg(test)]
mod test {
use arrow::datatypes::ArrowDataType;
use polars_buffer::Buffer;
use super::*;
#[test]
fn test_rolling_median_nulls() {
let buf = Buffer::from(vec![1.0, 2.0, 3.0, 4.0]);
let arr = &PrimitiveArray::new(
ArrowDataType::Float64,
buf,
Some(Bitmap::from(&[true, false, true, true])),
);
let med_pars = Some(RollingFnParams::Quantile(RollingQuantileParams {
prob: 0.5,
method: QuantileMethod::Linear,
}));
let out = rolling_quantile(arr, 2, 2, false, None, med_pars);
let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
assert_eq!(out, &[None, None, None, Some(3.5)]);
let out = rolling_quantile(arr, 2, 1, false, None, med_pars);
let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
assert_eq!(out, &[Some(1.0), Some(1.0), Some(3.0), Some(3.5)]);
let out = rolling_quantile(arr, 4, 1, false, None, med_pars);
let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
assert_eq!(out, &[Some(1.0), Some(1.0), Some(2.0), Some(3.0)]);
let out = rolling_quantile(arr, 4, 1, true, None, med_pars);
let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
assert_eq!(out, &[Some(1.0), Some(2.0), Some(3.0), Some(3.5)]);
let out = rolling_quantile(arr, 4, 4, true, None, med_pars);
let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
assert_eq!(out, &[None, None, None, None]);
}
#[test]
fn test_rolling_quantile_nulls_limits() {
let buf = Buffer::<f64>::from(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
let values = &PrimitiveArray::new(
ArrowDataType::Float64,
buf,
Some(Bitmap::from(&[true, false, false, true, true])),
);
let methods = vec![
QuantileMethod::Lower,
QuantileMethod::Higher,
QuantileMethod::Nearest,
QuantileMethod::Midpoint,
QuantileMethod::Linear,
QuantileMethod::Equiprobable,
];
for method in methods {
let min_pars = Some(RollingFnParams::Quantile(RollingQuantileParams {
prob: 0.0,
method,
}));
let out1 = rolling_min(values, 2, 1, false, None, None);
let out1 = out1.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
let out1 = out1.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
let out2 = rolling_quantile(values, 2, 1, false, None, min_pars);
let out2 = out2.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
let out2 = out2.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
assert_eq!(out1, out2);
let max_pars = Some(RollingFnParams::Quantile(RollingQuantileParams {
prob: 1.0,
method,
}));
let out1 = rolling_max(values, 2, 1, false, None, None);
let out1 = out1.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
let out1 = out1.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
let out2 = rolling_quantile(values, 2, 1, false, None, max_pars);
let out2 = out2.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
let out2 = out2.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
assert_eq!(out1, out2);
}
}
}