polars_compute/rolling/nulls/
mean.rs

1#![allow(unsafe_op_in_unsafe_fn)]
2use super::*;
3
4pub struct MeanWindow<'a, T> {
5    sum: SumWindow<'a, T, f64>,
6}
7
8impl<
9    'a,
10    T: NativeType
11        + IsFloat
12        + Add<Output = T>
13        + Sub<Output = T>
14        + NumCast
15        + Div<Output = T>
16        + AddAssign
17        + SubAssign,
18> RollingAggWindowNulls<'a, T> for MeanWindow<'a, T>
19{
20    unsafe fn new(
21        slice: &'a [T],
22        validity: &'a Bitmap,
23        start: usize,
24        end: usize,
25        params: Option<RollingFnParams>,
26        window_size: Option<usize>,
27    ) -> Self {
28        Self {
29            sum: SumWindow::new(slice, validity, start, end, params, window_size),
30        }
31    }
32
33    unsafe fn update(&mut self, start: usize, end: usize) -> Option<T> {
34        let sum = self.sum.update(start, end);
35        sum.map(|sum| sum / NumCast::from(end - start - self.sum.null_count).unwrap())
36    }
37    fn is_valid(&self, min_periods: usize) -> bool {
38        self.sum.is_valid(min_periods)
39    }
40}
41
42pub fn rolling_mean<T>(
43    arr: &PrimitiveArray<T>,
44    window_size: usize,
45    min_periods: usize,
46    center: bool,
47    weights: Option<&[f64]>,
48    _params: Option<RollingFnParams>,
49) -> ArrayRef
50where
51    T: NativeType
52        + IsFloat
53        + PartialOrd
54        + Add<Output = T>
55        + Sub<Output = T>
56        + NumCast
57        + AddAssign
58        + SubAssign
59        + Div<Output = T>,
60{
61    if weights.is_some() {
62        panic!("weights not yet supported on array with null values")
63    }
64    if center {
65        rolling_apply_agg_window::<MeanWindow<_>, _, _>(
66            arr.values().as_slice(),
67            arr.validity().as_ref().unwrap(),
68            window_size,
69            min_periods,
70            det_offsets_center,
71            None,
72        )
73    } else {
74        rolling_apply_agg_window::<MeanWindow<_>, _, _>(
75            arr.values().as_slice(),
76            arr.validity().as_ref().unwrap(),
77            window_size,
78            min_periods,
79            det_offsets,
80            None,
81        )
82    }
83}