polars_compute/rolling/nulls/
moment.rs

1#![allow(unsafe_op_in_unsafe_fn)]
2
3use num_traits::{FromPrimitive, ToPrimitive};
4
5pub use super::super::moment::*;
6use super::*;
7
8pub struct MomentWindow<'a, T, M: StateUpdate> {
9    slice: &'a [T],
10    validity: &'a Bitmap,
11    moment: Option<M>,
12    last_start: usize,
13    last_end: usize,
14    null_count: usize,
15    params: Option<RollingFnParams>,
16}
17
18impl<T: NativeType + ToPrimitive, M: StateUpdate> MomentWindow<'_, T, M> {
19    // compute sum from the entire window
20    unsafe fn compute_moment_and_null_count(&mut self, start: usize, end: usize) {
21        self.moment = None;
22        let mut idx = start;
23        self.null_count = 0;
24        for value in &self.slice[start..end] {
25            let valid = self.validity.get_bit_unchecked(idx);
26            if valid {
27                let value: f64 = NumCast::from(*value).unwrap();
28                self.moment
29                    .get_or_insert_with(|| M::new(self.params))
30                    .insert_one(value);
31            } else {
32                self.null_count += 1;
33            }
34            idx += 1;
35        }
36    }
37}
38
39impl<'a, T: NativeType + ToPrimitive + IsFloat + FromPrimitive, M: StateUpdate>
40    RollingAggWindowNulls<'a, T> for MomentWindow<'a, T, M>
41{
42    unsafe fn new(
43        slice: &'a [T],
44        validity: &'a Bitmap,
45        start: usize,
46        end: usize,
47        params: Option<RollingFnParams>,
48        _window_size: Option<usize>,
49    ) -> Self {
50        let mut out = Self {
51            slice,
52            validity,
53            moment: None,
54            last_start: start,
55            last_end: end,
56            null_count: 0,
57            params,
58        };
59        out.compute_moment_and_null_count(start, end);
60        out
61    }
62
63    unsafe fn update(&mut self, start: usize, end: usize) -> Option<T> {
64        let recompute_var = if start >= self.last_end {
65            true
66        } else {
67            // remove elements that should leave the window
68            let mut recompute_var = false;
69            for idx in self.last_start..start {
70                // SAFETY:
71                // we are in bounds
72                let valid = self.validity.get_bit_unchecked(idx);
73                if valid {
74                    let leaving_value = *self.slice.get_unchecked(idx);
75
76                    // if the leaving value is nan we need to recompute the window
77                    if T::is_float() && !leaving_value.is_finite() {
78                        recompute_var = true;
79                        break;
80                    }
81                    let leaving_value: f64 = NumCast::from(leaving_value).unwrap();
82                    if let Some(v) = self.moment.as_mut() {
83                        v.remove_one(leaving_value)
84                    }
85                } else {
86                    // null value leaving the window
87                    self.null_count -= 1;
88
89                    // self.sum is None and the leaving value is None
90                    // if the entering value is valid, we might get a new sum.
91                    if self.moment.is_none() {
92                        recompute_var = true;
93                        break;
94                    }
95                }
96            }
97            recompute_var
98        };
99
100        self.last_start = start;
101
102        // we traverse all values and compute
103        if recompute_var {
104            self.compute_moment_and_null_count(start, end);
105        } else {
106            for idx in self.last_end..end {
107                let valid = self.validity.get_bit_unchecked(idx);
108
109                if valid {
110                    let entering_value = *self.slice.get_unchecked(idx);
111                    let entering_value: f64 = NumCast::from(entering_value).unwrap();
112                    self.moment
113                        .get_or_insert_with(|| M::new(self.params))
114                        .insert_one(entering_value);
115                } else {
116                    // null value entering the window
117                    self.null_count += 1;
118                }
119            }
120        }
121        self.last_end = end;
122        self.moment.as_ref().and_then(|v| {
123            let out = v.finalize();
124            out.map(|v| T::from_f64(v).unwrap())
125        })
126    }
127
128    fn is_valid(&self, min_periods: usize) -> bool {
129        ((self.last_end - self.last_start) - self.null_count) >= min_periods
130    }
131}
132
133pub fn rolling_var<T>(
134    arr: &PrimitiveArray<T>,
135    window_size: usize,
136    min_periods: usize,
137    center: bool,
138    weights: Option<&[f64]>,
139    params: Option<RollingFnParams>,
140) -> ArrayRef
141where
142    T: NativeType + ToPrimitive + FromPrimitive + IsFloat + Float,
143{
144    if weights.is_some() {
145        panic!("weights not yet supported on array with null values")
146    }
147    let offsets_fn = if center {
148        det_offsets_center
149    } else {
150        det_offsets
151    };
152    rolling_apply_agg_window::<MomentWindow<_, VarianceMoment>, _, _>(
153        arr.values().as_slice(),
154        arr.validity().as_ref().unwrap(),
155        window_size,
156        min_periods,
157        offsets_fn,
158        params,
159    )
160}
161
162pub fn rolling_skew<T>(
163    arr: &PrimitiveArray<T>,
164    window_size: usize,
165    min_periods: usize,
166    center: bool,
167    params: Option<RollingFnParams>,
168) -> ArrayRef
169where
170    T: NativeType + ToPrimitive + FromPrimitive + IsFloat + Float,
171{
172    let offsets_fn = if center {
173        det_offsets_center
174    } else {
175        det_offsets
176    };
177    rolling_apply_agg_window::<MomentWindow<_, SkewMoment>, _, _>(
178        arr.values().as_slice(),
179        arr.validity().as_ref().unwrap(),
180        window_size,
181        min_periods,
182        offsets_fn,
183        params,
184    )
185}
186
187pub fn rolling_kurtosis<T>(
188    arr: &PrimitiveArray<T>,
189    window_size: usize,
190    min_periods: usize,
191    center: bool,
192    params: Option<RollingFnParams>,
193) -> ArrayRef
194where
195    T: NativeType + ToPrimitive + FromPrimitive + IsFloat + Float,
196{
197    let offsets_fn = if center {
198        det_offsets_center
199    } else {
200        det_offsets
201    };
202    rolling_apply_agg_window::<MomentWindow<_, KurtosisMoment>, _, _>(
203        arr.values().as_slice(),
204        arr.validity().as_ref().unwrap(),
205        window_size,
206        min_periods,
207        offsets_fn,
208        params,
209    )
210}