polars_compute/rolling/nulls/
mod.rs

1mod mean;
2mod min_max;
3mod moment;
4mod quantile;
5mod sum;
6
7use arrow::legacy::utils::CustomIterTools;
8pub use mean::*;
9pub use min_max::*;
10pub use moment::*;
11pub use quantile::*;
12pub use sum::*;
13
14use super::*;
15
16pub trait RollingAggWindowNulls<'a, T: NativeType> {
17    /// # Safety
18    /// `start` and `end` must be in bounds for `slice` and `validity`
19    unsafe fn new(
20        slice: &'a [T],
21        validity: &'a Bitmap,
22        start: usize,
23        end: usize,
24        params: Option<RollingFnParams>,
25        window_size: Option<usize>,
26    ) -> Self;
27
28    /// # Safety
29    /// `start` and `end` must be in bounds of `slice` and `bitmap`
30    unsafe fn update(&mut self, start: usize, end: usize) -> Option<T>;
31
32    fn is_valid(&self, min_periods: usize) -> bool;
33}
34
35// Use an aggregation window that maintains the state
36pub(super) fn rolling_apply_agg_window<'a, Agg, T, Fo>(
37    values: &'a [T],
38    validity: &'a Bitmap,
39    window_size: usize,
40    min_periods: usize,
41    det_offsets_fn: Fo,
42    params: Option<RollingFnParams>,
43) -> ArrayRef
44where
45    Fo: Fn(Idx, WindowSize, Len) -> (Start, End) + Copy,
46    Agg: RollingAggWindowNulls<'a, T>,
47    T: IsFloat + NativeType,
48{
49    let len = values.len();
50    let (start, end) = det_offsets_fn(0, window_size, len);
51    // SAFETY; we are in bounds
52    let mut agg_window =
53        unsafe { Agg::new(values, validity, start, end, params, Some(window_size)) };
54
55    let mut validity = create_validity(min_periods, len, window_size, det_offsets_fn)
56        .unwrap_or_else(|| {
57            let mut validity = MutableBitmap::with_capacity(len);
58            validity.extend_constant(len, true);
59            validity
60        });
61
62    let out = (0..len)
63        .map(|idx| {
64            let (start, end) = det_offsets_fn(idx, window_size, len);
65            // SAFETY:
66            // we are in bounds
67            let agg = unsafe { agg_window.update(start, end) };
68            match agg {
69                Some(val) => {
70                    if agg_window.is_valid(min_periods) {
71                        val
72                    } else {
73                        // SAFETY: we are in bounds
74                        unsafe { validity.set_unchecked(idx, false) };
75                        T::default()
76                    }
77                },
78                None => {
79                    // SAFETY: we are in bounds
80                    unsafe { validity.set_unchecked(idx, false) };
81                    T::default()
82                },
83            }
84        })
85        .collect_trusted::<Vec<_>>();
86
87    Box::new(PrimitiveArray::new(
88        T::PRIMITIVE.into(),
89        out.into(),
90        Some(validity.into()),
91    ))
92}
93
94#[cfg(test)]
95mod test {
96    use arrow::array::{Array, Int32Array};
97    use arrow::buffer::Buffer;
98    use arrow::datatypes::ArrowDataType;
99    use polars_utils::min_max::MaxIgnoreNan;
100
101    use super::*;
102    use crate::rolling::min_max::MinMaxWindow;
103
104    fn get_null_arr() -> PrimitiveArray<f64> {
105        // 1, None, -1, 4
106        let buf = Buffer::from(vec![1.0, 0.0, -1.0, 4.0]);
107        PrimitiveArray::new(
108            ArrowDataType::Float64,
109            buf,
110            Some(Bitmap::from(&[true, false, true, true])),
111        )
112    }
113
114    #[test]
115    fn test_rolling_sum_nulls() {
116        let buf = Buffer::from(vec![1.0, 2.0, 3.0, 4.0]);
117        let arr = &PrimitiveArray::new(
118            ArrowDataType::Float64,
119            buf,
120            Some(Bitmap::from(&[true, false, true, true])),
121        );
122
123        let out = rolling_sum(arr, 2, 2, false, None, None);
124        let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
125        let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
126        assert_eq!(out, &[None, None, None, Some(7.0)]);
127
128        let out = rolling_sum(arr, 2, 1, false, None, None);
129        let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
130        let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
131        assert_eq!(out, &[Some(1.0), Some(1.0), Some(3.0), Some(7.0)]);
132
133        let out = rolling_sum(arr, 4, 1, false, None, None);
134        let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
135        let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
136        assert_eq!(out, &[Some(1.0), Some(1.0), Some(4.0), Some(8.0)]);
137
138        let out = rolling_sum(arr, 4, 1, true, None, None);
139        let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
140        let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
141        assert_eq!(out, &[Some(1.0), Some(4.0), Some(8.0), Some(7.0)]);
142
143        let out = rolling_sum(arr, 4, 4, true, None, None);
144        let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
145        let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
146        assert_eq!(out, &[None, None, None, None]);
147    }
148
149    #[test]
150    fn test_rolling_mean_nulls() {
151        let arr = get_null_arr();
152        let arr = &arr;
153
154        let out = rolling_mean(arr, 2, 2, false, None, None);
155        let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
156        let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
157        assert_eq!(out, &[None, None, None, Some(1.5)]);
158
159        let out = rolling_mean(arr, 2, 1, false, None, None);
160        let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
161        let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
162        assert_eq!(out, &[Some(1.0), Some(1.0), Some(-1.0), Some(1.5)]);
163
164        let out = rolling_mean(arr, 4, 1, false, None, None);
165        let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
166        let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
167        assert_eq!(out, &[Some(1.0), Some(1.0), Some(0.0), Some(4.0 / 3.0)]);
168    }
169
170    #[test]
171    fn test_rolling_var_nulls() {
172        let arr = get_null_arr();
173        let arr = &arr;
174
175        let out = rolling_var(arr, 3, 1, false, None, None);
176        let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
177        let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
178
179        assert_eq!(out, &[None, None, Some(2.0), Some(12.5)]);
180
181        let testpars = Some(RollingFnParams::Var(RollingVarParams { ddof: 0 }));
182        let out = rolling_var(arr, 3, 1, false, None, testpars);
183        let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
184        let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
185
186        assert_eq!(out, &[Some(0.0), Some(0.0), Some(1.0), Some(6.25)]);
187
188        let out = rolling_var(arr, 4, 1, false, None, None);
189        let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
190        let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
191        assert_eq!(out, &[None, None, Some(2.0), Some(6.333333333333334)]);
192
193        let out = rolling_var(arr, 4, 1, false, None, testpars);
194        let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
195        let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
196        assert_eq!(
197            out,
198            &[Some(0.), Some(0.0), Some(1.0), Some(4.222222222222222)]
199        );
200    }
201
202    #[test]
203    fn test_rolling_max_no_nulls() {
204        let buf = Buffer::from(vec![1.0, 2.0, 3.0, 4.0]);
205        let arr = &PrimitiveArray::new(
206            ArrowDataType::Float64,
207            buf,
208            Some(Bitmap::from(&[true, true, true, true])),
209        );
210        let out = rolling_max(arr, 4, 1, false, None, None);
211        let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
212        let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
213        assert_eq!(out, &[Some(1.0), Some(2.0), Some(3.0), Some(4.0)]);
214
215        let out = rolling_max(arr, 2, 2, false, None, None);
216        let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
217        let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
218        assert_eq!(out, &[None, Some(2.0), Some(3.0), Some(4.0)]);
219
220        let out = rolling_max(arr, 4, 4, false, None, None);
221        let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
222        let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
223        assert_eq!(out, &[None, None, None, Some(4.0)]);
224
225        let buf = Buffer::from(vec![4.0, 3.0, 2.0, 1.0]);
226        let arr = &PrimitiveArray::new(
227            ArrowDataType::Float64,
228            buf,
229            Some(Bitmap::from(&[true, true, true, true])),
230        );
231        let out = rolling_max(arr, 2, 1, false, None, None);
232        let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
233        let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
234        assert_eq!(out, &[Some(4.0), Some(4.0), Some(3.0), Some(2.0)]);
235
236        let out =
237            super::no_nulls::rolling_max(arr.values().as_slice(), 2, 1, false, None, None).unwrap();
238        let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
239        let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
240        assert_eq!(out, &[Some(4.0), Some(4.0), Some(3.0), Some(2.0)]);
241    }
242
243    #[test]
244    fn test_rolling_extrema_nulls() {
245        let vals = vec![3, 3, 3, 10, 10, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1];
246        let validity = Bitmap::new_with_value(true, vals.len());
247        let window_size = 3;
248        let min_periods = 3;
249
250        let arr = Int32Array::new(ArrowDataType::Int32, vals.into(), Some(validity));
251
252        let out = rolling_apply_agg_window::<MinMaxWindow<i32, MaxIgnoreNan>, _, _>(
253            arr.values().as_slice(),
254            arr.validity().as_ref().unwrap(),
255            window_size,
256            min_periods,
257            det_offsets,
258            None,
259        );
260        let arr = out.as_any().downcast_ref::<Int32Array>().unwrap();
261        assert_eq!(arr.null_count(), 2);
262        assert_eq!(
263            &arr.values().as_slice()[2..],
264            &[3, 10, 10, 10, 10, 10, 9, 8, 7, 6, 5, 4, 3]
265        );
266    }
267}