Skip to main content

high_roller/
rolling_sum.rs

1use arraydeque::ArrayDeque;
2use num_traits::{CheckedAdd, CheckedSub, WrappingAdd, WrappingSub};
3
4// TODO docs:
5// Also there's a bitvec optimization we could do
6// when T = bool. But that's a project for later.
7
8#[derive(Debug)]
9pub struct RollingSum<T, const WINDOW: usize> {
10    deq: ArrayDeque<T, WINDOW>,
11    total: T,
12    zero: T,
13    balance: isize,
14}
15
16impl<T, const W: usize> Default for RollingSum<T, W>
17where
18    T: Default,
19{
20    fn default() -> Self {
21        Self::new(T::default(), T::default())
22    }
23}
24
25impl<T, const WINDOW: usize> RollingSum<T, WINDOW>
26where
27    T: Default,
28{
29    #[must_use]
30    pub const fn new(init: T, zero: T) -> Self {
31        const { assert!(WINDOW != 0, "RollingSum with WINDOW == 0 is not permitted") };
32        Self {
33            deq: ArrayDeque::new(),
34            total: init,
35            balance: 0,
36            zero,
37        }
38    }
39}
40
41impl<T, const WINDOW: usize> RollingSum<T, WINDOW>
42where
43    T: WrappingAdd + WrappingSub + CheckedAdd + CheckedSub + PartialOrd + Copy + Default,
44{
45    /// Adds `T` to the rolling sum, displacing the oldest
46    /// member if the window is full to capacity.
47    ///
48    /// If adding `T` causes numerical overflow, subsequent
49    /// calls to `total` will return None until window
50    /// expirations cause underflow commensurate to the overflow.
51    ///
52    /// # Panics
53    ///
54    /// This function panics if the `usize` variable tracking the
55    /// number of times the sum has overflowed itself overflows.
56    /// A window should be sized such that this never occurs.
57    //
58    // Clippy allow:
59    // Explained inline. This is easily provable, will never occur,
60    // and should not be exposed to the user.
61    #[allow(clippy::expect_used)]
62    #[allow(clippy::missing_panics_doc)]
63    pub fn add(&mut self, val: T) {
64        if self.deq.is_full() {
65            // Construction has a const assertion that WINDOW is not zero.
66            // So `is_full` guarantees there's something to pop.
67            let popped = self.deq.pop_front().expect(
68                "len is equal to capacity, and capacity is nonzero. So an element must exist.",
69            );
70
71            let changed = self.total.checked_sub(&popped).is_none();
72            self.total = self.total.wrapping_sub(&popped);
73
74            if changed {
75                self.balance = self
76                    .balance
77                    .checked_add(if popped >= self.zero { -1 } else { 1 })
78                    .expect("overflow count itself overflowed");
79            }
80        }
81
82        let changed = self.total.checked_add(&val).is_none();
83        self.total = self.total.wrapping_add(&val);
84
85        if changed {
86            self.balance = self
87                .balance
88                .checked_add(if val >= self.zero { 1 } else { -1 })
89                .expect("overflow count itself overflowed");
90        }
91
92        // The `if` condition above guarantees the deque
93        // is not full. So there's space to push a value.
94        self.deq.push_back(val).expect("deq is not full");
95    }
96
97    /// Returns the accumulated total of all added
98    /// values that fit within the rolling window's
99    /// capacity.
100    ///
101    /// Returns None if the window has overflowed.
102    /// In that case, it will return to Some(..) when
103    /// the last element causing overflow is pushed out.
104    #[must_use]
105    pub fn total(&self) -> Option<&T> {
106        (self.balance == 0).then_some(&self.total)
107    }
108}
109
110#[cfg(test)]
111pub mod for_tests {
112    use arraydeque::{ArrayDeque, Wrapping};
113    use num_bigint::BigInt;
114
115    /// A simple implementation satisfying the same API as
116    /// this crate's `RollingSum` type. This is used for both
117    /// correctness and performance testing.
118    ///
119    /// See the note in total. This is not as robust against
120    /// underflow as RollingSum.
121    #[derive(Debug, Default)]
122    pub struct NaiveRollingSum<T, const WINDOW: usize> {
123        deq: ArrayDeque<T, WINDOW, Wrapping>,
124        sum: BigInt,
125    }
126
127    impl<T, const WINDOW: usize> NaiveRollingSum<T, WINDOW>
128    where
129        T: Clone + Default + Into<BigInt> + for<'a> TryFrom<&'a BigInt>,
130    {
131        #[must_use]
132        pub fn new(init: T) -> Self {
133            const { assert!(WINDOW != 0, "RollingSum with WINDOW == 0 is not permitted") };
134            Self {
135                deq: ArrayDeque::new(),
136                sum: init.into(),
137            }
138        }
139
140        pub fn add(&mut self, val: T) {
141            self.sum += Into::<BigInt>::into(val.clone());
142            if let Some(replaced) = self.deq.push_back(val) {
143                self.sum -= Into::<BigInt>::into(replaced);
144            }
145        }
146
147        #[must_use]
148        pub fn total(&self) -> Option<T> {
149            (&self.sum).try_into().ok()
150        }
151    }
152}
153
154#[cfg(test)]
155#[allow(clippy::unwrap_used)]
156mod tests {
157    use super::*;
158    use crate::{
159        decimal::{D1, D5},
160        rolling_sum::for_tests::NaiveRollingSum,
161    };
162    use core::fmt::Debug;
163    use rand::{distr::Uniform, rngs::SmallRng, RngExt, SeedableRng};
164
165    /// Smoke test for RollingSum correctness.
166    ///
167    /// Accumulates a representative RollingMax and NaiveRollingMax
168    /// to verify their outputs are identical.
169    #[test]
170    fn rng_with_naive() {
171        const QLEN: usize = 1000;
172        const STREAM_LEN: usize = 100_000;
173
174        let sample = SmallRng::seed_from_u64(57).sample_iter(Uniform::new(-800f32, 800.).unwrap());
175        let mut roller = RollingSum::<D5, QLEN>::default();
176        let mut naive = NaiveRollingSum::<D5, QLEN>::default();
177
178        for val in sample.take(STREAM_LEN) {
179            let d4 = D5::cast(val.into());
180            roller.add(d4);
181            naive.add(d4);
182            assert_eq!(roller.total(), naive.total().as_ref());
183        }
184    }
185
186    /// Verifies that total() returns `init` before any values are added.
187    #[test]
188    fn total_before_any_add_is_init() {
189        let rs: RollingSum<u32, 3> = RollingSum::default();
190        assert_eq!(rs.total(), Some(&0u32));
191    }
192
193    /// A single add must accumulate into total without triggering eviction.
194    #[test]
195    fn single_add_below_capacity() {
196        let mut rs: RollingSum<u32, 3> = RollingSum::default();
197        rs.add(10);
198        assert_eq!(rs.total(), Some(&10));
199    }
200
201    /// Filling exactly to capacity must sum all values with no eviction.
202    #[test]
203    fn fill_to_capacity_no_eviction() {
204        self::expect_total::<u32, 3>([1, 2, 3].into_iter().zip([1, 3, 6]));
205    }
206
207    /// The (capacity+1)th add must evict the oldest element.
208    #[test]
209    fn first_eviction_at_capacity_plus_one() {
210        // Window = [2, 3, 4] after evicting 1.
211        self::expect_total::<u32, 3>([1, 2, 3, 4].into_iter().zip([1, 3, 6, 9]));
212    }
213
214    /// Step through a longer sequence to verify correct FIFO eviction ordering.
215    #[test]
216    fn sliding_window_trace() {
217        // cap=3: [5]=5, [5,3]=8, [5,3,8]=16, [3,8,2]=13, [8,2,6]=16
218        self::expect_total::<u32, 3>([5, 3, 8, 2, 6].into_iter().zip([5, 8, 16, 13, 16]));
219    }
220
221    /// capacity=1: each add completely replaces the previous value.
222    #[test]
223    fn window_of_one() {
224        self::expect_total::<u32, 1>([5, 3, 9, 1].into_iter().zip([5, 3, 9, 1]));
225    }
226
227    /// Window larger than the input: no eviction ever occurs.
228    #[test]
229    fn window_larger_than_input() {
230        self::expect_total::<u32, 100>([1, 2, 3, 4, 5].into_iter().zip([1, 3, 6, 10, 15]));
231    }
232
233    /// Signed integers: negative values must be summed and evicted correctly.
234    #[test]
235    fn signed_integers() {
236        // cap=2: [-3]=-3, [-3,5]=2, [5,-2]=3, [-2,4]=2
237        self::expect_total::<i32, 2>([-3, 5, -2, 4].into_iter().zip([-3, 2, 3, 2]));
238    }
239
240    /// Hitting u8::MAX exactly (no overflow) then recovering on eviction.
241    #[test]
242    fn u8_boundary_exact() {
243        let mut rs = RollingSum::<u8, 3>::default();
244        rs.add(100);
245        rs.add(100);
246        rs.add(55);
247        assert_eq!(rs.total(), Some(&255)); // 100+100+55 = u8::MAX exactly
248        rs.add(0);
249        assert_eq!(rs.total(), Some(&155)); // evicted 100 → 100+55+0
250    }
251
252    /// total() returns None while the window sum has overflowed, and recovers to Some
253    /// once all overflowing elements have been evicted.
254    #[test]
255    fn overflow_detected_then_recovered() {
256        let mut rs = RollingSum::<u8, 2>::default();
257        rs.add(200);
258        assert_eq!(rs.total(), Some(&200)); // single element, no overflow
259        rs.add(100); // 200+100=300 > u8::MAX → wrap_ct=1
260        assert_eq!(rs.total(), None); // overflow detected
261        rs.add(50); // evicts 200; window=[100,50], sum=150
262        assert_eq!(rs.total(), Some(&150)); // overflow healed
263    }
264
265    /// Fills the window with three individually-overflowing values (wrap_ct reaches 2),
266    /// then slides in small values to evict them one at a time, confirming that
267    /// total() stays None while any overflow-causing element remains in the window
268    /// and returns exact Some values as each one is expelled.
269    #[test]
270    fn double_overflow_and_full_recovery() {
271        let mut rs = RollingSum::<u8, 3>::default();
272
273        rs.add(200);
274        assert_eq!(rs.total(), Some(&200)); // [200], true=200, wrap_ct=0
275        rs.add(200);
276        assert_eq!(rs.total(), None); // [200,200], true=400, wrap_ct=1
277        rs.add(200);
278        assert_eq!(rs.total(), None); // [200,200,200], true=600, wrap_ct=2
279
280        rs.add(10);
281        assert_eq!(rs.total(), None); // evict 200 → wrap_ct=1; [200,200,10], true=410
282        rs.add(10);
283        assert_eq!(rs.total(), Some(&220)); // evict 200 → wrap_ct=0; [200,10,10], true=220
284        rs.add(10);
285        assert_eq!(rs.total(), Some(&30)); // evict 200, no wrap; [10,10,10], true=30
286    }
287
288    /// Large u64 values that stay within range: verifies no spurious overflow.
289    #[test]
290    fn u64_large_values_no_overflow() {
291        const HALF: u64 = (u64::MAX as f64 / 2.) as u64 - 1;
292        let mut rs = RollingSum::<_, 2>::default();
293        rs.add(HALF);
294        rs.add(HALF);
295        assert_eq!(rs.total(), Some(&(HALF * 2))); // 2^63 - 1, no overflow
296        rs.add(1);
297        assert_eq!(rs.total(), Some(&(HALF + 1))); // evicted half → half + 1
298    }
299
300    #[test]
301    fn overflow_negative() {
302        let mut rs = RollingSum::<i32, 3>::default();
303
304        rs.add(i32::MAX); // Total = MAX
305        assert!(rs.total().is_some());
306
307        rs.add(i32::MIN); // Total = MAX + MIN
308        assert!(rs.total().is_some());
309
310        rs.add(i32::MAX); // Total = MAX + MIN + MAX
311        assert!(rs.total().is_some());
312
313        rs.add(i32::MAX); // Total = MIN + MAX + MAX
314        assert!(rs.total().is_some());
315
316        rs.add(0); // Total = MAX + MAX + 0
317        assert!(rs.total().is_none());
318
319        rs.add(0); // Total = MAX + 0 + 0
320        assert_eq!(rs.total(), Some(&i32::MAX));
321    }
322
323    #[test]
324    fn underflow_negative() {
325        let mut rs = RollingSum::<i32, 3>::default();
326
327        rs.add(i32::MIN);
328        assert!(rs.total().is_some());
329
330        rs.add(-1);
331        assert!(rs.total().is_none());
332
333        rs.add(1);
334        assert_eq!(rs.total(), Some(&i32::MIN));
335    }
336
337    #[test]
338    fn decimal_overflow() {
339        let mut rs = RollingSum::<D1, 4>::default();
340
341        rs.add(D1::MAX);
342        assert!(rs.total().is_some());
343
344        for _ in 0..100 {
345            rs.add(D1::MAX);
346            assert!(rs.total().is_none());
347        }
348
349        for _ in 0..3 {
350            rs.add(D1::ZERO);
351        }
352
353        rs.add(D1::MIN_UNIT);
354        assert!(matches!(rs.total(), Some(&D1::MIN_UNIT)));
355    }
356
357    /// Feeds inputs from an `(input, expected)` iterator into
358    /// a RollingSum. Compares each total to `expected` and panics
359    /// if they're not equal.
360    fn expect_total<T, const WINDOW: usize>(input_and_expected: impl Iterator<Item = (T, T)>)
361    where
362        T: WrappingAdd
363            + WrappingSub
364            + CheckedAdd
365            + CheckedSub
366            + PartialOrd
367            + Copy
368            + Default
369            + Debug,
370    {
371        let mut roll: RollingSum<T, WINDOW> = RollingSum::default();
372        for (input, expected) in input_and_expected {
373            roll.add(input);
374            assert_eq!(*roll.total().unwrap(), expected);
375        }
376    }
377}