polars_compute/rolling/no_nulls/
sum.rs1#![allow(unsafe_op_in_unsafe_fn)]
2use super::*;
3
4fn sum_kahan<
5 T: NativeType
6 + IsFloat
7 + std::iter::Sum
8 + AddAssign
9 + SubAssign
10 + Sub<Output = T>
11 + Add<Output = T>,
12>(
13 vals: &[T],
14) -> (T, T) {
15 if T::is_float() {
16 let mut sum = T::zeroed();
17 let mut err = T::zeroed();
18
19 for val in vals.iter().copied() {
20 if val.is_finite() {
21 let y = val - err;
22 let new_sum = sum + y;
23 err = (new_sum - sum) - y;
24 sum = new_sum;
25 } else {
26 sum += val
27 }
28 }
29 (sum, err)
30 } else {
31 (vals.iter().copied().sum::<T>(), T::zeroed())
32 }
33}
34
35pub struct SumWindow<'a, T, S> {
36 slice: &'a [T],
37 sum: S,
38 err: S,
39 last_start: usize,
40 last_end: usize,
41}
42
43impl<T, S> SumWindow<'_, T, S>
44where
45 T: NativeType + IsFloat + Sub<Output = T> + NumCast,
46 S: NativeType + AddAssign + SubAssign + Sub<Output = S> + Add<Output = S> + NumCast,
47{
48 fn add(&mut self, val: T) {
50 if T::is_float() && val.is_finite() {
51 let val: S = NumCast::from(val).unwrap();
52 let y = val - self.err;
53 let new_sum = self.sum + y;
54 self.err = (new_sum - self.sum) - y;
55 self.sum = new_sum;
56 } else {
57 let val: S = NumCast::from(val).unwrap();
58 self.sum += val;
59 }
60 }
61
62 fn sub(&mut self, val: T) {
63 if T::is_float() {
64 self.add(T::zeroed() - val)
65 } else {
66 let val: S = NumCast::from(val).unwrap();
67 self.sum -= val;
68 }
69 }
70}
71
72impl<'a, T, S> RollingAggWindowNoNulls<'a, T> for SumWindow<'a, T, S>
73where
74 T: NativeType
75 + IsFloat
76 + Sub<Output = T>
77 + std::iter::Sum
78 + AddAssign
79 + SubAssign
80 + Add<Output = T>
81 + NumCast,
82 S: NativeType + AddAssign + SubAssign + Sub<Output = S> + Add<Output = S> + NumCast,
83{
84 fn new(
85 slice: &'a [T],
86 start: usize,
87 end: usize,
88 _params: Option<RollingFnParams>,
89 _window_size: Option<usize>,
90 ) -> Self {
91 let (sum, err) = sum_kahan(&slice[start..end]);
92 Self {
93 slice,
94 sum: NumCast::from(sum).unwrap(),
95 err: NumCast::from(err).unwrap(),
96 last_start: start,
97 last_end: end,
98 }
99 }
100
101 unsafe fn update(&mut self, start: usize, end: usize) -> Option<T> {
102 let recompute_sum = if start >= self.last_end {
105 true
106 } else {
107 let mut recompute_sum = false;
109 for idx in self.last_start..start {
110 let leaving_value = self.slice.get_unchecked(idx);
113
114 if T::is_float() && !leaving_value.is_finite() {
115 recompute_sum = true;
116 break;
117 }
118
119 self.sub(*leaving_value);
120 }
121 recompute_sum
122 };
123 self.last_start = start;
124
125 if recompute_sum {
127 let vals = self.slice.get_unchecked(start..end);
128 let (sum, err) = sum_kahan(vals);
129 self.sum = NumCast::from(sum).unwrap();
130 self.err = NumCast::from(err).unwrap();
131 }
132 else {
134 for idx in self.last_end..end {
135 self.add(*self.slice.get_unchecked(idx))
136 }
137 }
138 self.last_end = end;
139 NumCast::from(self.sum)
140 }
141}
142
143pub fn rolling_sum<T>(
144 values: &[T],
145 window_size: usize,
146 min_periods: usize,
147 center: bool,
148 weights: Option<&[f64]>,
149 _params: Option<RollingFnParams>,
150) -> PolarsResult<ArrayRef>
151where
152 T: NativeType
153 + std::iter::Sum
154 + NumCast
155 + Mul<Output = T>
156 + AddAssign
157 + SubAssign
158 + IsFloat
159 + Num,
160{
161 match (center, weights) {
162 (true, None) => rolling_apply_agg_window::<SumWindow<T, T>, _, _>(
163 values,
164 window_size,
165 min_periods,
166 det_offsets_center,
167 None,
168 ),
169 (false, None) => rolling_apply_agg_window::<SumWindow<T, T>, _, _>(
170 values,
171 window_size,
172 min_periods,
173 det_offsets,
174 None,
175 ),
176 (true, Some(weights)) => {
177 let weights = no_nulls::coerce_weights(weights);
178 no_nulls::rolling_apply_weights(
179 values,
180 window_size,
181 min_periods,
182 det_offsets_center,
183 no_nulls::compute_sum_weights,
184 &weights,
185 )
186 },
187 (false, Some(weights)) => {
188 let weights = no_nulls::coerce_weights(weights);
189 no_nulls::rolling_apply_weights(
190 values,
191 window_size,
192 min_periods,
193 det_offsets,
194 no_nulls::compute_sum_weights,
195 &weights,
196 )
197 },
198 }
199}
200
201#[cfg(test)]
202mod test {
203 use super::*;
204 #[test]
205 fn test_rolling_sum() {
206 let values = &[1.0f64, 2.0, 3.0, 4.0];
207
208 let out = rolling_sum(values, 2, 2, false, None, None).unwrap();
209 let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
210 let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
211 assert_eq!(out, &[None, Some(3.0), Some(5.0), Some(7.0)]);
212
213 let out = rolling_sum(values, 2, 1, false, None, None).unwrap();
214 let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
215 let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
216 assert_eq!(out, &[Some(1.0), Some(3.0), Some(5.0), Some(7.0)]);
217
218 let out = rolling_sum(values, 4, 1, false, None, None).unwrap();
219 let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
220 let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
221 assert_eq!(out, &[Some(1.0), Some(3.0), Some(6.0), Some(10.0)]);
222
223 let out = rolling_sum(values, 4, 1, true, None, None).unwrap();
224 let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
225 let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
226 assert_eq!(out, &[Some(3.0), Some(6.0), Some(10.0), Some(9.0)]);
227
228 let out = rolling_sum(values, 4, 4, true, None, None).unwrap();
229 let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
230 let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
231 assert_eq!(out, &[None, None, Some(10.0), None]);
232
233 let values = &[1.0, 2.0, 3.0, f64::nan(), 5.0, 6.0, 7.0];
235 let out = rolling_sum(values, 3, 3, false, None, None).unwrap();
236 let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
237 let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
238
239 assert_eq!(
240 format!("{:?}", out.as_slice()),
241 format!(
242 "{:?}",
243 &[
244 None,
245 None,
246 Some(6.0),
247 Some(f64::nan()),
248 Some(f64::nan()),
249 Some(f64::nan()),
250 Some(18.0)
251 ]
252 )
253 );
254 }
255}