1use arraydeque::ArrayDeque;
2use num_traits::{CheckedAdd, CheckedSub, WrappingAdd, WrappingSub};
3
4#[derive(Debug)]
9pub struct RollingSum<T, const WINDOW: usize> {
10 deq: ArrayDeque<T, WINDOW>,
11 total: T,
12 zero: T,
13 balance: isize,
14}
15
16impl<T, const W: usize> Default for RollingSum<T, W>
17where
18 T: Default,
19{
20 fn default() -> Self {
21 Self::new(T::default(), T::default())
22 }
23}
24
25impl<T, const WINDOW: usize> RollingSum<T, WINDOW>
26where
27 T: Default,
28{
29 #[must_use]
30 pub const fn new(init: T, zero: T) -> Self {
31 const { assert!(WINDOW != 0, "RollingSum with WINDOW == 0 is not permitted") };
32 Self {
33 deq: ArrayDeque::new(),
34 total: init,
35 balance: 0,
36 zero,
37 }
38 }
39}
40
41impl<T, const WINDOW: usize> RollingSum<T, WINDOW>
42where
43 T: WrappingAdd + WrappingSub + CheckedAdd + CheckedSub + PartialOrd + Copy + Default,
44{
45 #[allow(clippy::expect_used)]
62 #[allow(clippy::missing_panics_doc)]
63 pub fn add(&mut self, val: T) {
64 if self.deq.is_full() {
65 let popped = self.deq.pop_front().expect(
68 "len is equal to capacity, and capacity is nonzero. So an element must exist.",
69 );
70
71 let changed = self.total.checked_sub(&popped).is_none();
72 self.total = self.total.wrapping_sub(&popped);
73
74 if changed {
75 self.balance = self
76 .balance
77 .checked_add(if popped >= self.zero { -1 } else { 1 })
78 .expect("overflow count itself overflowed");
79 }
80 }
81
82 let changed = self.total.checked_add(&val).is_none();
83 self.total = self.total.wrapping_add(&val);
84
85 if changed {
86 self.balance = self
87 .balance
88 .checked_add(if val >= self.zero { 1 } else { -1 })
89 .expect("overflow count itself overflowed");
90 }
91
92 self.deq.push_back(val).expect("deq is not full");
95 }
96
97 #[must_use]
105 pub fn total(&self) -> Option<&T> {
106 (self.balance == 0).then_some(&self.total)
107 }
108}
109
110#[cfg(test)]
111pub mod for_tests {
112 use arraydeque::{ArrayDeque, Wrapping};
113 use num_bigint::BigInt;
114
115 #[derive(Debug, Default)]
122 pub struct NaiveRollingSum<T, const WINDOW: usize> {
123 deq: ArrayDeque<T, WINDOW, Wrapping>,
124 sum: BigInt,
125 }
126
127 impl<T, const WINDOW: usize> NaiveRollingSum<T, WINDOW>
128 where
129 T: Clone + Default + Into<BigInt> + for<'a> TryFrom<&'a BigInt>,
130 {
131 #[must_use]
132 pub fn new(init: T) -> Self {
133 const { assert!(WINDOW != 0, "RollingSum with WINDOW == 0 is not permitted") };
134 Self {
135 deq: ArrayDeque::new(),
136 sum: init.into(),
137 }
138 }
139
140 pub fn add(&mut self, val: T) {
141 self.sum += Into::<BigInt>::into(val.clone());
142 if let Some(replaced) = self.deq.push_back(val) {
143 self.sum -= Into::<BigInt>::into(replaced);
144 }
145 }
146
147 #[must_use]
148 pub fn total(&self) -> Option<T> {
149 (&self.sum).try_into().ok()
150 }
151 }
152}
153
154#[cfg(test)]
155#[allow(clippy::unwrap_used)]
156mod tests {
157 use super::*;
158 use crate::{
159 decimal::{D1, D5},
160 rolling_sum::for_tests::NaiveRollingSum,
161 };
162 use core::fmt::Debug;
163 use rand::{distr::Uniform, rngs::SmallRng, RngExt, SeedableRng};
164
165 #[test]
170 fn rng_with_naive() {
171 const QLEN: usize = 1000;
172 const STREAM_LEN: usize = 100_000;
173
174 let sample = SmallRng::seed_from_u64(57).sample_iter(Uniform::new(-800f32, 800.).unwrap());
175 let mut roller = RollingSum::<D5, QLEN>::default();
176 let mut naive = NaiveRollingSum::<D5, QLEN>::default();
177
178 for val in sample.take(STREAM_LEN) {
179 let d4 = D5::cast(val.into());
180 roller.add(d4);
181 naive.add(d4);
182 assert_eq!(roller.total(), naive.total().as_ref());
183 }
184 }
185
186 #[test]
188 fn total_before_any_add_is_init() {
189 let rs: RollingSum<u32, 3> = RollingSum::default();
190 assert_eq!(rs.total(), Some(&0u32));
191 }
192
193 #[test]
195 fn single_add_below_capacity() {
196 let mut rs: RollingSum<u32, 3> = RollingSum::default();
197 rs.add(10);
198 assert_eq!(rs.total(), Some(&10));
199 }
200
201 #[test]
203 fn fill_to_capacity_no_eviction() {
204 self::expect_total::<u32, 3>([1, 2, 3].into_iter().zip([1, 3, 6]));
205 }
206
207 #[test]
209 fn first_eviction_at_capacity_plus_one() {
210 self::expect_total::<u32, 3>([1, 2, 3, 4].into_iter().zip([1, 3, 6, 9]));
212 }
213
214 #[test]
216 fn sliding_window_trace() {
217 self::expect_total::<u32, 3>([5, 3, 8, 2, 6].into_iter().zip([5, 8, 16, 13, 16]));
219 }
220
221 #[test]
223 fn window_of_one() {
224 self::expect_total::<u32, 1>([5, 3, 9, 1].into_iter().zip([5, 3, 9, 1]));
225 }
226
227 #[test]
229 fn window_larger_than_input() {
230 self::expect_total::<u32, 100>([1, 2, 3, 4, 5].into_iter().zip([1, 3, 6, 10, 15]));
231 }
232
233 #[test]
235 fn signed_integers() {
236 self::expect_total::<i32, 2>([-3, 5, -2, 4].into_iter().zip([-3, 2, 3, 2]));
238 }
239
240 #[test]
242 fn u8_boundary_exact() {
243 let mut rs = RollingSum::<u8, 3>::default();
244 rs.add(100);
245 rs.add(100);
246 rs.add(55);
247 assert_eq!(rs.total(), Some(&255)); rs.add(0);
249 assert_eq!(rs.total(), Some(&155)); }
251
252 #[test]
255 fn overflow_detected_then_recovered() {
256 let mut rs = RollingSum::<u8, 2>::default();
257 rs.add(200);
258 assert_eq!(rs.total(), Some(&200)); rs.add(100); assert_eq!(rs.total(), None); rs.add(50); assert_eq!(rs.total(), Some(&150)); }
264
265 #[test]
270 fn double_overflow_and_full_recovery() {
271 let mut rs = RollingSum::<u8, 3>::default();
272
273 rs.add(200);
274 assert_eq!(rs.total(), Some(&200)); rs.add(200);
276 assert_eq!(rs.total(), None); rs.add(200);
278 assert_eq!(rs.total(), None); rs.add(10);
281 assert_eq!(rs.total(), None); rs.add(10);
283 assert_eq!(rs.total(), Some(&220)); rs.add(10);
285 assert_eq!(rs.total(), Some(&30)); }
287
288 #[test]
290 fn u64_large_values_no_overflow() {
291 const HALF: u64 = (u64::MAX as f64 / 2.) as u64 - 1;
292 let mut rs = RollingSum::<_, 2>::default();
293 rs.add(HALF);
294 rs.add(HALF);
295 assert_eq!(rs.total(), Some(&(HALF * 2))); rs.add(1);
297 assert_eq!(rs.total(), Some(&(HALF + 1))); }
299
300 #[test]
301 fn overflow_negative() {
302 let mut rs = RollingSum::<i32, 3>::default();
303
304 rs.add(i32::MAX); assert!(rs.total().is_some());
306
307 rs.add(i32::MIN); assert!(rs.total().is_some());
309
310 rs.add(i32::MAX); assert!(rs.total().is_some());
312
313 rs.add(i32::MAX); assert!(rs.total().is_some());
315
316 rs.add(0); assert!(rs.total().is_none());
318
319 rs.add(0); assert_eq!(rs.total(), Some(&i32::MAX));
321 }
322
323 #[test]
324 fn underflow_negative() {
325 let mut rs = RollingSum::<i32, 3>::default();
326
327 rs.add(i32::MIN);
328 assert!(rs.total().is_some());
329
330 rs.add(-1);
331 assert!(rs.total().is_none());
332
333 rs.add(1);
334 assert_eq!(rs.total(), Some(&i32::MIN));
335 }
336
337 #[test]
338 fn decimal_overflow() {
339 let mut rs = RollingSum::<D1, 4>::default();
340
341 rs.add(D1::MAX);
342 assert!(rs.total().is_some());
343
344 for _ in 0..100 {
345 rs.add(D1::MAX);
346 assert!(rs.total().is_none());
347 }
348
349 for _ in 0..3 {
350 rs.add(D1::ZERO);
351 }
352
353 rs.add(D1::MIN_UNIT);
354 assert!(matches!(rs.total(), Some(&D1::MIN_UNIT)));
355 }
356
357 fn expect_total<T, const WINDOW: usize>(input_and_expected: impl Iterator<Item = (T, T)>)
361 where
362 T: WrappingAdd
363 + WrappingSub
364 + CheckedAdd
365 + CheckedSub
366 + PartialOrd
367 + Copy
368 + Default
369 + Debug,
370 {
371 let mut roll: RollingSum<T, WINDOW> = RollingSum::default();
372 for (input, expected) in input_and_expected {
373 roll.add(input);
374 assert_eq!(*roll.total().unwrap(), expected);
375 }
376 }
377}