use std::fmt::Debug;
use arrow::array::PrimitiveArray;
use arrow::datatypes::ArrowDataType;
use arrow::legacy::error::PolarsResult;
use arrow::legacy::utils::CustomIterTools;
use arrow::types::NativeType;
use num_traits::{Float, Num, NumCast};
mod mean;
mod min_max;
mod moment;
mod quantile;
pub mod rank;
mod sum;
pub use mean::*;
pub use min_max::*;
pub use moment::*;
pub use quantile::*;
pub use rank::*;
pub use sum::*;
use super::*;
pub trait RollingAggWindowNoNulls<T: NativeType, Out: NativeType = T> {
type This<'a>: RollingAggWindowNoNulls<T, Out>;
fn new(
slice: &[T],
start: usize,
end: usize,
params: Option<RollingFnParams>,
window_size: Option<usize>,
) -> Self::This<'_>;
unsafe fn update(&mut self, new_start: usize, new_end: usize);
fn get_agg(&self, idx: usize) -> Option<Out>;
fn slice_len(&self) -> usize;
}
pub(super) fn rolling_apply_agg_window<Agg, T, O, Fo>(
values: &[T],
window_size: usize,
min_periods: usize,
det_offsets_fn: Fo,
params: Option<RollingFnParams>,
) -> PolarsResult<ArrayRef>
where
Fo: Fn(Idx, WindowSize, Len) -> (Start, End),
Agg: RollingAggWindowNoNulls<T, O>,
T: Debug + NativeType + Num,
O: Debug + NativeType + Num,
{
let len = values.len();
let (start, end) = det_offsets_fn(0, window_size, len);
let mut agg_window = Agg::new(values, start, end, params, Some(window_size));
let out = (0..len).map(|idx| {
let (start, end) = det_offsets_fn(idx, window_size, len);
if end - start < min_periods {
None
} else {
unsafe { agg_window.update(start, end) }
agg_window.get_agg(idx)
}
});
let arr = PrimitiveArray::from_trusted_len_iter(out);
Ok(Box::new(arr))
}
pub(super) fn rolling_apply_weights<T, Fo, Fa>(
values: &[T],
window_size: usize,
min_periods: usize,
det_offsets_fn: Fo,
aggregator: Fa,
weights: &[T],
centered: bool,
) -> PolarsResult<ArrayRef>
where
T: NativeType + num_traits::Zero + std::ops::Div<Output = T> + Copy,
Fo: Fn(Idx, WindowSize, Len) -> (Start, End),
Fa: Fn(&[T], &[T]) -> T,
{
assert_eq!(weights.len(), window_size);
let len = values.len();
let out = (0..len)
.map(|idx| {
let (start, end) = det_offsets_fn(idx, window_size, len);
let vals = unsafe { values.get_unchecked(start..end) };
let win_len = end - start;
let weights_start = if centered {
let center = (window_size / 2) as isize;
let offset = center - (idx as isize - start as isize);
offset.max(0) as usize
} else if start == 0 {
weights.len() - win_len
} else {
0
};
let weights_slice = &weights[weights_start..weights_start + win_len];
aggregator(vals, weights_slice)
})
.collect_trusted::<Vec<T>>();
let validity = create_validity(min_periods, len, window_size, det_offsets_fn);
Ok(Box::new(PrimitiveArray::new(
ArrowDataType::from(T::PRIMITIVE),
out.into(),
validity.map(|b| b.into()),
)))
}
fn compute_var_weights<T>(vals: &[T], weights: &[T]) -> T
where
T: Float + std::ops::AddAssign,
{
let (wssq, wmean, total_weight) = vals.iter().zip(weights).fold(
(T::zero(), T::zero(), T::zero()),
|(wssq, wsum, wtot), (&v, &w)| (wssq + v * v * w, wsum + v * w, wtot + w),
);
if total_weight.is_zero() {
T::zero() } else {
let mean = wmean / total_weight;
(wssq / total_weight) - (mean * mean)
}
}
pub(crate) fn compute_sum_weights<T>(values: &[T], weights: &[T]) -> T
where
T: std::iter::Sum<T> + Copy + std::ops::Mul<Output = T>,
{
values.iter().zip(weights).map(|(v, w)| *v * *w).sum()
}
pub(crate) fn compute_mean_weights<T>(values: &[T], weights: &[T]) -> T
where
T: std::iter::Sum<T>
+ Copy
+ std::ops::Mul<Output = T>
+ std::ops::Div<Output = T>
+ num_traits::Zero,
{
let (weighted_sum, total_weight) = values
.iter()
.zip(weights)
.fold((T::zero(), T::zero()), |(wsum, wtot), (&v, &w)| {
(wsum + v * w, wtot + w)
});
if total_weight.is_zero() {
T::zero() } else {
weighted_sum / total_weight
}
}
pub(super) fn coerce_weights<T: NumCast>(weights: &[f64]) -> Vec<T>
where
{
weights
.iter()
.map(|v| NumCast::from(*v).unwrap())
.collect::<Vec<_>>()
}