polars_compute/rolling/
sum.rs1use std::ops::{Add, AddAssign, Sub, SubAssign};
2
3use super::no_nulls::RollingAggWindowNoNulls;
4use super::nulls::RollingAggWindowNulls;
5use super::*;
6
7pub struct SumWindow<'a, T, S> {
8 slice: &'a [T],
9 validity: Option<&'a Bitmap>,
10 sum: S,
11 err_add: S,
12 err_sub: S,
13 non_finite_count: usize, pos_inf_count: usize,
15 neg_inf_count: usize,
16 pub(super) null_count: usize,
17 pub(super) start: usize,
18 pub(super) end: usize,
19}
20
21impl<'a, T, S> SumWindow<'a, T, S>
22where
23 T: NativeType + IsFloat + Sub<Output = T> + NumCast + PartialOrd,
24 S: NativeType + AddAssign + SubAssign + Sub<Output = S> + Add<Output = S> + NumCast,
25{
26 fn new_impl(slice: &'a [T], validity: Option<&'a Bitmap>) -> Self {
27 Self {
28 slice,
29 validity,
30 sum: S::zeroed(),
31 err_add: S::zeroed(),
32 err_sub: S::zeroed(),
33 non_finite_count: 0,
34 pos_inf_count: 0,
35 neg_inf_count: 0,
36 null_count: 0,
37 start: 0,
38 end: 0,
39 }
40 }
41
42 fn reset(&mut self) {
43 self.sum = S::zeroed();
44 self.err_add = S::zeroed();
45 self.err_sub = S::zeroed();
46 self.non_finite_count = 0;
47 self.pos_inf_count = 0;
48 self.neg_inf_count = 0;
49 self.null_count = 0;
50 }
51
52 fn add_finite_kahan(&mut self, val: T) {
53 let val: S = NumCast::from(val).unwrap();
54 let y = val - self.err_add;
55 let new_sum = self.sum + y;
56 self.err_add = (new_sum - self.sum) - y;
57 self.sum = new_sum;
58 }
59
60 fn sub_finite_kahan(&mut self, val: T) {
61 let val: S = NumCast::from(T::zeroed() - val).unwrap();
62 let y = val - self.err_sub;
63 let new_sum = self.sum + y;
64 self.err_sub = (new_sum - self.sum) - y;
65 self.sum = new_sum;
66 }
67
68 fn add(&mut self, val: T) {
69 if T::is_float() {
70 if val.is_finite() {
71 self.add_finite_kahan(val);
72 } else {
73 self.non_finite_count += 1;
74 self.pos_inf_count += (val > T::zeroed()) as usize;
75 self.neg_inf_count += (val < T::zeroed()) as usize;
76 }
77 } else {
78 let val: S = NumCast::from(val).unwrap();
79 self.sum += val;
80 }
81 }
82
83 fn sub(&mut self, val: T) {
84 if T::is_float() {
85 if val.is_finite() {
86 self.sub_finite_kahan(val);
87 } else {
88 self.non_finite_count -= 1;
89 self.pos_inf_count -= (val > T::zeroed()) as usize;
90 self.neg_inf_count -= (val < T::zeroed()) as usize;
91 }
92 } else {
93 let val: S = NumCast::from(val).unwrap();
94 self.sum -= val;
95 }
96 }
97
98 fn get_sum(&self) -> Option<T> {
99 if self.non_finite_count == 0 {
100 NumCast::from(self.sum)
101 } else if self.non_finite_count == self.pos_inf_count {
102 Some(T::pos_inf_value())
103 } else if self.non_finite_count == self.neg_inf_count {
104 Some(T::neg_inf_value())
105 } else {
106 Some(T::nan_value())
107 }
108 }
109}
110
111impl<T, S> RollingAggWindowNoNulls<T> for SumWindow<'_, T, S>
112where
113 T: NativeType + IsFloat + Sub<Output = T> + NumCast + PartialOrd,
114 S: NativeType + AddAssign + SubAssign + Sub<Output = S> + Add<Output = S> + NumCast,
115{
116 type This<'a> = SumWindow<'a, T, S>;
117
118 fn new<'a>(
119 slice: &'a [T],
120 start: usize,
121 end: usize,
122 _params: Option<RollingFnParams>,
123 _window_size: Option<usize>,
124 ) -> Self::This<'a> {
125 let mut out = SumWindow::new_impl(slice, None);
126 unsafe { RollingAggWindowNoNulls::update(&mut out, start, end) };
127 out
128 }
129
130 unsafe fn update(&mut self, new_start: usize, new_end: usize) {
133 if new_start >= self.end {
134 self.reset();
135 self.start = new_start;
136 self.end = new_start;
137 }
138
139 for val in &self.slice[self.start..new_start] {
140 self.sub(*val);
141 }
142
143 for val in &self.slice[self.end..new_end] {
144 self.add(*val);
145 }
146
147 self.start = new_start;
148 self.end = new_end;
149 }
150
151 fn get_agg(&self, _idx: usize) -> Option<T> {
152 self.get_sum()
153 }
154
155 fn slice_len(&self) -> usize {
156 self.slice.len()
157 }
158}
159
160impl<T, S> RollingAggWindowNulls<T> for SumWindow<'_, T, S>
161where
162 T: NativeType + IsFloat + Sub<Output = T> + NumCast + PartialOrd,
163 S: NativeType + AddAssign + SubAssign + Sub<Output = S> + Add<Output = S> + NumCast,
164{
165 type This<'a> = SumWindow<'a, T, S>;
166
167 fn new<'a>(
168 slice: &'a [T],
169 validity: &'a Bitmap,
170 start: usize,
171 end: usize,
172 _params: Option<RollingFnParams>,
173 _window_size: Option<usize>,
174 ) -> Self::This<'a> {
175 assert!(start <= slice.len() && end <= slice.len() && start <= end);
176 let mut out = SumWindow::new_impl(slice, Some(validity));
177 unsafe { RollingAggWindowNulls::update(&mut out, start, end) };
179 out
180 }
181
182 unsafe fn update(&mut self, new_start: usize, new_end: usize) {
185 let validity = unsafe { self.validity.unwrap_unchecked() };
186
187 if new_start >= self.end {
188 self.reset();
189 self.start = new_start;
190 self.end = new_start;
191 }
192
193 for idx in self.start..new_start {
194 let valid = unsafe { validity.get_bit_unchecked(idx) };
195 if valid {
196 self.sub(unsafe { *self.slice.get_unchecked(idx) });
197 } else {
198 self.null_count -= 1;
199 }
200 }
201
202 for idx in self.end..new_end {
203 let valid = unsafe { validity.get_bit_unchecked(idx) };
204 if valid {
205 self.add(unsafe { *self.slice.get_unchecked(idx) });
206 } else {
207 self.null_count += 1;
208 }
209 }
210
211 self.start = new_start;
212 self.end = new_end;
213 }
214
215 fn get_agg(&self, _idx: usize) -> Option<T> {
216 self.get_sum()
217 }
218
219 fn is_valid(&self, min_periods: usize) -> bool {
220 ((self.end - self.start) - self.null_count) >= min_periods
221 }
222
223 fn slice_len(&self) -> usize {
224 self.slice.len()
225 }
226}