polars_compute/rolling/no_nulls/
sum.rs

1#![allow(unsafe_op_in_unsafe_fn)]
2use super::*;
3
4fn sum_kahan<
5    T: NativeType
6        + IsFloat
7        + std::iter::Sum
8        + AddAssign
9        + SubAssign
10        + Sub<Output = T>
11        + Add<Output = T>,
12>(
13    vals: &[T],
14) -> (T, T) {
15    if T::is_float() {
16        let mut sum = T::zeroed();
17        let mut err = T::zeroed();
18
19        for val in vals.iter().copied() {
20            if val.is_finite() {
21                let y = val - err;
22                let new_sum = sum + y;
23                err = (new_sum - sum) - y;
24                sum = new_sum;
25            } else {
26                sum += val
27            }
28        }
29        (sum, err)
30    } else {
31        (vals.iter().copied().sum::<T>(), T::zeroed())
32    }
33}
34
35pub struct SumWindow<'a, T, S> {
36    slice: &'a [T],
37    sum: S,
38    err: S,
39    last_start: usize,
40    last_end: usize,
41}
42
43impl<T, S> SumWindow<'_, T, S>
44where
45    T: NativeType + IsFloat + Sub<Output = T> + NumCast,
46    S: NativeType + AddAssign + SubAssign + Sub<Output = S> + Add<Output = S> + NumCast,
47{
48    // Kahan summation
49    fn add(&mut self, val: T) {
50        if T::is_float() && val.is_finite() {
51            let val: S = NumCast::from(val).unwrap();
52            let y = val - self.err;
53            let new_sum = self.sum + y;
54            self.err = (new_sum - self.sum) - y;
55            self.sum = new_sum;
56        } else {
57            let val: S = NumCast::from(val).unwrap();
58            self.sum += val;
59        }
60    }
61
62    fn sub(&mut self, val: T) {
63        if T::is_float() {
64            self.add(T::zeroed() - val)
65        } else {
66            let val: S = NumCast::from(val).unwrap();
67            self.sum -= val;
68        }
69    }
70}
71
72impl<'a, T, S> RollingAggWindowNoNulls<'a, T> for SumWindow<'a, T, S>
73where
74    T: NativeType
75        + IsFloat
76        + Sub<Output = T>
77        + std::iter::Sum
78        + AddAssign
79        + SubAssign
80        + Add<Output = T>
81        + NumCast,
82    S: NativeType + AddAssign + SubAssign + Sub<Output = S> + Add<Output = S> + NumCast,
83{
84    fn new(
85        slice: &'a [T],
86        start: usize,
87        end: usize,
88        _params: Option<RollingFnParams>,
89        _window_size: Option<usize>,
90    ) -> Self {
91        let (sum, err) = sum_kahan(&slice[start..end]);
92        Self {
93            slice,
94            sum: NumCast::from(sum).unwrap(),
95            err: NumCast::from(err).unwrap(),
96            last_start: start,
97            last_end: end,
98        }
99    }
100
101    unsafe fn update(&mut self, start: usize, end: usize) -> Option<T> {
102        // if we exceed the end, we have a completely new window
103        // so we recompute
104        let recompute_sum = if start >= self.last_end {
105            true
106        } else {
107            // remove elements that should leave the window
108            let mut recompute_sum = false;
109            for idx in self.last_start..start {
110                // SAFETY:
111                // we are in bounds
112                let leaving_value = self.slice.get_unchecked(idx);
113
114                if T::is_float() && !leaving_value.is_finite() {
115                    recompute_sum = true;
116                    break;
117                }
118
119                self.sub(*leaving_value);
120            }
121            recompute_sum
122        };
123        self.last_start = start;
124
125        // we traverse all values and compute
126        if recompute_sum {
127            let vals = self.slice.get_unchecked(start..end);
128            let (sum, err) = sum_kahan(vals);
129            self.sum = NumCast::from(sum).unwrap();
130            self.err = NumCast::from(err).unwrap();
131        }
132        // add entering values.
133        else {
134            for idx in self.last_end..end {
135                self.add(*self.slice.get_unchecked(idx))
136            }
137        }
138        self.last_end = end;
139        NumCast::from(self.sum)
140    }
141}
142
143pub fn rolling_sum<T>(
144    values: &[T],
145    window_size: usize,
146    min_periods: usize,
147    center: bool,
148    weights: Option<&[f64]>,
149    _params: Option<RollingFnParams>,
150) -> PolarsResult<ArrayRef>
151where
152    T: NativeType
153        + std::iter::Sum
154        + NumCast
155        + Mul<Output = T>
156        + AddAssign
157        + SubAssign
158        + IsFloat
159        + Num,
160{
161    match (center, weights) {
162        (true, None) => rolling_apply_agg_window::<SumWindow<T, T>, _, _>(
163            values,
164            window_size,
165            min_periods,
166            det_offsets_center,
167            None,
168        ),
169        (false, None) => rolling_apply_agg_window::<SumWindow<T, T>, _, _>(
170            values,
171            window_size,
172            min_periods,
173            det_offsets,
174            None,
175        ),
176        (true, Some(weights)) => {
177            let weights = no_nulls::coerce_weights(weights);
178            no_nulls::rolling_apply_weights(
179                values,
180                window_size,
181                min_periods,
182                det_offsets_center,
183                no_nulls::compute_sum_weights,
184                &weights,
185            )
186        },
187        (false, Some(weights)) => {
188            let weights = no_nulls::coerce_weights(weights);
189            no_nulls::rolling_apply_weights(
190                values,
191                window_size,
192                min_periods,
193                det_offsets,
194                no_nulls::compute_sum_weights,
195                &weights,
196            )
197        },
198    }
199}
200
201#[cfg(test)]
202mod test {
203    use super::*;
204    #[test]
205    fn test_rolling_sum() {
206        let values = &[1.0f64, 2.0, 3.0, 4.0];
207
208        let out = rolling_sum(values, 2, 2, false, None, None).unwrap();
209        let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
210        let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
211        assert_eq!(out, &[None, Some(3.0), Some(5.0), Some(7.0)]);
212
213        let out = rolling_sum(values, 2, 1, false, None, None).unwrap();
214        let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
215        let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
216        assert_eq!(out, &[Some(1.0), Some(3.0), Some(5.0), Some(7.0)]);
217
218        let out = rolling_sum(values, 4, 1, false, None, None).unwrap();
219        let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
220        let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
221        assert_eq!(out, &[Some(1.0), Some(3.0), Some(6.0), Some(10.0)]);
222
223        let out = rolling_sum(values, 4, 1, true, None, None).unwrap();
224        let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
225        let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
226        assert_eq!(out, &[Some(3.0), Some(6.0), Some(10.0), Some(9.0)]);
227
228        let out = rolling_sum(values, 4, 4, true, None, None).unwrap();
229        let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
230        let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
231        assert_eq!(out, &[None, None, Some(10.0), None]);
232
233        // test nan handling.
234        let values = &[1.0, 2.0, 3.0, f64::nan(), 5.0, 6.0, 7.0];
235        let out = rolling_sum(values, 3, 3, false, None, None).unwrap();
236        let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
237        let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
238
239        assert_eq!(
240            format!("{:?}", out.as_slice()),
241            format!(
242                "{:?}",
243                &[
244                    None,
245                    None,
246                    Some(6.0),
247                    Some(f64::nan()),
248                    Some(f64::nan()),
249                    Some(f64::nan()),
250                    Some(18.0)
251                ]
252            )
253        );
254    }
255}