Skip to main content

polars_compute/rolling/
sum.rs

1use std::ops::{Add, AddAssign, Sub, SubAssign};
2
3use super::no_nulls::RollingAggWindowNoNulls;
4use super::nulls::RollingAggWindowNulls;
5use super::*;
6
7pub struct SumWindow<'a, T, S> {
8    slice: &'a [T],
9    validity: Option<&'a Bitmap>,
10    sum: S,
11    err_add: S,
12    err_sub: S,
13    non_finite_count: usize, // NaN or infinity.
14    pos_inf_count: usize,
15    neg_inf_count: usize,
16    pub(super) null_count: usize,
17    pub(super) start: usize,
18    pub(super) end: usize,
19}
20
21impl<'a, T, S> SumWindow<'a, T, S>
22where
23    T: NativeType + IsFloat + Sub<Output = T> + NumCast + PartialOrd,
24    S: NativeType + AddAssign + SubAssign + Sub<Output = S> + Add<Output = S> + NumCast,
25{
26    fn new_impl(slice: &'a [T], validity: Option<&'a Bitmap>) -> Self {
27        Self {
28            slice,
29            validity,
30            sum: S::zeroed(),
31            err_add: S::zeroed(),
32            err_sub: S::zeroed(),
33            non_finite_count: 0,
34            pos_inf_count: 0,
35            neg_inf_count: 0,
36            null_count: 0,
37            start: 0,
38            end: 0,
39        }
40    }
41
42    fn reset(&mut self) {
43        self.sum = S::zeroed();
44        self.err_add = S::zeroed();
45        self.err_sub = S::zeroed();
46        self.non_finite_count = 0;
47        self.pos_inf_count = 0;
48        self.neg_inf_count = 0;
49        self.null_count = 0;
50    }
51
52    fn add_finite_kahan(&mut self, val: T) {
53        let val: S = NumCast::from(val).unwrap();
54        let y = val - self.err_add;
55        let new_sum = self.sum + y;
56        self.err_add = (new_sum - self.sum) - y;
57        self.sum = new_sum;
58    }
59
60    fn sub_finite_kahan(&mut self, val: T) {
61        let val: S = NumCast::from(T::zeroed() - val).unwrap();
62        let y = val - self.err_sub;
63        let new_sum = self.sum + y;
64        self.err_sub = (new_sum - self.sum) - y;
65        self.sum = new_sum;
66    }
67
68    fn add(&mut self, val: T) {
69        if T::is_float() {
70            if val.is_finite() {
71                self.add_finite_kahan(val);
72            } else {
73                self.non_finite_count += 1;
74                self.pos_inf_count += (val > T::zeroed()) as usize;
75                self.neg_inf_count += (val < T::zeroed()) as usize;
76            }
77        } else {
78            let val: S = NumCast::from(val).unwrap();
79            self.sum += val;
80        }
81    }
82
83    fn sub(&mut self, val: T) {
84        if T::is_float() {
85            if val.is_finite() {
86                self.sub_finite_kahan(val);
87            } else {
88                self.non_finite_count -= 1;
89                self.pos_inf_count -= (val > T::zeroed()) as usize;
90                self.neg_inf_count -= (val < T::zeroed()) as usize;
91            }
92        } else {
93            let val: S = NumCast::from(val).unwrap();
94            self.sum -= val;
95        }
96    }
97
98    fn get_sum(&self) -> Option<T> {
99        if self.non_finite_count == 0 {
100            NumCast::from(self.sum)
101        } else if self.non_finite_count == self.pos_inf_count {
102            Some(T::pos_inf_value())
103        } else if self.non_finite_count == self.neg_inf_count {
104            Some(T::neg_inf_value())
105        } else {
106            Some(T::nan_value())
107        }
108    }
109}
110
111impl<T, S> RollingAggWindowNoNulls<T> for SumWindow<'_, T, S>
112where
113    T: NativeType + IsFloat + Sub<Output = T> + NumCast + PartialOrd,
114    S: NativeType + AddAssign + SubAssign + Sub<Output = S> + Add<Output = S> + NumCast,
115{
116    type This<'a> = SumWindow<'a, T, S>;
117
118    fn new<'a>(
119        slice: &'a [T],
120        start: usize,
121        end: usize,
122        _params: Option<RollingFnParams>,
123        _window_size: Option<usize>,
124    ) -> Self::This<'a> {
125        let mut out = SumWindow::new_impl(slice, None);
126        unsafe { RollingAggWindowNoNulls::update(&mut out, start, end) };
127        out
128    }
129
130    // # Safety
131    // The start, end range must be in-bounds.
132    unsafe fn update(&mut self, new_start: usize, new_end: usize) {
133        if new_start >= self.end {
134            self.reset();
135            self.start = new_start;
136            self.end = new_start;
137        }
138
139        for val in &self.slice[self.start..new_start] {
140            self.sub(*val);
141        }
142
143        for val in &self.slice[self.end..new_end] {
144            self.add(*val);
145        }
146
147        self.start = new_start;
148        self.end = new_end;
149    }
150
151    fn get_agg(&self, _idx: usize) -> Option<T> {
152        self.get_sum()
153    }
154
155    fn slice_len(&self) -> usize {
156        self.slice.len()
157    }
158}
159
160impl<T, S> RollingAggWindowNulls<T> for SumWindow<'_, T, S>
161where
162    T: NativeType + IsFloat + Sub<Output = T> + NumCast + PartialOrd,
163    S: NativeType + AddAssign + SubAssign + Sub<Output = S> + Add<Output = S> + NumCast,
164{
165    type This<'a> = SumWindow<'a, T, S>;
166
167    fn new<'a>(
168        slice: &'a [T],
169        validity: &'a Bitmap,
170        start: usize,
171        end: usize,
172        _params: Option<RollingFnParams>,
173        _window_size: Option<usize>,
174    ) -> Self::This<'a> {
175        assert!(start <= slice.len() && end <= slice.len() && start <= end);
176        let mut out = SumWindow::new_impl(slice, Some(validity));
177        // SAFETY: We bounds checked `start` and `end`.
178        unsafe { RollingAggWindowNulls::update(&mut out, start, end) };
179        out
180    }
181
182    // # Safety
183    // The start, end range must be in-bounds.
184    unsafe fn update(&mut self, new_start: usize, new_end: usize) {
185        let validity = unsafe { self.validity.unwrap_unchecked() };
186
187        if new_start >= self.end {
188            self.reset();
189            self.start = new_start;
190            self.end = new_start;
191        }
192
193        for idx in self.start..new_start {
194            let valid = unsafe { validity.get_bit_unchecked(idx) };
195            if valid {
196                self.sub(unsafe { *self.slice.get_unchecked(idx) });
197            } else {
198                self.null_count -= 1;
199            }
200        }
201
202        for idx in self.end..new_end {
203            let valid = unsafe { validity.get_bit_unchecked(idx) };
204            if valid {
205                self.add(unsafe { *self.slice.get_unchecked(idx) });
206            } else {
207                self.null_count += 1;
208            }
209        }
210
211        self.start = new_start;
212        self.end = new_end;
213    }
214
215    fn get_agg(&self, _idx: usize) -> Option<T> {
216        self.get_sum()
217    }
218
219    fn is_valid(&self, min_periods: usize) -> bool {
220        ((self.end - self.start) - self.null_count) >= min_periods
221    }
222
223    fn slice_len(&self) -> usize {
224        self.slice.len()
225    }
226}