1use num_traits::{FromPrimitive, ToPrimitive};
2
3use super::no_nulls::RollingAggWindowNoNulls;
4use super::nulls::RollingAggWindowNulls;
5use super::*;
6use crate::moment::{KurtosisState, SkewState, VarState};
7
8pub trait StateUpdate {
9 fn new(params: Option<RollingFnParams>) -> Self;
10 fn reset(&mut self);
11 fn insert_one(&mut self, x: f64);
12 fn remove_one(&mut self, x: f64);
13 fn finalize(&self) -> Option<f64>;
14}
15
16pub struct VarianceMoment {
17 state: VarState,
18 ddof: u8,
19}
20
21impl StateUpdate for VarianceMoment {
22 fn new(params: Option<RollingFnParams>) -> Self {
23 let ddof = if let Some(RollingFnParams::Var(params)) = params {
24 params.ddof
25 } else {
26 1
27 };
28
29 Self {
30 state: VarState::default(),
31 ddof,
32 }
33 }
34
35 #[inline(always)]
36 fn reset(&mut self) {
37 self.state = VarState::default();
38 }
39
40 #[inline(always)]
41 fn insert_one(&mut self, x: f64) {
42 self.state.insert_one(x);
43 }
44
45 #[inline(always)]
46 fn remove_one(&mut self, x: f64) {
47 self.state.remove_one(x);
48 }
49
50 #[inline(always)]
51 fn finalize(&self) -> Option<f64> {
52 self.state.finalize(self.ddof)
53 }
54}
55
56pub struct KurtosisMoment {
57 state: KurtosisState,
58 fisher: bool,
59 bias: bool,
60}
61
62impl StateUpdate for KurtosisMoment {
63 fn new(params: Option<RollingFnParams>) -> Self {
64 let (fisher, bias) = if let Some(RollingFnParams::Kurtosis { fisher, bias }) = params {
65 (fisher, bias)
66 } else {
67 (false, false)
68 };
69
70 Self {
71 state: KurtosisState::default(),
72 fisher,
73 bias,
74 }
75 }
76
77 #[inline(always)]
78 fn reset(&mut self) {
79 self.state = KurtosisState::default();
80 }
81
82 #[inline(always)]
83 fn insert_one(&mut self, x: f64) {
84 self.state.insert_one(x);
85 }
86
87 #[inline(always)]
88 fn remove_one(&mut self, x: f64) {
89 self.state.remove_one(x);
90 }
91
92 #[inline(always)]
93 fn finalize(&self) -> Option<f64> {
94 self.state.finalize(self.fisher, self.bias)
95 }
96}
97
98pub struct SkewMoment {
99 state: SkewState,
100 bias: bool,
101}
102
103impl StateUpdate for SkewMoment {
104 fn new(params: Option<RollingFnParams>) -> Self {
105 let bias = if let Some(RollingFnParams::Skew { bias }) = params {
106 bias
107 } else {
108 false
109 };
110
111 Self {
112 state: SkewState::default(),
113 bias,
114 }
115 }
116
117 #[inline(always)]
118 fn reset(&mut self) {
119 self.state = SkewState::default();
120 }
121
122 #[inline(always)]
123 fn insert_one(&mut self, x: f64) {
124 self.state.insert_one(x);
125 }
126
127 #[inline(always)]
128 fn remove_one(&mut self, x: f64) {
129 self.state.remove_one(x);
130 }
131
132 #[inline(always)]
133 fn finalize(&self) -> Option<f64> {
134 self.state.finalize(self.bias)
135 }
136}
137
138pub struct MomentWindow<'a, T, M: StateUpdate> {
139 slice: &'a [T],
140 validity: Option<&'a Bitmap>,
141 moment: M,
142 non_finite_count: usize, null_count: usize,
144 last_start: usize,
145 last_end: usize,
146}
147
148impl<'a, T, M> MomentWindow<'a, T, M>
149where
150 T: NativeType + ToPrimitive + IsFloat + FromPrimitive,
151 M: StateUpdate,
152{
153 fn new_impl(
154 slice: &'a [T],
155 validity: Option<&'a Bitmap>,
156 params: Option<RollingFnParams>,
157 ) -> Self {
158 Self {
159 slice,
160 validity,
161 moment: M::new(params),
162 non_finite_count: 0,
163 null_count: 0,
164 last_start: 0,
165 last_end: 0,
166 }
167 }
168
169 #[inline(always)]
170 fn reset(&mut self) {
171 self.moment.reset();
172 self.non_finite_count = 0;
173 self.null_count = 0;
174 }
175
176 #[inline(always)]
177 fn insert(&mut self, val: T) {
178 if val.is_finite() {
179 self.moment.insert_one(NumCast::from(val).unwrap());
180 } else {
181 self.moment.insert_one(0.0); self.non_finite_count += 1;
183 }
184 }
185
186 #[inline(always)]
187 fn remove(&mut self, val: T) {
188 if val.is_finite() {
189 self.moment.remove_one(NumCast::from(val).unwrap());
190 } else {
191 self.moment.remove_one(0.0); self.non_finite_count -= 1;
193 }
194 }
195
196 #[inline(always)]
197 fn finalize(&self) -> Option<T> {
198 if self.non_finite_count > 0 {
199 self.moment
200 .finalize()
201 .map(|_v| T::from_f64(f64::NAN).unwrap())
202 } else {
203 self.moment.finalize().map(|v| T::from_f64(v).unwrap())
204 }
205 }
206}
207
208impl<'a, T, M> RollingAggWindowNoNulls<'a, T> for MomentWindow<'a, T, M>
209where
210 T: NativeType + ToPrimitive + IsFloat + FromPrimitive,
211 M: StateUpdate,
212{
213 fn new(
214 slice: &'a [T],
215 start: usize,
216 end: usize,
217 params: Option<RollingFnParams>,
218 _window_size: Option<usize>,
219 ) -> Self {
220 let mut out = Self::new_impl(slice, None, params);
221 unsafe { RollingAggWindowNoNulls::update(&mut out, start, end) };
222 out
223 }
224
225 #[inline]
228 unsafe fn update(&mut self, start: usize, end: usize) -> Option<T> {
229 if start >= self.last_end {
230 self.reset();
231 self.last_start = start;
232 self.last_end = start;
233 }
234
235 for val in &self.slice[self.last_start..start] {
236 self.remove(*val);
237 }
238
239 for val in &self.slice[self.last_end..end] {
240 self.insert(*val);
241 }
242
243 self.last_start = start;
244 self.last_end = end;
245 self.finalize()
246 }
247}
248
249impl<'a, T, M> RollingAggWindowNulls<'a, T> for MomentWindow<'a, T, M>
250where
251 T: NativeType + ToPrimitive + IsFloat + FromPrimitive,
252 M: StateUpdate,
253{
254 unsafe fn new(
255 slice: &'a [T],
256 validity: &'a Bitmap,
257 start: usize,
258 end: usize,
259 params: Option<RollingFnParams>,
260 _window_size: Option<usize>,
261 ) -> Self {
262 let mut out = Self::new_impl(slice, Some(validity), params);
263 unsafe { RollingAggWindowNulls::update(&mut out, start, end) };
264 out
265 }
266
267 #[inline]
270 unsafe fn update(&mut self, start: usize, end: usize) -> Option<T> {
271 let validity = unsafe { self.validity.unwrap_unchecked() };
272
273 if start >= self.last_end {
274 self.reset();
275 self.last_start = start;
276 self.last_end = start;
277 }
278
279 for idx in self.last_start..start {
280 let valid = unsafe { validity.get_bit_unchecked(idx) };
281 if valid {
282 self.remove(unsafe { *self.slice.get_unchecked(idx) });
283 } else {
284 self.null_count -= 1;
285 }
286 }
287
288 for idx in self.last_end..end {
289 let valid = unsafe { validity.get_bit_unchecked(idx) };
290 if valid {
291 self.insert(unsafe { *self.slice.get_unchecked(idx) });
292 } else {
293 self.null_count += 1;
294 }
295 }
296
297 self.last_start = start;
298 self.last_end = end;
299 self.finalize()
300 }
301
302 #[inline(always)]
303 fn is_valid(&self, min_periods: usize) -> bool {
304 ((self.last_end - self.last_start) - self.null_count) >= min_periods
305 }
306}