Skip to main content

high_roller/
rolling_sum.rs

1use num_traits::{CheckedAdd, CheckedSub, WrappingAdd, WrappingSub};
2use std::{collections::VecDeque, num::NonZeroUsize};
3
4// TODO docs:
5// Also there's a bitvec optimization we could do
6// when T = bool. But that's a project for later.
7
8#[derive(Debug)]
9pub struct RollingSum<T> {
10    deq: VecDeque<T>,
11    total: T,
12    capacity: usize,
13    wrap_ct: usize,
14}
15
16impl<T> RollingSum<T>
17where
18    T: WrappingAdd + WrappingSub + CheckedAdd + CheckedSub + PartialOrd + Copy + Default,
19{
20    #[must_use]
21    pub fn new(capacity: NonZeroUsize) -> Self {
22        Self {
23            deq: VecDeque::with_capacity(capacity.into()),
24            total: T::default(),
25            capacity: capacity.into(),
26            wrap_ct: 0,
27        }
28    }
29
30    /// Adds `T` to the rolling sum, displacing the oldest
31    /// member if the window is full to capacity.
32    ///
33    /// If adding `T` causes numerical overflow, subsequent
34    /// calls to `total` will fail until the sum returns to an
35    /// un-overflowed state.
36    pub fn add(&mut self, val: T) -> bool {
37        if self.deq.len() == self.capacity {
38            let popped = self.deq.pop_front().expect(
39                "len is equal to capacity, and capacity is nonzero. So an element must exist.",
40            );
41            let before = self.total;
42            self.total = self.total.wrapping_sub(&popped);
43            if before.checked_sub(&popped).is_none() {
44                self.wrap_ct -= 1;
45            }
46        }
47
48        let before_add = self.total;
49        self.total = self.total.wrapping_add(&val);
50        self.wrap_ct += before_add.checked_add(&val).is_none() as usize;
51
52        self.deq.push_back(val);
53        true
54    }
55
56    /// Returns the accumulated total of all added
57    /// values that fit within the rolling window's
58    /// capacity.
59    ///
60    /// Returns None if the window has overflowed.
61    /// In that case, it will return to Some(..) when
62    /// the last element causing overflow is pushed out.
63    #[must_use]
64    pub fn total(&self) -> Option<&T> {
65        (self.wrap_ct == 0).then_some(&self.total)
66    }
67}
68
69#[cfg(test)]
70mod tests {
71    use super::*;
72
73    fn nz(n: usize) -> NonZeroUsize {
74        NonZeroUsize::new(n).unwrap()
75    }
76
77    /// Push every value and collect total() snapshots after each add.
78    fn totals<T>(vals: &[T], cap: usize) -> Vec<T>
79    where
80        T: WrappingAdd + WrappingSub + CheckedAdd + CheckedSub + PartialOrd + Copy + Default,
81    {
82        let mut rs = RollingSum::new(nz(cap));
83        vals.iter()
84            .map(|&v| {
85                rs.add(v);
86                *rs.total().unwrap()
87            })
88            .collect()
89    }
90
91    /// Verifies that total() returns `init` before any values are added.
92    #[test]
93    fn total_before_any_add_is_init() {
94        let rs: RollingSum<u32> = RollingSum::new(nz(3));
95        assert_eq!(rs.total(), Some(&0u32));
96    }
97
98    /// A single add must accumulate into total without triggering eviction.
99    #[test]
100    fn single_add_below_capacity() {
101        let mut rs: RollingSum<u32> = RollingSum::new(nz(3));
102        rs.add(10);
103        assert_eq!(rs.total(), Some(&10));
104    }
105
106    /// Filling exactly to capacity must sum all values with no eviction.
107    #[test]
108    fn fill_to_capacity_no_eviction() {
109        assert_eq!(totals(&[1u32, 2, 3], 3), vec![1, 3, 6]);
110    }
111
112    /// The (capacity+1)th add must evict the oldest element.
113    #[test]
114    fn first_eviction_at_capacity_plus_one() {
115        // Window = [2, 3, 4] after evicting 1.
116        assert_eq!(totals(&[1u32, 2, 3, 4], 3), vec![1, 3, 6, 9]);
117    }
118
119    /// Step through a longer sequence to verify correct FIFO eviction ordering.
120    #[test]
121    fn sliding_window_trace() {
122        // cap=3: [5]=5, [5,3]=8, [5,3,8]=16, [3,8,2]=13, [8,2,6]=16
123        assert_eq!(totals(&[5u32, 3, 8, 2, 6], 3), vec![5, 8, 16, 13, 16]);
124    }
125
126    /// capacity=1: each add completely replaces the previous value.
127    #[test]
128    fn window_of_one() {
129        assert_eq!(totals(&[5u32, 3, 9, 1], 1), vec![5, 3, 9, 1]);
130    }
131
132    /// Window larger than the input: no eviction ever occurs.
133    #[test]
134    fn window_larger_than_input() {
135        assert_eq!(totals(&[1u32, 2, 3, 4, 5], 100), vec![1, 3, 6, 10, 15]);
136    }
137
138    /// Signed integers: negative values must be summed and evicted correctly.
139    #[test]
140    fn signed_integers() {
141        // cap=2: [-3]=-3, [-3,5]=2, [5,-2]=3, [-2,4]=2
142        assert_eq!(totals(&[-3i32, 5, -2, 4], 2), vec![-3, 2, 3, 2]);
143    }
144
145    /// Hitting u8::MAX exactly (no overflow) then recovering on eviction.
146    #[test]
147    fn u8_boundary_exact() {
148        let mut rs = RollingSum::new(nz(3));
149        rs.add(100);
150        rs.add(100);
151        rs.add(55);
152        assert_eq!(rs.total(), Some(&255)); // 100+100+55 = u8::MAX exactly
153        rs.add(0);
154        assert_eq!(rs.total(), Some(&155)); // evicted 100 → 100+55+0
155    }
156
157    /// total() returns None while the window sum has overflowed, and recovers to Some
158    /// once all overflowing elements have been evicted.
159    #[test]
160    fn overflow_detected_then_recovered() {
161        let mut rs = RollingSum::<u8>::new(nz(2));
162        rs.add(200);
163        assert_eq!(rs.total(), Some(&200)); // single element, no overflow
164        rs.add(100);                        // 200+100=300 > u8::MAX → wrap_ct=1
165        assert_eq!(rs.total(), None);       // overflow detected
166        rs.add(50);                         // evicts 200; window=[100,50], sum=150
167        assert_eq!(rs.total(), Some(&150)); // overflow healed
168    }
169
170    /// Fills the window with three individually-overflowing values (wrap_ct reaches 2),
171    /// then slides in small values to evict them one at a time, confirming that
172    /// total() stays None while any overflow-causing element remains in the window
173    /// and returns exact Some values as each one is expelled.
174    #[test]
175    fn double_overflow_and_full_recovery() {
176        let mut rs = RollingSum::<u8>::new(nz(3));
177
178        rs.add(200);
179        assert_eq!(rs.total(), Some(&200)); // [200], true=200, wrap_ct=0
180        rs.add(200);
181        assert_eq!(rs.total(), None);       // [200,200], true=400, wrap_ct=1
182        rs.add(200);
183        assert_eq!(rs.total(), None);       // [200,200,200], true=600, wrap_ct=2
184
185        rs.add(10);
186        assert_eq!(rs.total(), None);       // evict 200 → wrap_ct=1; [200,200,10], true=410
187        rs.add(10);
188        assert_eq!(rs.total(), Some(&220)); // evict 200 → wrap_ct=0; [200,10,10], true=220
189        rs.add(10);
190        assert_eq!(rs.total(), Some(&30));  // evict 200, no wrap; [10,10,10], true=30
191    }
192
193    /// Large u64 values that stay within range: verifies no spurious overflow.
194    #[test]
195    fn u64_large_values_no_overflow() {
196        let half = u64::MAX / 2;
197        let mut rs = RollingSum::new(nz(2));
198        rs.add(half);
199        rs.add(half);
200        assert_eq!(rs.total(), Some(&(half * 2))); // 2^63 - 1, no overflow
201        rs.add(1);
202        assert_eq!(rs.total(), Some(&(half + 1))); // evicted half → half + 1
203    }
204}