high_roller/
rolling_sum.rs1use num_traits::{CheckedAdd, CheckedSub, WrappingAdd, WrappingSub};
2use std::{collections::VecDeque, num::NonZeroUsize};
3
4#[derive(Debug)]
9pub struct RollingSum<T> {
10 deq: VecDeque<T>,
11 total: T,
12 capacity: usize,
13 wrap_ct: usize,
14}
15
16impl<T> RollingSum<T>
17where
18 T: WrappingAdd + WrappingSub + CheckedAdd + CheckedSub + PartialOrd + Copy + Default,
19{
20 #[must_use]
21 pub fn new(capacity: NonZeroUsize) -> Self {
22 Self {
23 deq: VecDeque::with_capacity(capacity.into()),
24 total: T::default(),
25 capacity: capacity.into(),
26 wrap_ct: 0,
27 }
28 }
29
30 pub fn add(&mut self, val: T) -> bool {
37 if self.deq.len() == self.capacity {
38 let popped = self.deq.pop_front().expect(
39 "len is equal to capacity, and capacity is nonzero. So an element must exist.",
40 );
41 let before = self.total;
42 self.total = self.total.wrapping_sub(&popped);
43 if before.checked_sub(&popped).is_none() {
44 self.wrap_ct -= 1;
45 }
46 }
47
48 let before_add = self.total;
49 self.total = self.total.wrapping_add(&val);
50 self.wrap_ct += before_add.checked_add(&val).is_none() as usize;
51
52 self.deq.push_back(val);
53 true
54 }
55
56 #[must_use]
64 pub fn total(&self) -> Option<&T> {
65 (self.wrap_ct == 0).then_some(&self.total)
66 }
67}
68
69#[cfg(test)]
70mod tests {
71 use super::*;
72
73 fn nz(n: usize) -> NonZeroUsize {
74 NonZeroUsize::new(n).unwrap()
75 }
76
77 fn totals<T>(vals: &[T], cap: usize) -> Vec<T>
79 where
80 T: WrappingAdd + WrappingSub + CheckedAdd + CheckedSub + PartialOrd + Copy + Default,
81 {
82 let mut rs = RollingSum::new(nz(cap));
83 vals.iter()
84 .map(|&v| {
85 rs.add(v);
86 *rs.total().unwrap()
87 })
88 .collect()
89 }
90
91 #[test]
93 fn total_before_any_add_is_init() {
94 let rs: RollingSum<u32> = RollingSum::new(nz(3));
95 assert_eq!(rs.total(), Some(&0u32));
96 }
97
98 #[test]
100 fn single_add_below_capacity() {
101 let mut rs: RollingSum<u32> = RollingSum::new(nz(3));
102 rs.add(10);
103 assert_eq!(rs.total(), Some(&10));
104 }
105
106 #[test]
108 fn fill_to_capacity_no_eviction() {
109 assert_eq!(totals(&[1u32, 2, 3], 3), vec![1, 3, 6]);
110 }
111
112 #[test]
114 fn first_eviction_at_capacity_plus_one() {
115 assert_eq!(totals(&[1u32, 2, 3, 4], 3), vec![1, 3, 6, 9]);
117 }
118
119 #[test]
121 fn sliding_window_trace() {
122 assert_eq!(totals(&[5u32, 3, 8, 2, 6], 3), vec![5, 8, 16, 13, 16]);
124 }
125
126 #[test]
128 fn window_of_one() {
129 assert_eq!(totals(&[5u32, 3, 9, 1], 1), vec![5, 3, 9, 1]);
130 }
131
132 #[test]
134 fn window_larger_than_input() {
135 assert_eq!(totals(&[1u32, 2, 3, 4, 5], 100), vec![1, 3, 6, 10, 15]);
136 }
137
138 #[test]
140 fn signed_integers() {
141 assert_eq!(totals(&[-3i32, 5, -2, 4], 2), vec![-3, 2, 3, 2]);
143 }
144
145 #[test]
147 fn u8_boundary_exact() {
148 let mut rs = RollingSum::new(nz(3));
149 rs.add(100);
150 rs.add(100);
151 rs.add(55);
152 assert_eq!(rs.total(), Some(&255)); rs.add(0);
154 assert_eq!(rs.total(), Some(&155)); }
156
157 #[test]
160 fn overflow_detected_then_recovered() {
161 let mut rs = RollingSum::<u8>::new(nz(2));
162 rs.add(200);
163 assert_eq!(rs.total(), Some(&200)); rs.add(100); assert_eq!(rs.total(), None); rs.add(50); assert_eq!(rs.total(), Some(&150)); }
169
170 #[test]
175 fn double_overflow_and_full_recovery() {
176 let mut rs = RollingSum::<u8>::new(nz(3));
177
178 rs.add(200);
179 assert_eq!(rs.total(), Some(&200)); rs.add(200);
181 assert_eq!(rs.total(), None); rs.add(200);
183 assert_eq!(rs.total(), None); rs.add(10);
186 assert_eq!(rs.total(), None); rs.add(10);
188 assert_eq!(rs.total(), Some(&220)); rs.add(10);
190 assert_eq!(rs.total(), Some(&30)); }
192
193 #[test]
195 fn u64_large_values_no_overflow() {
196 let half = u64::MAX / 2;
197 let mut rs = RollingSum::new(nz(2));
198 rs.add(half);
199 rs.add(half);
200 assert_eq!(rs.total(), Some(&(half * 2))); rs.add(1);
202 assert_eq!(rs.total(), Some(&(half + 1))); }
204}