polars_compute/rolling/nulls/
sum.rs1#![allow(unsafe_op_in_unsafe_fn)]
2
3use super::*;
4
5pub struct SumWindow<'a, T, S> {
6 slice: &'a [T],
7 validity: &'a Bitmap,
8 sum: Option<S>,
9 err: S,
10 last_start: usize,
11 last_end: usize,
12 pub(super) null_count: usize,
13}
14
15impl<T, S> SumWindow<'_, T, S>
16where
17 T: NativeType + IsFloat + Sub<Output = T> + NumCast,
18 S: NativeType + AddAssign + SubAssign + Sub<Output = S> + Add<Output = S> + NumCast,
19{
20 fn add(&mut self, val: T) {
22 if T::is_float() && val.is_finite() {
23 self.sum = self.sum.map(|sum| {
24 let val: S = NumCast::from(val).unwrap();
25 let y = val - self.err;
26 let new_sum = sum + y;
27 self.err = (new_sum - sum) - y;
28 new_sum
29 });
30 } else {
31 let val: S = NumCast::from(val).unwrap();
32 self.sum = self.sum.map(|v| v + val)
33 }
34 }
35
36 fn sub(&mut self, val: T) {
37 if T::is_float() && val.is_finite() {
38 self.add(T::zeroed() - val)
39 } else {
40 let val: S = NumCast::from(val).unwrap();
41 self.sum = self.sum.map(|v| v - val)
42 }
43 }
44}
45
46impl<T, S> SumWindow<'_, T, S>
47where
48 T: NativeType + IsFloat + Sub<Output = T> + NumCast,
49 S: NativeType + AddAssign + SubAssign + Sub<Output = S> + Add<Output = S> + NumCast,
50{
51 unsafe fn compute_sum_and_null_count(&mut self, start: usize, end: usize) -> Option<S> {
53 let mut sum = None;
54 let mut idx = start;
55 self.null_count = 0;
56 for value in &self.slice[start..end] {
57 let value: S = NumCast::from(*value).unwrap();
58 let valid = self.validity.get_bit_unchecked(idx);
59 if valid {
60 match sum {
61 None => sum = Some(value),
62 Some(current) => sum = Some(value + current),
63 }
64 } else {
65 self.null_count += 1;
66 }
67 idx += 1;
68 }
69 self.sum = sum;
70 sum
71 }
72}
73
74impl<'a, T, S> RollingAggWindowNulls<'a, T> for SumWindow<'a, T, S>
75where
76 T: NativeType + IsFloat + Sub<Output = T> + NumCast,
77 S: NativeType + AddAssign + SubAssign + Sub<Output = S> + Add<Output = S> + NumCast,
78{
79 unsafe fn new(
80 slice: &'a [T],
81 validity: &'a Bitmap,
82 start: usize,
83 end: usize,
84 _params: Option<RollingFnParams>,
85 _window_size: Option<usize>,
86 ) -> Self {
87 let mut out = Self {
88 slice,
89 validity,
90 sum: None,
91 err: S::zeroed(),
92 last_start: start,
93 last_end: end,
94 null_count: 0,
95 };
96 out.compute_sum_and_null_count(start, end);
97 out
98 }
99
100 unsafe fn update(&mut self, start: usize, end: usize) -> Option<T> {
101 let recompute_sum = if start >= self.last_end {
104 true
105 } else {
106 let mut recompute_sum = false;
108 for idx in self.last_start..start {
109 let valid = self.validity.get_bit_unchecked(idx);
112 if valid {
113 let leaving_value = self.slice.get_unchecked(idx);
114
115 if T::is_float() && !leaving_value.is_finite() {
117 recompute_sum = true;
118 break;
119 }
120 self.sub(*leaving_value);
121 } else {
122 self.null_count -= 1;
124
125 if self.sum.is_none() {
128 recompute_sum = true;
129 break;
130 }
131 }
132 }
133 recompute_sum
134 };
135
136 self.last_start = start;
137
138 if recompute_sum {
140 self.compute_sum_and_null_count(start, end);
141 } else {
142 for idx in self.last_end..end {
143 let valid = self.validity.get_bit_unchecked(idx);
144
145 if valid {
146 let value = *self.slice.get_unchecked(idx);
147 match self.sum {
148 None => self.sum = NumCast::from(value),
149 _ => self.add(value),
150 }
151 } else {
152 self.null_count += 1;
154 }
155 }
156 }
157 self.last_end = end;
158 self.sum.and_then(NumCast::from)
159 }
160
161 fn is_valid(&self, min_periods: usize) -> bool {
162 ((self.last_end - self.last_start) - self.null_count) >= min_periods
163 }
164}
165
166pub fn rolling_sum<T>(
167 arr: &PrimitiveArray<T>,
168 window_size: usize,
169 min_periods: usize,
170 center: bool,
171 weights: Option<&[f64]>,
172 _params: Option<RollingFnParams>,
173) -> ArrayRef
174where
175 T: NativeType
176 + IsFloat
177 + PartialOrd
178 + Add<Output = T>
179 + Sub<Output = T>
180 + SubAssign
181 + AddAssign
182 + NumCast,
183{
184 if weights.is_some() {
185 panic!("weights not yet supported on array with null values")
186 }
187 if center {
188 rolling_apply_agg_window::<SumWindow<T, T>, _, _>(
189 arr.values().as_slice(),
190 arr.validity().as_ref().unwrap(),
191 window_size,
192 min_periods,
193 det_offsets_center,
194 None,
195 )
196 } else {
197 rolling_apply_agg_window::<SumWindow<T, T>, _, _>(
198 arr.values().as_slice(),
199 arr.validity().as_ref().unwrap(),
200 window_size,
201 min_periods,
202 det_offsets,
203 None,
204 )
205 }
206}