polars_compute/rolling/nulls/
sum.rs

1#![allow(unsafe_op_in_unsafe_fn)]
2
3use super::*;
4
5pub struct SumWindow<'a, T, S> {
6    slice: &'a [T],
7    validity: &'a Bitmap,
8    sum: Option<S>,
9    err: S,
10    last_start: usize,
11    last_end: usize,
12    pub(super) null_count: usize,
13}
14
15impl<T, S> SumWindow<'_, T, S>
16where
17    T: NativeType + IsFloat + Sub<Output = T> + NumCast,
18    S: NativeType + AddAssign + SubAssign + Sub<Output = S> + Add<Output = S> + NumCast,
19{
20    // Kahan summation
21    fn add(&mut self, val: T) {
22        if T::is_float() && val.is_finite() {
23            self.sum = self.sum.map(|sum| {
24                let val: S = NumCast::from(val).unwrap();
25                let y = val - self.err;
26                let new_sum = sum + y;
27                self.err = (new_sum - sum) - y;
28                new_sum
29            });
30        } else {
31            let val: S = NumCast::from(val).unwrap();
32            self.sum = self.sum.map(|v| v + val)
33        }
34    }
35
36    fn sub(&mut self, val: T) {
37        if T::is_float() && val.is_finite() {
38            self.add(T::zeroed() - val)
39        } else {
40            let val: S = NumCast::from(val).unwrap();
41            self.sum = self.sum.map(|v| v - val)
42        }
43    }
44}
45
46impl<T, S> SumWindow<'_, T, S>
47where
48    T: NativeType + IsFloat + Sub<Output = T> + NumCast,
49    S: NativeType + AddAssign + SubAssign + Sub<Output = S> + Add<Output = S> + NumCast,
50{
51    // compute sum from the entire window
52    unsafe fn compute_sum_and_null_count(&mut self, start: usize, end: usize) -> Option<S> {
53        let mut sum = None;
54        let mut idx = start;
55        self.null_count = 0;
56        for value in &self.slice[start..end] {
57            let value: S = NumCast::from(*value).unwrap();
58            let valid = self.validity.get_bit_unchecked(idx);
59            if valid {
60                match sum {
61                    None => sum = Some(value),
62                    Some(current) => sum = Some(value + current),
63                }
64            } else {
65                self.null_count += 1;
66            }
67            idx += 1;
68        }
69        self.sum = sum;
70        sum
71    }
72}
73
74impl<'a, T, S> RollingAggWindowNulls<'a, T> for SumWindow<'a, T, S>
75where
76    T: NativeType + IsFloat + Sub<Output = T> + NumCast,
77    S: NativeType + AddAssign + SubAssign + Sub<Output = S> + Add<Output = S> + NumCast,
78{
79    unsafe fn new(
80        slice: &'a [T],
81        validity: &'a Bitmap,
82        start: usize,
83        end: usize,
84        _params: Option<RollingFnParams>,
85        _window_size: Option<usize>,
86    ) -> Self {
87        let mut out = Self {
88            slice,
89            validity,
90            sum: None,
91            err: S::zeroed(),
92            last_start: start,
93            last_end: end,
94            null_count: 0,
95        };
96        out.compute_sum_and_null_count(start, end);
97        out
98    }
99
100    unsafe fn update(&mut self, start: usize, end: usize) -> Option<T> {
101        // if we exceed the end, we have a completely new window
102        // so we recompute
103        let recompute_sum = if start >= self.last_end {
104            true
105        } else {
106            // remove elements that should leave the window
107            let mut recompute_sum = false;
108            for idx in self.last_start..start {
109                // SAFETY:
110                // we are in bounds
111                let valid = self.validity.get_bit_unchecked(idx);
112                if valid {
113                    let leaving_value = self.slice.get_unchecked(idx);
114
115                    // if the leaving value is nan we need to recompute the window
116                    if T::is_float() && !leaving_value.is_finite() {
117                        recompute_sum = true;
118                        break;
119                    }
120                    self.sub(*leaving_value);
121                } else {
122                    // null value leaving the window
123                    self.null_count -= 1;
124
125                    // self.sum is None and the leaving value is None
126                    // if the entering value is valid, we might get a new sum.
127                    if self.sum.is_none() {
128                        recompute_sum = true;
129                        break;
130                    }
131                }
132            }
133            recompute_sum
134        };
135
136        self.last_start = start;
137
138        // we traverse all values and compute
139        if recompute_sum {
140            self.compute_sum_and_null_count(start, end);
141        } else {
142            for idx in self.last_end..end {
143                let valid = self.validity.get_bit_unchecked(idx);
144
145                if valid {
146                    let value = *self.slice.get_unchecked(idx);
147                    match self.sum {
148                        None => self.sum = NumCast::from(value),
149                        _ => self.add(value),
150                    }
151                } else {
152                    // null value entering the window
153                    self.null_count += 1;
154                }
155            }
156        }
157        self.last_end = end;
158        self.sum.and_then(NumCast::from)
159    }
160
161    fn is_valid(&self, min_periods: usize) -> bool {
162        ((self.last_end - self.last_start) - self.null_count) >= min_periods
163    }
164}
165
166pub fn rolling_sum<T>(
167    arr: &PrimitiveArray<T>,
168    window_size: usize,
169    min_periods: usize,
170    center: bool,
171    weights: Option<&[f64]>,
172    _params: Option<RollingFnParams>,
173) -> ArrayRef
174where
175    T: NativeType
176        + IsFloat
177        + PartialOrd
178        + Add<Output = T>
179        + Sub<Output = T>
180        + SubAssign
181        + AddAssign
182        + NumCast,
183{
184    if weights.is_some() {
185        panic!("weights not yet supported on array with null values")
186    }
187    if center {
188        rolling_apply_agg_window::<SumWindow<T, T>, _, _>(
189            arr.values().as_slice(),
190            arr.validity().as_ref().unwrap(),
191            window_size,
192            min_periods,
193            det_offsets_center,
194            None,
195        )
196    } else {
197        rolling_apply_agg_window::<SumWindow<T, T>, _, _>(
198            arr.values().as_slice(),
199            arr.validity().as_ref().unwrap(),
200            window_size,
201            min_periods,
202            det_offsets,
203            None,
204        )
205    }
206}