Skip to main content

polars_compute/rolling/
mean.rs

1use super::no_nulls::RollingAggWindowNoNulls;
2use super::nulls::RollingAggWindowNulls;
3use super::sum::SumWindow;
4use super::*;
5
6pub struct MeanWindow<'a, T> {
7    sum: SumWindow<'a, T, f64>,
8}
9
10impl<T> RollingAggWindowNoNulls<T> for MeanWindow<'_, T>
11where
12    T: NativeType
13        + IsFloat
14        + std::iter::Sum
15        + AddAssign
16        + SubAssign
17        + Div<Output = T>
18        + NumCast
19        + Add<Output = T>
20        + Sub<Output = T>
21        + PartialOrd,
22{
23    type This<'a> = MeanWindow<'a, T>;
24
25    fn new<'a>(
26        slice: &'a [T],
27        start: usize,
28        end: usize,
29        params: Option<RollingFnParams>,
30        window_size: Option<usize>,
31    ) -> Self::This<'a> {
32        MeanWindow {
33            sum: <SumWindow<T, f64> as RollingAggWindowNoNulls<T>>::new(
34                slice,
35                start,
36                end,
37                params,
38                window_size,
39            ),
40        }
41    }
42
43    unsafe fn update(&mut self, new_start: usize, new_end: usize) {
44        unsafe {
45            RollingAggWindowNoNulls::update(&mut self.sum, new_start, new_end);
46        };
47    }
48
49    fn get_agg(&self, idx: usize) -> Option<T> {
50        let sum = RollingAggWindowNoNulls::get_agg(&self.sum, idx).unwrap();
51        (self.sum.start != self.sum.end)
52            .then(|| sum / NumCast::from(self.sum.end - self.sum.start).unwrap())
53    }
54
55    fn slice_len(&self) -> usize {
56        RollingAggWindowNulls::slice_len(&self.sum)
57    }
58}
59
60impl<
61    T: NativeType
62        + IsFloat
63        + Add<Output = T>
64        + Sub<Output = T>
65        + NumCast
66        + Div<Output = T>
67        + AddAssign
68        + SubAssign
69        + PartialOrd,
70> RollingAggWindowNulls<T> for MeanWindow<'_, T>
71{
72    type This<'a> = MeanWindow<'a, T>;
73
74    fn new<'a>(
75        slice: &'a [T],
76        validity: &'a Bitmap,
77        start: usize,
78        end: usize,
79        params: Option<RollingFnParams>,
80        window_size: Option<usize>,
81    ) -> Self::This<'a> {
82        MeanWindow {
83            sum: <SumWindow<T, f64> as RollingAggWindowNulls<T>>::new(
84                slice,
85                validity,
86                start,
87                end,
88                params,
89                window_size,
90            ),
91        }
92    }
93
94    unsafe fn update(&mut self, new_start: usize, new_end: usize) {
95        unsafe { RollingAggWindowNulls::update(&mut self.sum, new_start, new_end) };
96    }
97
98    fn get_agg(&self, idx: usize) -> Option<T> {
99        let sum = RollingAggWindowNulls::get_agg(&self.sum, idx);
100        let len = self.sum.end - self.sum.start;
101        if self.sum.null_count == len {
102            None
103        } else {
104            sum.map(|sum| {
105                sum / NumCast::from(self.sum.end - self.sum.start - self.sum.null_count).unwrap()
106            })
107        }
108    }
109
110    fn is_valid(&self, min_periods: usize) -> bool {
111        self.sum.is_valid(min_periods)
112    }
113
114    fn slice_len(&self) -> usize {
115        RollingAggWindowNulls::slice_len(&self.sum)
116    }
117}