use arraydeque::ArrayDeque;
use num_traits::CheckedAdd;
use num_traits::CheckedSub;
use num_traits::WrappingAdd;
use num_traits::WrappingSub;
#[derive(Debug)]
pub struct RollingSum<T, const WINDOW: usize> {
deq: ArrayDeque<T, WINDOW>,
total: T,
zero: T,
balance: isize,
}
impl<T, const W: usize> Default for RollingSum<T, W>
where
T: Default,
{
fn default() -> Self {
Self::new(T::default(), T::default())
}
}
impl<T, const WINDOW: usize> RollingSum<T, WINDOW>
where
T: Default,
{
#[must_use]
pub const fn new(init: T, zero: T) -> Self {
const { assert!(WINDOW != 0, "RollingSum with WINDOW == 0 is not permitted") };
Self {
deq: ArrayDeque::new(),
total: init,
balance: 0,
zero,
}
}
}
impl<T, const WINDOW: usize> RollingSum<T, WINDOW>
where
T: WrappingAdd + WrappingSub + CheckedAdd + CheckedSub + PartialOrd + Copy + Default,
{
#[allow(clippy::expect_used)]
#[allow(clippy::missing_panics_doc)]
pub fn add(&mut self, val: T) {
if self.deq.is_full() {
let popped = self.deq.pop_front().expect(
"len is equal to capacity, and capacity is nonzero. So an element must exist.",
);
let changed = self.total.checked_sub(&popped).is_none();
self.total = self.total.wrapping_sub(&popped);
if changed {
self.balance = self
.balance
.checked_add(if popped >= self.zero { -1 } else { 1 })
.expect("overflow count itself overflowed");
}
}
let changed = self.total.checked_add(&val).is_none();
self.total = self.total.wrapping_add(&val);
if changed {
self.balance = self
.balance
.checked_add(if val >= self.zero { 1 } else { -1 })
.expect("overflow count itself overflowed");
}
self.deq.push_back(val).expect("deq is not full");
}
#[must_use]
pub fn total(&self) -> Option<&T> {
(self.balance == 0).then_some(&self.total)
}
}
#[cfg(test)]
pub mod for_tests {
use arraydeque::ArrayDeque;
use arraydeque::Wrapping;
use num_bigint::BigInt;
#[derive(Debug, Default)]
pub struct NaiveRollingSum<T, const WINDOW: usize> {
deq: ArrayDeque<T, WINDOW, Wrapping>,
sum: BigInt,
}
impl<T, const WINDOW: usize> NaiveRollingSum<T, WINDOW>
where
T: Clone + Default + Into<BigInt> + for<'a> TryFrom<&'a BigInt>,
{
#[must_use]
pub fn new(init: T) -> Self {
const { assert!(WINDOW != 0, "RollingSum with WINDOW == 0 is not permitted") };
Self {
deq: ArrayDeque::new(),
sum: init.into(),
}
}
#[allow(clippy::expect_used)]
#[allow(clippy::missing_panics_doc)]
pub fn add(&mut self, val: T) {
self.sum = self
.sum
.checked_add(&Into::<BigInt>::into(val.clone()))
.expect("no bigint overflow");
if let Some(replaced) = self.deq.push_back(val) {
self.sum = self
.sum
.checked_sub(&Into::<BigInt>::into(replaced))
.expect("no bigint underflow");
}
}
#[must_use]
pub fn total(&self) -> Option<T> {
(&self.sum).try_into().ok()
}
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::*;
use crate::decimal::D1;
use crate::decimal::D5;
use crate::rolling_sum::for_tests::NaiveRollingSum;
use core::fmt::Debug;
use rand::distr::Uniform;
use rand::rngs::SmallRng;
use rand::RngExt;
use rand::SeedableRng;
#[test]
fn rng_with_naive() {
const QLEN: usize = 1000;
const STREAM_LEN: usize = 100_000;
let sample = SmallRng::seed_from_u64(57).sample_iter(Uniform::new(-800f32, 800.).unwrap());
let mut roller = RollingSum::<D5, QLEN>::default();
let mut naive = NaiveRollingSum::<D5, QLEN>::default();
for val in sample.take(STREAM_LEN) {
let d4 = D5::cast(val.into());
roller.add(d4);
naive.add(d4);
assert_eq!(roller.total(), naive.total().as_ref());
}
}
#[test]
fn basic_math() {
let mut rs: RollingSum<u32, 3> = RollingSum::default();
assert_eq!(rs.total(), Some(&0u32));
rs.add(10);
assert_eq!(rs.total(), Some(&10));
}
#[test]
fn basic_rolling() {
self::expect_total::<u32, 3>([1, 2, 3].into_iter().zip([1, 3, 6]));
self::expect_total::<u32, 3>([1, 2, 3, 4].into_iter().zip([1, 3, 6, 9]));
self::expect_total::<u32, 3>([5, 3, 8, 2, 6].into_iter().zip([5, 8, 16, 13, 16]));
self::expect_total::<u32, 1>([5, 3, 9, 1].into_iter().zip([5, 3, 9, 1]));
self::expect_total::<u32, 100>([1, 2, 3, 4, 5].into_iter().zip([1, 3, 6, 10, 15]));
self::expect_total::<i32, 2>([-3, 5, -2, 4].into_iter().zip([-3, 2, 3, 2]));
}
#[test]
fn limit_boundary() {
let mut rs = RollingSum::<u8, 3>::default();
rs.add(100);
rs.add(100);
rs.add(55);
assert_eq!(rs.total(), Some(&255));
rs.add(0);
assert_eq!(rs.total(), Some(&155));
}
#[test]
fn limit_overflow() {
let mut rs = RollingSum::<u8, 2>::default();
rs.add(200);
assert_eq!(rs.total(), Some(&200)); rs.add(100); assert_eq!(rs.total(), None); rs.add(50); assert_eq!(rs.total(), Some(&150)); }
#[test]
fn limit_overflow_2x() {
let mut rs = RollingSum::<u8, 3>::default();
rs.add(200);
assert_eq!(rs.total(), Some(&200));
rs.add(200);
assert_eq!(rs.total(), None);
rs.add(200);
assert_eq!(rs.total(), None);
rs.add(10);
assert_eq!(rs.total(), None);
rs.add(10);
assert_eq!(rs.total(), Some(&220));
rs.add(10);
assert_eq!(rs.total(), Some(&30));
}
#[test]
fn sanity_check_big() {
const HALF: u64 = (u64::MAX as f64 / 2.) as u64 - 1;
let mut rs = RollingSum::<_, 2>::default();
rs.add(HALF);
rs.add(HALF);
assert_eq!(rs.total(), Some(&(HALF * 2)));
rs.add(1);
assert_eq!(rs.total(), Some(&(HALF + 1)));
}
#[test]
fn limit_extremes() {
let mut rs = RollingSum::<i32, 3>::default();
rs.add(i32::MAX);
assert!(rs.total().is_some());
rs.add(i32::MIN);
assert!(rs.total().is_some());
rs.add(i32::MAX);
assert!(rs.total().is_some());
rs.add(i32::MAX);
assert!(rs.total().is_some());
rs.add(0);
assert!(rs.total().is_none());
rs.add(0);
assert_eq!(rs.total(), Some(&i32::MAX));
}
#[test]
fn limit_underflow() {
let mut rs = RollingSum::<i32, 3>::default();
rs.add(i32::MIN);
assert!(rs.total().is_some());
rs.add(-1);
assert!(rs.total().is_none());
rs.add(1);
assert_eq!(rs.total(), Some(&i32::MIN));
}
#[test]
fn decimal_overflow() {
let mut rs = RollingSum::<D1, 4>::default();
rs.add(D1::MAX);
assert!(rs.total().is_some());
for _ in 0..100 {
rs.add(D1::MAX);
assert!(rs.total().is_none());
}
for _ in 0..3 {
rs.add(D1::ZERO);
}
rs.add(D1::MIN_UNIT);
assert!(matches!(rs.total(), Some(&D1::MIN_UNIT)));
}
fn expect_total<T, const WINDOW: usize>(input_and_expected: impl Iterator<Item = (T, T)>)
where
T: WrappingAdd
+ WrappingSub
+ CheckedAdd
+ CheckedSub
+ PartialOrd
+ Copy
+ Default
+ Debug,
{
let mut roll: RollingSum<T, WINDOW> = RollingSum::default();
for (input, expected) in input_and_expected {
roll.add(input);
assert_eq!(*roll.total().unwrap(), expected);
}
}
}