1mod mean;
2mod min_max;
3mod moment;
4mod quantile;
5mod sum;
6
7use arrow::legacy::utils::CustomIterTools;
8pub use mean::*;
9pub use min_max::*;
10pub use moment::*;
11pub use quantile::*;
12pub use sum::*;
13
14use super::*;
15
16pub trait RollingAggWindowNulls<'a, T: NativeType> {
17 unsafe fn new(
20 slice: &'a [T],
21 validity: &'a Bitmap,
22 start: usize,
23 end: usize,
24 params: Option<RollingFnParams>,
25 window_size: Option<usize>,
26 ) -> Self;
27
28 unsafe fn update(&mut self, start: usize, end: usize) -> Option<T>;
31
32 fn is_valid(&self, min_periods: usize) -> bool;
33}
34
35pub(super) fn rolling_apply_agg_window<'a, Agg, T, Fo>(
37 values: &'a [T],
38 validity: &'a Bitmap,
39 window_size: usize,
40 min_periods: usize,
41 det_offsets_fn: Fo,
42 params: Option<RollingFnParams>,
43) -> ArrayRef
44where
45 Fo: Fn(Idx, WindowSize, Len) -> (Start, End) + Copy,
46 Agg: RollingAggWindowNulls<'a, T>,
47 T: IsFloat + NativeType,
48{
49 let len = values.len();
50 let (start, end) = det_offsets_fn(0, window_size, len);
51 let mut agg_window =
53 unsafe { Agg::new(values, validity, start, end, params, Some(window_size)) };
54
55 let mut validity = create_validity(min_periods, len, window_size, det_offsets_fn)
56 .unwrap_or_else(|| {
57 let mut validity = MutableBitmap::with_capacity(len);
58 validity.extend_constant(len, true);
59 validity
60 });
61
62 let out = (0..len)
63 .map(|idx| {
64 let (start, end) = det_offsets_fn(idx, window_size, len);
65 let agg = unsafe { agg_window.update(start, end) };
68 match agg {
69 Some(val) => {
70 if agg_window.is_valid(min_periods) {
71 val
72 } else {
73 unsafe { validity.set_unchecked(idx, false) };
75 T::default()
76 }
77 },
78 None => {
79 unsafe { validity.set_unchecked(idx, false) };
81 T::default()
82 },
83 }
84 })
85 .collect_trusted::<Vec<_>>();
86
87 Box::new(PrimitiveArray::new(
88 T::PRIMITIVE.into(),
89 out.into(),
90 Some(validity.into()),
91 ))
92}
93
94#[cfg(test)]
95mod test {
96 use arrow::array::{Array, Int32Array};
97 use arrow::buffer::Buffer;
98 use arrow::datatypes::ArrowDataType;
99 use polars_utils::min_max::MaxIgnoreNan;
100
101 use super::*;
102 use crate::rolling::min_max::MinMaxWindow;
103
104 fn get_null_arr() -> PrimitiveArray<f64> {
105 let buf = Buffer::from(vec![1.0, 0.0, -1.0, 4.0]);
107 PrimitiveArray::new(
108 ArrowDataType::Float64,
109 buf,
110 Some(Bitmap::from(&[true, false, true, true])),
111 )
112 }
113
114 #[test]
115 fn test_rolling_sum_nulls() {
116 let buf = Buffer::from(vec![1.0, 2.0, 3.0, 4.0]);
117 let arr = &PrimitiveArray::new(
118 ArrowDataType::Float64,
119 buf,
120 Some(Bitmap::from(&[true, false, true, true])),
121 );
122
123 let out = rolling_sum(arr, 2, 2, false, None, None);
124 let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
125 let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
126 assert_eq!(out, &[None, None, None, Some(7.0)]);
127
128 let out = rolling_sum(arr, 2, 1, false, None, None);
129 let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
130 let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
131 assert_eq!(out, &[Some(1.0), Some(1.0), Some(3.0), Some(7.0)]);
132
133 let out = rolling_sum(arr, 4, 1, false, None, None);
134 let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
135 let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
136 assert_eq!(out, &[Some(1.0), Some(1.0), Some(4.0), Some(8.0)]);
137
138 let out = rolling_sum(arr, 4, 1, true, None, None);
139 let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
140 let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
141 assert_eq!(out, &[Some(1.0), Some(4.0), Some(8.0), Some(7.0)]);
142
143 let out = rolling_sum(arr, 4, 4, true, None, None);
144 let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
145 let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
146 assert_eq!(out, &[None, None, None, None]);
147 }
148
149 #[test]
150 fn test_rolling_mean_nulls() {
151 let arr = get_null_arr();
152 let arr = &arr;
153
154 let out = rolling_mean(arr, 2, 2, false, None, None);
155 let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
156 let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
157 assert_eq!(out, &[None, None, None, Some(1.5)]);
158
159 let out = rolling_mean(arr, 2, 1, false, None, None);
160 let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
161 let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
162 assert_eq!(out, &[Some(1.0), Some(1.0), Some(-1.0), Some(1.5)]);
163
164 let out = rolling_mean(arr, 4, 1, false, None, None);
165 let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
166 let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
167 assert_eq!(out, &[Some(1.0), Some(1.0), Some(0.0), Some(4.0 / 3.0)]);
168 }
169
170 #[test]
171 fn test_rolling_var_nulls() {
172 let arr = get_null_arr();
173 let arr = &arr;
174
175 let out = rolling_var(arr, 3, 1, false, None, None);
176 let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
177 let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
178
179 assert_eq!(out, &[None, None, Some(2.0), Some(12.5)]);
180
181 let testpars = Some(RollingFnParams::Var(RollingVarParams { ddof: 0 }));
182 let out = rolling_var(arr, 3, 1, false, None, testpars);
183 let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
184 let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
185
186 assert_eq!(out, &[Some(0.0), Some(0.0), Some(1.0), Some(6.25)]);
187
188 let out = rolling_var(arr, 4, 1, false, None, None);
189 let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
190 let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
191 assert_eq!(out, &[None, None, Some(2.0), Some(6.333333333333334)]);
192
193 let out = rolling_var(arr, 4, 1, false, None, testpars);
194 let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
195 let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
196 assert_eq!(
197 out,
198 &[Some(0.), Some(0.0), Some(1.0), Some(4.222222222222222)]
199 );
200 }
201
202 #[test]
203 fn test_rolling_max_no_nulls() {
204 let buf = Buffer::from(vec![1.0, 2.0, 3.0, 4.0]);
205 let arr = &PrimitiveArray::new(
206 ArrowDataType::Float64,
207 buf,
208 Some(Bitmap::from(&[true, true, true, true])),
209 );
210 let out = rolling_max(arr, 4, 1, false, None, None);
211 let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
212 let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
213 assert_eq!(out, &[Some(1.0), Some(2.0), Some(3.0), Some(4.0)]);
214
215 let out = rolling_max(arr, 2, 2, false, None, None);
216 let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
217 let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
218 assert_eq!(out, &[None, Some(2.0), Some(3.0), Some(4.0)]);
219
220 let out = rolling_max(arr, 4, 4, false, None, None);
221 let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
222 let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
223 assert_eq!(out, &[None, None, None, Some(4.0)]);
224
225 let buf = Buffer::from(vec![4.0, 3.0, 2.0, 1.0]);
226 let arr = &PrimitiveArray::new(
227 ArrowDataType::Float64,
228 buf,
229 Some(Bitmap::from(&[true, true, true, true])),
230 );
231 let out = rolling_max(arr, 2, 1, false, None, None);
232 let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
233 let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
234 assert_eq!(out, &[Some(4.0), Some(4.0), Some(3.0), Some(2.0)]);
235
236 let out =
237 super::no_nulls::rolling_max(arr.values().as_slice(), 2, 1, false, None, None).unwrap();
238 let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
239 let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
240 assert_eq!(out, &[Some(4.0), Some(4.0), Some(3.0), Some(2.0)]);
241 }
242
243 #[test]
244 fn test_rolling_extrema_nulls() {
245 let vals = vec![3, 3, 3, 10, 10, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1];
246 let validity = Bitmap::new_with_value(true, vals.len());
247 let window_size = 3;
248 let min_periods = 3;
249
250 let arr = Int32Array::new(ArrowDataType::Int32, vals.into(), Some(validity));
251
252 let out = rolling_apply_agg_window::<MinMaxWindow<i32, MaxIgnoreNan>, _, _>(
253 arr.values().as_slice(),
254 arr.validity().as_ref().unwrap(),
255 window_size,
256 min_periods,
257 det_offsets,
258 None,
259 );
260 let arr = out.as_any().downcast_ref::<Int32Array>().unwrap();
261 assert_eq!(arr.null_count(), 2);
262 assert_eq!(
263 &arr.values().as_slice()[2..],
264 &[3, 10, 10, 10, 10, 10, 9, 8, 7, 6, 5, 4, 3]
265 );
266 }
267}