Skip to main content

high_roller/
rolling_sum.rs

1//! # Rolling Sum
2//!
3//! [`RollingSum`] tracks the sum of values in a fixed-size window.
4//! When the window is full, pushing a new value evicts the oldest.
5//!
6//! API guarantees:
7//! - Add is O(1).
8//! - Total is O(1).
9//! - Overflow and underflow are detected and recoverable.
10//! - Zero heap allocations.
11//!
12//! The value of this implementation is error recovery.
13//! [`RollingSum`] returns [`None`] as long as the inner sum
14//! overflows `T::MAX` or underflows `T::MIN`. And it resumes
15//! yielding `Some(&T)` as soon as the inner sum recovers.
16//!
17//! ```
18//! use high_roller::rolling_sum::RollingSum;
19//!
20//! let mut rsum: RollingSum<i8, 3> = RollingSum::default();
21//!
22//! // Add an initial value.
23//! rsum.add(5);
24//! assert_eq!(rsum.total().copied(), Some(5));
25//!
26//! // Stream of -1s takes over the window.
27//! for _ in 0..100 {
28//!     rsum.add(-1);
29//! }
30//! assert_eq!(rsum.total().copied(), Some(-3));
31//!
32//! // Underflow forces total to None.
33//! rsum.add(i8::MIN);
34//! assert_eq!(rsum.total(), None);
35//!
36//! // Balanced back to zero.
37//! rsum.add(i8::MAX);
38//! rsum.add(1);
39//! assert_eq!(rsum.total().copied(), Some(0));
40//!
41//! // Evicting i8::MIN causes overflow.
42//! rsum.add(0);
43//! assert_eq!(rsum.total(), None);
44//!
45//! // Cause multiple overflows.
46//! for _ in 0..100 {
47//!     rsum.add(i8::MAX);
48//! }
49//! assert_eq!(rsum.total(), None);
50//!
51//! // And recover back into an expected range.
52//! for _ in 0..100 {
53//!     rsum.add(1);
54//! }
55//! assert_eq!(rsum.total().copied(), Some(3));
56//! ```
57
58use arraydeque::ArrayDeque;
59use num_traits::CheckedAdd;
60use num_traits::CheckedSub;
61use num_traits::WrappingAdd;
62use num_traits::WrappingSub;
63
64// PERF: using a bitvec or double bitvec for RollingSum on
65// applicable types could be a pretty nice space optimization.
66// Future project :)
67
68/// Tracks a rolling sum of at most `WINDOW` values.
69///
70/// This type is stack allocated. However a large window
71/// might warrant boxing. Clustering multiple instances
72/// in the same allocation might offer better cache
73/// access patterns.
74///
75/// ```
76/// use high_roller::rolling_sum::RollingSum;
77///
78/// struct Average {
79///     total: RollingSum<u32, 6000>,
80///     samples: RollingSum<u8, 6000>
81/// }
82///
83/// // Probably want to box this.
84/// const _: () = assert!(core::mem::size_of::<Average>() == 30064);
85/// ```
86#[derive(Debug)]
87pub struct RollingSum<T, const WINDOW: usize> {
88    deq: ArrayDeque<T, WINDOW>,
89    total: T,
90    zero: T,
91    balance: isize,
92}
93
94impl<T, const W: usize> Default for RollingSum<T, W>
95where
96    T: Default,
97{
98    /// Constructs a new [`RollingSum`] with sum and zero
99    /// values equal to `T::default()`.
100    fn default() -> Self {
101        Self::new(T::default(), T::default())
102    }
103}
104
105impl<T, const WINDOW: usize> RollingSum<T, WINDOW>
106where
107    T: Default,
108{
109    /// Constructs a new [`RollingSum`].
110    ///
111    /// Prefer using [`RollingSum::default`] for a cleaner API.
112    /// `init` is the sum's initial value. `zero` is value separating
113    /// positive and negative for this type. While identifying zero is strange,
114    /// passing as a parameter avoids requiring a `Zero` trait impl, which isn't
115    /// fun for anybody. Zero is required for differentiating overflow
116    /// and underflow.
117    ///
118    /// ```
119    /// use high_roller::rolling_sum::RollingSum;
120    ///
121    /// let _sum: RollingSum::<usize, 10> = RollingSum::new(0, 0);
122    /// ```
123    ///
124    /// # Compile errors
125    ///
126    /// Using `WINDOW == 0` is a compile-time error.
127    ///
128    /// ```compile_fail
129    /// # use high_roller::rolling_sum::RollingSum;
130    /// // This will fail to compile. A rolling sum of capacity 0 is nonsensical.
131    /// let _sum: RollingSum::<usize, 0> = RollingSum::default();
132    /// ```
133    #[must_use]
134    pub const fn new(init: T, zero: T) -> Self {
135        const { assert!(WINDOW != 0, "RollingSum with WINDOW == 0 is not permitted") };
136        Self {
137            deq: ArrayDeque::new(),
138            total: init,
139            balance: 0,
140            zero,
141        }
142    }
143}
144
145impl<T, const WINDOW: usize> RollingSum<T, WINDOW>
146where
147    T: WrappingAdd + WrappingSub + CheckedAdd + CheckedSub + PartialOrd + Copy + Default,
148{
149    /// Adds `T` to the rolling sum, displacing the oldest
150    /// member if the window is full to capacity.
151    ///
152    /// If adding `T` causes numerical overflow, subsequent
153    /// calls to [`RollingSum::total`] will return [`None`] until window
154    /// expirations cause underflow commensurate to the overflow.
155    ///
156    /// ```
157    /// use high_roller::rolling_sum::RollingSum;
158    ///
159    /// let mut rsum: RollingSum<i32, 1000> = RollingSum::default();
160    ///
161    /// rsum.add(100);
162    /// rsum.add(1);
163    /// rsum.add(-2);
164    ///
165    /// assert_eq!(rsum.total().copied(), Some(99));
166    /// ```
167    ///
168    /// # Panics
169    ///
170    /// This function panics if the [`isize`] signed overflow balance
171    /// counter itself overflows. Such a panic is impossible as long
172    /// as `WINDOW <= isize::MAX`. And an array of size `isize::MAX`
173    /// bytes is fantasy on all current hardware anyway.
174    //
175    // Clippy allow:
176    // Explained inline. This is easily provable, will never occur,
177    // and should not be exposed to the user.
178    #[allow(clippy::expect_used)]
179    #[allow(clippy::missing_panics_doc)]
180    pub fn add(&mut self, val: T) {
181        if self.deq.is_full() {
182            // Construction has a const assertion that WINDOW is not zero.
183            // So `is_full` guarantees there's something to pop.
184            let popped = self.deq.pop_front().expect(
185                "len is equal to capacity, and capacity is nonzero. So an element must exist.",
186            );
187
188            let changed = self.total.checked_sub(&popped).is_none();
189            self.total = self.total.wrapping_sub(&popped);
190
191            if changed {
192                self.balance = self
193                    .balance
194                    .checked_add(if popped >= self.zero { -1 } else { 1 })
195                    .expect("overflow count itself overflowed");
196            }
197        }
198
199        let changed = self.total.checked_add(&val).is_none();
200        self.total = self.total.wrapping_add(&val);
201
202        if changed {
203            self.balance = self
204                .balance
205                .checked_add(if val >= self.zero { 1 } else { -1 })
206                .expect("overflow count itself overflowed");
207        }
208
209        // The `if` condition above guarantees the deque
210        // is not full. So there's space to push a value.
211        self.deq.push_back(val).expect("deq is not full");
212    }
213
214    /// Returns the accumulated total of all added values that
215    /// fit within the rolling window's capacity.
216    ///
217    /// Returns [`None`] if the window has overflowed. [`RollingSum`]
218    /// will resume returning `Some(&T)` when the last element
219    /// causing overflow is pushed out of the window.
220    ///
221    /// ```
222    /// use high_roller::rolling_sum::RollingSum;
223    ///
224    /// let mut rsum: RollingSum<i32, 100> = RollingSum::default();
225    ///
226    /// rsum.add(i32::MIN);
227    /// assert_eq!(rsum.total().copied(), Some(i32::MIN));
228    ///
229    /// for _ in 0..1000 {
230    ///     rsum.add(i32::MIN);
231    ///     assert_eq!(rsum.total(), None);
232    /// }
233    ///
234    /// for _ in 0..100 {
235    ///     rsum.add(-1);
236    /// }
237    /// assert_eq!(rsum.total().copied(), Some(-100));
238    /// ```
239    #[must_use]
240    pub fn total(&self) -> Option<&T> {
241        (self.balance == 0).then_some(&self.total)
242    }
243}
244
245#[cfg(test)]
246pub mod for_tests {
247    use arraydeque::ArrayDeque;
248    use arraydeque::Wrapping;
249    use num_bigint::BigInt;
250
251    /// A simple implementation satisfying the same API as
252    /// this crate's `RollingSum` type. This is used for both
253    /// correctness and performance testing.
254    ///
255    /// See the note in total. This is not as robust against
256    /// underflow as RollingSum.
257    #[derive(Debug, Default)]
258    pub struct NaiveRollingSum<T, const WINDOW: usize> {
259        deq: ArrayDeque<T, WINDOW, Wrapping>,
260        sum: BigInt,
261    }
262
263    impl<T, const WINDOW: usize> NaiveRollingSum<T, WINDOW>
264    where
265        T: Clone + Default + Into<BigInt> + for<'a> TryFrom<&'a BigInt>,
266    {
267        #[must_use]
268        pub fn new(init: T) -> Self {
269            const { assert!(WINDOW != 0, "RollingSum with WINDOW == 0 is not permitted") };
270            Self {
271                deq: ArrayDeque::new(),
272                sum: init.into(),
273            }
274        }
275
276        pub fn add(&mut self, val: T) {
277            self.sum += Into::<BigInt>::into(val.clone());
278            if let Some(replaced) = self.deq.push_back(val) {
279                self.sum -= Into::<BigInt>::into(replaced);
280            }
281        }
282
283        #[must_use]
284        pub fn total(&self) -> Option<T> {
285            (&self.sum).try_into().ok()
286        }
287    }
288}
289
290#[cfg(test)]
291#[allow(clippy::unwrap_used)]
292mod tests {
293    use super::*;
294    use crate::decimal::D1;
295    use crate::decimal::D5;
296    use crate::rolling_sum::for_tests::NaiveRollingSum;
297    use core::fmt::Debug;
298    use rand::distr::Uniform;
299    use rand::rngs::SmallRng;
300    use rand::RngExt;
301    use rand::SeedableRng;
302
303    /// Smoke test for RollingSum correctness.
304    ///
305    /// Accumulates a representative RollingMax and NaiveRollingMax
306    /// to verify their outputs are identical.
307    #[test]
308    fn rng_with_naive() {
309        const QLEN: usize = 1000;
310        const STREAM_LEN: usize = 100_000;
311
312        let sample = SmallRng::seed_from_u64(57).sample_iter(Uniform::new(-800f32, 800.).unwrap());
313        let mut roller = RollingSum::<D5, QLEN>::default();
314        let mut naive = NaiveRollingSum::<D5, QLEN>::default();
315
316        for val in sample.take(STREAM_LEN) {
317            let d4 = D5::cast(val.into());
318            roller.add(d4);
319            naive.add(d4);
320            assert_eq!(roller.total(), naive.total().as_ref());
321        }
322    }
323
324    /// Verifies that total() returns `init` before any values are added.
325    #[test]
326    fn total_before_any_add_is_init() {
327        let rs: RollingSum<u32, 3> = RollingSum::default();
328        assert_eq!(rs.total(), Some(&0u32));
329    }
330
331    /// A single add must accumulate into total without triggering eviction.
332    #[test]
333    fn single_add_below_capacity() {
334        let mut rs: RollingSum<u32, 3> = RollingSum::default();
335        rs.add(10);
336        assert_eq!(rs.total(), Some(&10));
337    }
338
339    /// Filling exactly to capacity must sum all values with no eviction.
340    #[test]
341    fn fill_to_capacity_no_eviction() {
342        self::expect_total::<u32, 3>([1, 2, 3].into_iter().zip([1, 3, 6]));
343    }
344
345    /// The (capacity+1)th add must evict the oldest element.
346    #[test]
347    fn first_eviction_at_capacity_plus_one() {
348        // Window = [2, 3, 4] after evicting 1.
349        self::expect_total::<u32, 3>([1, 2, 3, 4].into_iter().zip([1, 3, 6, 9]));
350    }
351
352    /// Step through a longer sequence to verify correct FIFO eviction ordering.
353    #[test]
354    fn sliding_window_trace() {
355        // cap=3: [5]=5, [5,3]=8, [5,3,8]=16, [3,8,2]=13, [8,2,6]=16
356        self::expect_total::<u32, 3>([5, 3, 8, 2, 6].into_iter().zip([5, 8, 16, 13, 16]));
357    }
358
359    /// capacity=1: each add completely replaces the previous value.
360    #[test]
361    fn window_of_one() {
362        self::expect_total::<u32, 1>([5, 3, 9, 1].into_iter().zip([5, 3, 9, 1]));
363    }
364
365    /// Window larger than the input: no eviction ever occurs.
366    #[test]
367    fn window_larger_than_input() {
368        self::expect_total::<u32, 100>([1, 2, 3, 4, 5].into_iter().zip([1, 3, 6, 10, 15]));
369    }
370
371    /// Signed integers: negative values must be summed and evicted correctly.
372    #[test]
373    fn signed_integers() {
374        // cap=2: [-3]=-3, [-3,5]=2, [5,-2]=3, [-2,4]=2
375        self::expect_total::<i32, 2>([-3, 5, -2, 4].into_iter().zip([-3, 2, 3, 2]));
376    }
377
378    /// Hitting u8::MAX exactly (no overflow) then recovering on eviction.
379    #[test]
380    fn u8_boundary_exact() {
381        let mut rs = RollingSum::<u8, 3>::default();
382        rs.add(100);
383        rs.add(100);
384        rs.add(55);
385        assert_eq!(rs.total(), Some(&255)); // 100+100+55 = u8::MAX exactly
386        rs.add(0);
387        assert_eq!(rs.total(), Some(&155)); // evicted 100 → 100+55+0
388    }
389
390    /// total() returns None while the window sum has overflowed, and recovers to Some
391    /// once all overflowing elements have been evicted.
392    #[test]
393    fn overflow_detected_then_recovered() {
394        let mut rs = RollingSum::<u8, 2>::default();
395        rs.add(200);
396        assert_eq!(rs.total(), Some(&200)); // single element, no overflow
397        rs.add(100); // 200+100=300 > u8::MAX → wrap_ct=1
398        assert_eq!(rs.total(), None); // overflow detected
399        rs.add(50); // evicts 200; window=[100,50], sum=150
400        assert_eq!(rs.total(), Some(&150)); // overflow healed
401    }
402
403    /// Fills the window with three individually-overflowing values (wrap_ct reaches 2),
404    /// then slides in small values to evict them one at a time, confirming that
405    /// total() stays None while any overflow-causing element remains in the window
406    /// and returns exact Some values as each one is expelled.
407    #[test]
408    fn double_overflow_and_full_recovery() {
409        let mut rs = RollingSum::<u8, 3>::default();
410
411        rs.add(200);
412        assert_eq!(rs.total(), Some(&200)); // [200], true=200, wrap_ct=0
413        rs.add(200);
414        assert_eq!(rs.total(), None); // [200,200], true=400, wrap_ct=1
415        rs.add(200);
416        assert_eq!(rs.total(), None); // [200,200,200], true=600, wrap_ct=2
417
418        rs.add(10);
419        assert_eq!(rs.total(), None); // evict 200 → wrap_ct=1; [200,200,10], true=410
420        rs.add(10);
421        assert_eq!(rs.total(), Some(&220)); // evict 200 → wrap_ct=0; [200,10,10], true=220
422        rs.add(10);
423        assert_eq!(rs.total(), Some(&30)); // evict 200, no wrap; [10,10,10], true=30
424    }
425
426    /// Large u64 values that stay within range: verifies no spurious overflow.
427    #[test]
428    fn u64_large_values_no_overflow() {
429        const HALF: u64 = (u64::MAX as f64 / 2.) as u64 - 1;
430        let mut rs = RollingSum::<_, 2>::default();
431        rs.add(HALF);
432        rs.add(HALF);
433        assert_eq!(rs.total(), Some(&(HALF * 2))); // 2^63 - 1, no overflow
434        rs.add(1);
435        assert_eq!(rs.total(), Some(&(HALF + 1))); // evicted half → half + 1
436    }
437
438    #[test]
439    fn overflow_negative() {
440        let mut rs = RollingSum::<i32, 3>::default();
441
442        rs.add(i32::MAX); // Total = MAX
443        assert!(rs.total().is_some());
444
445        rs.add(i32::MIN); // Total = MAX + MIN
446        assert!(rs.total().is_some());
447
448        rs.add(i32::MAX); // Total = MAX + MIN + MAX
449        assert!(rs.total().is_some());
450
451        rs.add(i32::MAX); // Total = MIN + MAX + MAX
452        assert!(rs.total().is_some());
453
454        rs.add(0); // Total = MAX + MAX + 0
455        assert!(rs.total().is_none());
456
457        rs.add(0); // Total = MAX + 0 + 0
458        assert_eq!(rs.total(), Some(&i32::MAX));
459    }
460
461    #[test]
462    fn underflow_negative() {
463        let mut rs = RollingSum::<i32, 3>::default();
464
465        rs.add(i32::MIN);
466        assert!(rs.total().is_some());
467
468        rs.add(-1);
469        assert!(rs.total().is_none());
470
471        rs.add(1);
472        assert_eq!(rs.total(), Some(&i32::MIN));
473    }
474
475    #[test]
476    fn decimal_overflow() {
477        let mut rs = RollingSum::<D1, 4>::default();
478
479        rs.add(D1::MAX);
480        assert!(rs.total().is_some());
481
482        for _ in 0..100 {
483            rs.add(D1::MAX);
484            assert!(rs.total().is_none());
485        }
486
487        for _ in 0..3 {
488            rs.add(D1::ZERO);
489        }
490
491        rs.add(D1::MIN_UNIT);
492        assert!(matches!(rs.total(), Some(&D1::MIN_UNIT)));
493    }
494
495    /// Feeds inputs from an `(input, expected)` iterator into
496    /// a RollingSum. Compares each total to `expected` and panics
497    /// if they're not equal.
498    fn expect_total<T, const WINDOW: usize>(input_and_expected: impl Iterator<Item = (T, T)>)
499    where
500        T: WrappingAdd
501            + WrappingSub
502            + CheckedAdd
503            + CheckedSub
504            + PartialOrd
505            + Copy
506            + Default
507            + Debug,
508    {
509        let mut roll: RollingSum<T, WINDOW> = RollingSum::default();
510        for (input, expected) in input_and_expected {
511            roll.add(input);
512            assert_eq!(*roll.total().unwrap(), expected);
513        }
514    }
515}