polars_compute/rolling/no_nulls/
moment.rs

1#![allow(unsafe_op_in_unsafe_fn)]
2use num_traits::{FromPrimitive, ToPrimitive};
3use polars_error::polars_ensure;
4
5pub use super::super::moment::*;
6use super::*;
7
8pub struct MomentWindow<'a, T, M: StateUpdate> {
9    slice: &'a [T],
10    moment: M,
11    last_start: usize,
12    last_end: usize,
13    params: Option<RollingFnParams>,
14}
15
16impl<T: ToPrimitive + Copy, M: StateUpdate> MomentWindow<'_, T, M> {
17    fn compute_var(&mut self, start: usize, end: usize) {
18        self.moment = M::new(self.params);
19        for value in &self.slice[start..end] {
20            let value: f64 = NumCast::from(*value).unwrap();
21            self.moment.insert_one(value);
22        }
23    }
24}
25
26impl<'a, T: NativeType + IsFloat + Float + ToPrimitive + FromPrimitive, M: StateUpdate>
27    RollingAggWindowNoNulls<'a, T> for MomentWindow<'a, T, M>
28{
29    fn new(
30        slice: &'a [T],
31        start: usize,
32        end: usize,
33        params: Option<RollingFnParams>,
34        _window_size: Option<usize>,
35    ) -> Self {
36        let mut out = Self {
37            slice,
38            moment: M::new(params),
39            last_start: start,
40            last_end: end,
41            params,
42        };
43        out.compute_var(start, end);
44        out
45    }
46
47    unsafe fn update(&mut self, start: usize, end: usize) -> Option<T> {
48        let recompute_var = if start >= self.last_end {
49            true
50        } else {
51            // remove elements that should leave the window
52            let mut recompute_var = false;
53            for idx in self.last_start..start {
54                // SAFETY: we are in bounds
55                let leaving_value = *self.slice.get_unchecked(idx);
56
57                // if the leaving value is nan we need to recompute the window
58                if T::is_float() && !leaving_value.is_finite() {
59                    recompute_var = true;
60                    break;
61                }
62                let leaving_value: f64 = NumCast::from(leaving_value).unwrap();
63                self.moment.remove_one(leaving_value);
64            }
65            recompute_var
66        };
67
68        self.last_start = start;
69
70        // we traverse all values and compute
71        if recompute_var {
72            self.compute_var(start, end);
73        } else {
74            for idx in self.last_end..end {
75                let entering_value = *self.slice.get_unchecked(idx);
76                let entering_value: f64 = NumCast::from(entering_value).unwrap();
77
78                self.moment.insert_one(entering_value);
79            }
80        }
81        self.last_end = end;
82        self.moment.finalize().map(|v| T::from_f64(v).unwrap())
83    }
84}
85
86pub fn rolling_var<T>(
87    values: &[T],
88    window_size: usize,
89    min_periods: usize,
90    center: bool,
91    weights: Option<&[f64]>,
92    params: Option<RollingFnParams>,
93) -> PolarsResult<ArrayRef>
94where
95    T: NativeType + Float + IsFloat + ToPrimitive + FromPrimitive + AddAssign,
96{
97    let offset_fn = match center {
98        true => det_offsets_center,
99        false => det_offsets,
100    };
101    match weights {
102        None => rolling_apply_agg_window::<MomentWindow<_, VarianceMoment>, _, _>(
103            values,
104            window_size,
105            min_periods,
106            offset_fn,
107            params,
108        ),
109        Some(weights) => {
110            // Validate and standardize the weights like we do for the mean. This definition is fine
111            // because frequency weights and unbiasing don't make sense for rolling operations.
112            let mut wts = no_nulls::coerce_weights(weights);
113            let wsum = wts.iter().fold(T::zero(), |acc, x| acc + *x);
114            polars_ensure!(
115                wsum != T::zero(),
116                ComputeError: "Weighted variance is undefined if weights sum to 0"
117            );
118            wts.iter_mut().for_each(|w| *w = *w / wsum);
119            super::rolling_apply_weights(
120                values,
121                window_size,
122                min_periods,
123                offset_fn,
124                compute_var_weights,
125                &wts,
126            )
127        },
128    }
129}
130
131pub fn rolling_skew<T>(
132    values: &[T],
133    window_size: usize,
134    min_periods: usize,
135    center: bool,
136    params: Option<RollingFnParams>,
137) -> PolarsResult<ArrayRef>
138where
139    T: NativeType + Float + IsFloat + ToPrimitive + FromPrimitive + AddAssign,
140{
141    let offset_fn = match center {
142        true => det_offsets_center,
143        false => det_offsets,
144    };
145    rolling_apply_agg_window::<MomentWindow<_, SkewMoment>, _, _>(
146        values,
147        window_size,
148        min_periods,
149        offset_fn,
150        params,
151    )
152}
153
154pub fn rolling_kurtosis<T>(
155    values: &[T],
156    window_size: usize,
157    min_periods: usize,
158    center: bool,
159    params: Option<RollingFnParams>,
160) -> PolarsResult<ArrayRef>
161where
162    T: NativeType + Float + IsFloat + ToPrimitive + FromPrimitive + AddAssign,
163{
164    let offset_fn = match center {
165        true => det_offsets_center,
166        false => det_offsets,
167    };
168    rolling_apply_agg_window::<MomentWindow<_, KurtosisMoment>, _, _>(
169        values,
170        window_size,
171        min_periods,
172        offset_fn,
173        params,
174    )
175}
176
177#[cfg(test)]
178mod test {
179    use super::*;
180
181    #[test]
182    fn test_rolling_var() {
183        let values = &[1.0f64, 5.0, 3.0, 4.0];
184
185        let out = rolling_var(values, 2, 2, false, None, None).unwrap();
186        let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
187        let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
188        assert_eq!(out, &[None, Some(8.0), Some(2.0), Some(0.5)]);
189
190        let testpars = Some(RollingFnParams::Var(RollingVarParams { ddof: 0 }));
191        let out = rolling_var(values, 2, 2, false, None, testpars).unwrap();
192        let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
193        let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
194        assert_eq!(out, &[None, Some(4.0), Some(1.0), Some(0.25)]);
195
196        let out = rolling_var(values, 2, 1, false, None, None).unwrap();
197        let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
198        let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
199        // we cannot compare nans, so we compare the string values
200        assert_eq!(
201            format!("{:?}", out.as_slice()),
202            format!("{:?}", &[None, Some(8.0), Some(2.0), Some(0.5)])
203        );
204        // test nan handling.
205        let values = &[-10.0, 2.0, 3.0, f64::nan(), 5.0, 6.0, 7.0];
206        let out = rolling_var(values, 3, 3, false, None, None).unwrap();
207        let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
208        let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
209        // we cannot compare nans, so we compare the string values
210        assert_eq!(
211            format!("{:?}", out.as_slice()),
212            format!(
213                "{:?}",
214                &[
215                    None,
216                    None,
217                    Some(52.33333333333333),
218                    Some(f64::nan()),
219                    Some(f64::nan()),
220                    Some(f64::nan()),
221                    Some(1.0)
222                ]
223            )
224        );
225    }
226}