polars_compute/rolling/
moment.rs

1use num_traits::{FromPrimitive, ToPrimitive};
2
3use super::no_nulls::RollingAggWindowNoNulls;
4use super::nulls::RollingAggWindowNulls;
5use super::*;
6use crate::moment::{KurtosisState, SkewState, VarState};
7
8pub trait StateUpdate {
9    fn new(params: Option<RollingFnParams>) -> Self;
10    fn reset(&mut self);
11    fn insert_one(&mut self, x: f64);
12    fn remove_one(&mut self, x: f64);
13    fn finalize(&self) -> Option<f64>;
14}
15
16pub struct VarianceMoment {
17    state: VarState,
18    ddof: u8,
19}
20
21impl StateUpdate for VarianceMoment {
22    fn new(params: Option<RollingFnParams>) -> Self {
23        let ddof = if let Some(RollingFnParams::Var(params)) = params {
24            params.ddof
25        } else {
26            1
27        };
28
29        Self {
30            state: VarState::default(),
31            ddof,
32        }
33    }
34
35    #[inline(always)]
36    fn reset(&mut self) {
37        self.state = VarState::default();
38    }
39
40    #[inline(always)]
41    fn insert_one(&mut self, x: f64) {
42        self.state.insert_one(x);
43    }
44
45    #[inline(always)]
46    fn remove_one(&mut self, x: f64) {
47        self.state.remove_one(x);
48    }
49
50    #[inline(always)]
51    fn finalize(&self) -> Option<f64> {
52        self.state.finalize(self.ddof)
53    }
54}
55
56pub struct KurtosisMoment {
57    state: KurtosisState,
58    fisher: bool,
59    bias: bool,
60}
61
62impl StateUpdate for KurtosisMoment {
63    fn new(params: Option<RollingFnParams>) -> Self {
64        let (fisher, bias) = if let Some(RollingFnParams::Kurtosis { fisher, bias }) = params {
65            (fisher, bias)
66        } else {
67            (false, false)
68        };
69
70        Self {
71            state: KurtosisState::default(),
72            fisher,
73            bias,
74        }
75    }
76
77    #[inline(always)]
78    fn reset(&mut self) {
79        self.state = KurtosisState::default();
80    }
81
82    #[inline(always)]
83    fn insert_one(&mut self, x: f64) {
84        self.state.insert_one(x);
85    }
86
87    #[inline(always)]
88    fn remove_one(&mut self, x: f64) {
89        self.state.remove_one(x);
90    }
91
92    #[inline(always)]
93    fn finalize(&self) -> Option<f64> {
94        self.state.finalize(self.fisher, self.bias)
95    }
96}
97
98pub struct SkewMoment {
99    state: SkewState,
100    bias: bool,
101}
102
103impl StateUpdate for SkewMoment {
104    fn new(params: Option<RollingFnParams>) -> Self {
105        let bias = if let Some(RollingFnParams::Skew { bias }) = params {
106            bias
107        } else {
108            false
109        };
110
111        Self {
112            state: SkewState::default(),
113            bias,
114        }
115    }
116
117    #[inline(always)]
118    fn reset(&mut self) {
119        self.state = SkewState::default();
120    }
121
122    #[inline(always)]
123    fn insert_one(&mut self, x: f64) {
124        self.state.insert_one(x);
125    }
126
127    #[inline(always)]
128    fn remove_one(&mut self, x: f64) {
129        self.state.remove_one(x);
130    }
131
132    #[inline(always)]
133    fn finalize(&self) -> Option<f64> {
134        self.state.finalize(self.bias)
135    }
136}
137
138pub struct MomentWindow<'a, T, M: StateUpdate> {
139    slice: &'a [T],
140    validity: Option<&'a Bitmap>,
141    moment: M,
142    non_finite_count: usize, // NaN or infinity.
143    null_count: usize,
144    last_start: usize,
145    last_end: usize,
146}
147
148impl<'a, T, M> MomentWindow<'a, T, M>
149where
150    T: NativeType + ToPrimitive + IsFloat + FromPrimitive,
151    M: StateUpdate,
152{
153    fn new_impl(
154        slice: &'a [T],
155        validity: Option<&'a Bitmap>,
156        params: Option<RollingFnParams>,
157    ) -> Self {
158        Self {
159            slice,
160            validity,
161            moment: M::new(params),
162            non_finite_count: 0,
163            null_count: 0,
164            last_start: 0,
165            last_end: 0,
166        }
167    }
168
169    #[inline(always)]
170    fn reset(&mut self) {
171        self.moment.reset();
172        self.non_finite_count = 0;
173        self.null_count = 0;
174    }
175
176    #[inline(always)]
177    fn insert(&mut self, val: T) {
178        if val.is_finite() {
179            self.moment.insert_one(NumCast::from(val).unwrap());
180        } else {
181            self.moment.insert_one(0.0); // A hack to replicate ddof null behavior.
182            self.non_finite_count += 1;
183        }
184    }
185
186    #[inline(always)]
187    fn remove(&mut self, val: T) {
188        if val.is_finite() {
189            self.moment.remove_one(NumCast::from(val).unwrap());
190        } else {
191            self.moment.remove_one(0.0); // A hack to replicate ddof null behavior.
192            self.non_finite_count -= 1;
193        }
194    }
195
196    #[inline(always)]
197    fn finalize(&self) -> Option<T> {
198        if self.non_finite_count > 0 {
199            self.moment
200                .finalize()
201                .map(|_v| T::from_f64(f64::NAN).unwrap())
202        } else {
203            self.moment.finalize().map(|v| T::from_f64(v).unwrap())
204        }
205    }
206}
207
208impl<'a, T, M> RollingAggWindowNoNulls<'a, T> for MomentWindow<'a, T, M>
209where
210    T: NativeType + ToPrimitive + IsFloat + FromPrimitive,
211    M: StateUpdate,
212{
213    fn new(
214        slice: &'a [T],
215        start: usize,
216        end: usize,
217        params: Option<RollingFnParams>,
218        _window_size: Option<usize>,
219    ) -> Self {
220        let mut out = Self::new_impl(slice, None, params);
221        unsafe { RollingAggWindowNoNulls::update(&mut out, start, end) };
222        out
223    }
224
225    // # Safety
226    // The start, end range must be in-bounds.
227    #[inline]
228    unsafe fn update(&mut self, start: usize, end: usize) -> Option<T> {
229        if start >= self.last_end {
230            self.reset();
231            self.last_start = start;
232            self.last_end = start;
233        }
234
235        for val in &self.slice[self.last_start..start] {
236            self.remove(*val);
237        }
238
239        for val in &self.slice[self.last_end..end] {
240            self.insert(*val);
241        }
242
243        self.last_start = start;
244        self.last_end = end;
245        self.finalize()
246    }
247}
248
249impl<'a, T, M> RollingAggWindowNulls<'a, T> for MomentWindow<'a, T, M>
250where
251    T: NativeType + ToPrimitive + IsFloat + FromPrimitive,
252    M: StateUpdate,
253{
254    unsafe fn new(
255        slice: &'a [T],
256        validity: &'a Bitmap,
257        start: usize,
258        end: usize,
259        params: Option<RollingFnParams>,
260        _window_size: Option<usize>,
261    ) -> Self {
262        let mut out = Self::new_impl(slice, Some(validity), params);
263        unsafe { RollingAggWindowNulls::update(&mut out, start, end) };
264        out
265    }
266
267    // # Safety
268    // The start, end range must be in-bounds.
269    #[inline]
270    unsafe fn update(&mut self, start: usize, end: usize) -> Option<T> {
271        let validity = unsafe { self.validity.unwrap_unchecked() };
272
273        if start >= self.last_end {
274            self.reset();
275            self.last_start = start;
276            self.last_end = start;
277        }
278
279        for idx in self.last_start..start {
280            let valid = unsafe { validity.get_bit_unchecked(idx) };
281            if valid {
282                self.remove(unsafe { *self.slice.get_unchecked(idx) });
283            } else {
284                self.null_count -= 1;
285            }
286        }
287
288        for idx in self.last_end..end {
289            let valid = unsafe { validity.get_bit_unchecked(idx) };
290            if valid {
291                self.insert(unsafe { *self.slice.get_unchecked(idx) });
292            } else {
293                self.null_count += 1;
294            }
295        }
296
297        self.last_start = start;
298        self.last_end = end;
299        self.finalize()
300    }
301
302    #[inline(always)]
303    fn is_valid(&self, min_periods: usize) -> bool {
304        ((self.last_end - self.last_start) - self.null_count) >= min_periods
305    }
306}