polars_compute/rolling/no_nulls/
moment.rs1#![allow(unsafe_op_in_unsafe_fn)]
2use num_traits::{FromPrimitive, ToPrimitive};
3use polars_error::polars_ensure;
4
5pub use super::super::moment::*;
6use super::*;
7
8pub struct MomentWindow<'a, T, M: StateUpdate> {
9 slice: &'a [T],
10 moment: M,
11 last_start: usize,
12 last_end: usize,
13 params: Option<RollingFnParams>,
14}
15
16impl<T: ToPrimitive + Copy, M: StateUpdate> MomentWindow<'_, T, M> {
17 fn compute_var(&mut self, start: usize, end: usize) {
18 self.moment = M::new(self.params);
19 for value in &self.slice[start..end] {
20 let value: f64 = NumCast::from(*value).unwrap();
21 self.moment.insert_one(value);
22 }
23 }
24}
25
26impl<'a, T: NativeType + IsFloat + Float + ToPrimitive + FromPrimitive, M: StateUpdate>
27 RollingAggWindowNoNulls<'a, T> for MomentWindow<'a, T, M>
28{
29 fn new(
30 slice: &'a [T],
31 start: usize,
32 end: usize,
33 params: Option<RollingFnParams>,
34 _window_size: Option<usize>,
35 ) -> Self {
36 let mut out = Self {
37 slice,
38 moment: M::new(params),
39 last_start: start,
40 last_end: end,
41 params,
42 };
43 out.compute_var(start, end);
44 out
45 }
46
47 unsafe fn update(&mut self, start: usize, end: usize) -> Option<T> {
48 let recompute_var = if start >= self.last_end {
49 true
50 } else {
51 let mut recompute_var = false;
53 for idx in self.last_start..start {
54 let leaving_value = *self.slice.get_unchecked(idx);
56
57 if T::is_float() && !leaving_value.is_finite() {
59 recompute_var = true;
60 break;
61 }
62 let leaving_value: f64 = NumCast::from(leaving_value).unwrap();
63 self.moment.remove_one(leaving_value);
64 }
65 recompute_var
66 };
67
68 self.last_start = start;
69
70 if recompute_var {
72 self.compute_var(start, end);
73 } else {
74 for idx in self.last_end..end {
75 let entering_value = *self.slice.get_unchecked(idx);
76 let entering_value: f64 = NumCast::from(entering_value).unwrap();
77
78 self.moment.insert_one(entering_value);
79 }
80 }
81 self.last_end = end;
82 self.moment.finalize().map(|v| T::from_f64(v).unwrap())
83 }
84}
85
86pub fn rolling_var<T>(
87 values: &[T],
88 window_size: usize,
89 min_periods: usize,
90 center: bool,
91 weights: Option<&[f64]>,
92 params: Option<RollingFnParams>,
93) -> PolarsResult<ArrayRef>
94where
95 T: NativeType + Float + IsFloat + ToPrimitive + FromPrimitive + AddAssign,
96{
97 let offset_fn = match center {
98 true => det_offsets_center,
99 false => det_offsets,
100 };
101 match weights {
102 None => rolling_apply_agg_window::<MomentWindow<_, VarianceMoment>, _, _>(
103 values,
104 window_size,
105 min_periods,
106 offset_fn,
107 params,
108 ),
109 Some(weights) => {
110 let mut wts = no_nulls::coerce_weights(weights);
113 let wsum = wts.iter().fold(T::zero(), |acc, x| acc + *x);
114 polars_ensure!(
115 wsum != T::zero(),
116 ComputeError: "Weighted variance is undefined if weights sum to 0"
117 );
118 wts.iter_mut().for_each(|w| *w = *w / wsum);
119 super::rolling_apply_weights(
120 values,
121 window_size,
122 min_periods,
123 offset_fn,
124 compute_var_weights,
125 &wts,
126 )
127 },
128 }
129}
130
131pub fn rolling_skew<T>(
132 values: &[T],
133 window_size: usize,
134 min_periods: usize,
135 center: bool,
136 params: Option<RollingFnParams>,
137) -> PolarsResult<ArrayRef>
138where
139 T: NativeType + Float + IsFloat + ToPrimitive + FromPrimitive + AddAssign,
140{
141 let offset_fn = match center {
142 true => det_offsets_center,
143 false => det_offsets,
144 };
145 rolling_apply_agg_window::<MomentWindow<_, SkewMoment>, _, _>(
146 values,
147 window_size,
148 min_periods,
149 offset_fn,
150 params,
151 )
152}
153
154pub fn rolling_kurtosis<T>(
155 values: &[T],
156 window_size: usize,
157 min_periods: usize,
158 center: bool,
159 params: Option<RollingFnParams>,
160) -> PolarsResult<ArrayRef>
161where
162 T: NativeType + Float + IsFloat + ToPrimitive + FromPrimitive + AddAssign,
163{
164 let offset_fn = match center {
165 true => det_offsets_center,
166 false => det_offsets,
167 };
168 rolling_apply_agg_window::<MomentWindow<_, KurtosisMoment>, _, _>(
169 values,
170 window_size,
171 min_periods,
172 offset_fn,
173 params,
174 )
175}
176
177#[cfg(test)]
178mod test {
179 use super::*;
180
181 #[test]
182 fn test_rolling_var() {
183 let values = &[1.0f64, 5.0, 3.0, 4.0];
184
185 let out = rolling_var(values, 2, 2, false, None, None).unwrap();
186 let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
187 let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
188 assert_eq!(out, &[None, Some(8.0), Some(2.0), Some(0.5)]);
189
190 let testpars = Some(RollingFnParams::Var(RollingVarParams { ddof: 0 }));
191 let out = rolling_var(values, 2, 2, false, None, testpars).unwrap();
192 let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
193 let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
194 assert_eq!(out, &[None, Some(4.0), Some(1.0), Some(0.25)]);
195
196 let out = rolling_var(values, 2, 1, false, None, None).unwrap();
197 let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
198 let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
199 assert_eq!(
201 format!("{:?}", out.as_slice()),
202 format!("{:?}", &[None, Some(8.0), Some(2.0), Some(0.5)])
203 );
204 let values = &[-10.0, 2.0, 3.0, f64::nan(), 5.0, 6.0, 7.0];
206 let out = rolling_var(values, 3, 3, false, None, None).unwrap();
207 let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
208 let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
209 assert_eq!(
211 format!("{:?}", out.as_slice()),
212 format!(
213 "{:?}",
214 &[
215 None,
216 None,
217 Some(52.33333333333333),
218 Some(f64::nan()),
219 Some(f64::nan()),
220 Some(f64::nan()),
221 Some(1.0)
222 ]
223 )
224 );
225 }
226}