polars_compute/rolling/no_nulls/
mean.rs1#![allow(unsafe_op_in_unsafe_fn)]
2
3use polars_error::polars_ensure;
4
5use super::*;
6
7pub struct MeanWindow<'a, T> {
8 sum: SumWindow<'a, T, f64>,
9}
10
11impl<'a, T> RollingAggWindowNoNulls<'a, T> for MeanWindow<'a, T>
12where
13 T: NativeType
14 + IsFloat
15 + std::iter::Sum
16 + AddAssign
17 + SubAssign
18 + Div<Output = T>
19 + NumCast
20 + Add<Output = T>
21 + Sub<Output = T>,
22{
23 fn new(
24 slice: &'a [T],
25 start: usize,
26 end: usize,
27 params: Option<RollingFnParams>,
28 window_size: Option<usize>,
29 ) -> Self {
30 Self {
31 sum: SumWindow::<T, f64>::new(slice, start, end, params, window_size),
32 }
33 }
34
35 unsafe fn update(&mut self, start: usize, end: usize) -> Option<T> {
36 let sum = self.sum.update(start, end).unwrap_unchecked();
37 Some(sum / NumCast::from(end - start).unwrap())
38 }
39}
40
41pub fn rolling_mean<T>(
42 values: &[T],
43 window_size: usize,
44 min_periods: usize,
45 center: bool,
46 weights: Option<&[f64]>,
47 _params: Option<RollingFnParams>,
48) -> PolarsResult<ArrayRef>
49where
50 T: NativeType + Float + std::iter::Sum<T> + SubAssign + AddAssign + IsFloat,
51{
52 let offset_fn = match center {
53 true => det_offsets_center,
54 false => det_offsets,
55 };
56 match weights {
57 None => rolling_apply_agg_window::<MeanWindow<_>, _, _>(
58 values,
59 window_size,
60 min_periods,
61 offset_fn,
62 None,
63 ),
64 Some(weights) => {
65 let mut wts = no_nulls::coerce_weights(weights);
67 let wsum = wts.iter().fold(T::zero(), |acc, x| acc + *x);
68 polars_ensure!(
69 wsum != T::zero(),
70 ComputeError: "Weighted mean is undefined if weights sum to 0"
71 );
72 wts.iter_mut().for_each(|w| *w = *w / wsum);
73 no_nulls::rolling_apply_weights(
74 values,
75 window_size,
76 min_periods,
77 offset_fn,
78 no_nulls::compute_sum_weights,
79 &wts,
80 )
81 },
82 }
83}