polars_compute/rolling/no_nulls/
mean.rs

1#![allow(unsafe_op_in_unsafe_fn)]
2
3use polars_error::polars_ensure;
4
5use super::*;
6
7pub struct MeanWindow<'a, T> {
8    sum: SumWindow<'a, T, f64>,
9}
10
11impl<'a, T> RollingAggWindowNoNulls<'a, T> for MeanWindow<'a, T>
12where
13    T: NativeType
14        + IsFloat
15        + std::iter::Sum
16        + AddAssign
17        + SubAssign
18        + Div<Output = T>
19        + NumCast
20        + Add<Output = T>
21        + Sub<Output = T>,
22{
23    fn new(
24        slice: &'a [T],
25        start: usize,
26        end: usize,
27        params: Option<RollingFnParams>,
28        window_size: Option<usize>,
29    ) -> Self {
30        Self {
31            sum: SumWindow::<T, f64>::new(slice, start, end, params, window_size),
32        }
33    }
34
35    unsafe fn update(&mut self, start: usize, end: usize) -> Option<T> {
36        let sum = self.sum.update(start, end).unwrap_unchecked();
37        Some(sum / NumCast::from(end - start).unwrap())
38    }
39}
40
41pub fn rolling_mean<T>(
42    values: &[T],
43    window_size: usize,
44    min_periods: usize,
45    center: bool,
46    weights: Option<&[f64]>,
47    _params: Option<RollingFnParams>,
48) -> PolarsResult<ArrayRef>
49where
50    T: NativeType + Float + std::iter::Sum<T> + SubAssign + AddAssign + IsFloat,
51{
52    let offset_fn = match center {
53        true => det_offsets_center,
54        false => det_offsets,
55    };
56    match weights {
57        None => rolling_apply_agg_window::<MeanWindow<_>, _, _>(
58            values,
59            window_size,
60            min_periods,
61            offset_fn,
62            None,
63        ),
64        Some(weights) => {
65            // A weighted mean is a weighted sum with normalized weights
66            let mut wts = no_nulls::coerce_weights(weights);
67            let wsum = wts.iter().fold(T::zero(), |acc, x| acc + *x);
68            polars_ensure!(
69                wsum != T::zero(),
70                ComputeError: "Weighted mean is undefined if weights sum to 0"
71            );
72            wts.iter_mut().for_each(|w| *w = *w / wsum);
73            no_nulls::rolling_apply_weights(
74                values,
75                window_size,
76                min_periods,
77                offset_fn,
78                no_nulls::compute_sum_weights,
79                &wts,
80            )
81        },
82    }
83}