use std::collections::VecDeque;
use std::sync::Mutex;
use std::time::{Duration, Instant};
use crate::error::WindowBreached;
#[derive(Debug, Clone, Copy)]
struct Entry {
at: Instant,
tokens: u64,
usd: f64,
}
#[derive(Debug, Clone)]
pub struct Window {
name: String,
duration: Duration,
token_cap: Option<u64>,
usd_cap: Option<f64>,
}
impl Window {
pub fn new(name: impl Into<String>, duration: Duration) -> Self {
Self {
name: name.into(),
duration,
token_cap: None,
usd_cap: None,
}
}
pub fn with_token_cap(mut self, cap: u64) -> Self {
self.token_cap = Some(cap);
self
}
pub fn with_usd_cap(mut self, cap: f64) -> Self {
self.usd_cap = Some(cap);
self
}
pub fn name(&self) -> &str {
&self.name
}
pub fn duration(&self) -> Duration {
self.duration
}
}
#[derive(Debug)]
struct WindowState {
cfg: Window,
entries: VecDeque<Entry>,
tokens_sum: u64,
usd_sum: f64,
}
impl WindowState {
fn new(cfg: Window) -> Self {
Self {
cfg,
entries: VecDeque::new(),
tokens_sum: 0,
usd_sum: 0.0,
}
}
fn evict_older_than(&mut self, now: Instant) {
let cutoff = now.checked_sub(self.cfg.duration).unwrap_or(now);
while let Some(front) = self.entries.front() {
if front.at < cutoff {
self.tokens_sum -= front.tokens;
self.usd_sum -= front.usd;
self.entries.pop_front();
} else {
break;
}
}
}
fn check_capacity(&self, tokens: u64, usd: f64) -> Result<(), WindowBreached> {
if let Some(cap) = self.cfg.token_cap {
let attempted = self.tokens_sum + tokens;
if attempted > cap {
return Err(WindowBreached {
window_name: self.cfg.name.clone(),
axis: "tokens",
attempted: attempted as f64,
cap: cap as f64,
});
}
}
if let Some(cap) = self.cfg.usd_cap {
let attempted = self.usd_sum + usd;
if attempted > cap {
return Err(WindowBreached {
window_name: self.cfg.name.clone(),
axis: "usd",
attempted,
cap,
});
}
}
Ok(())
}
fn record(&mut self, at: Instant, tokens: u64, usd: f64) {
self.entries.push_back(Entry { at, tokens, usd });
self.tokens_sum += tokens;
self.usd_sum += usd;
}
}
#[derive(Debug, Clone)]
pub struct WindowSnapshot {
pub name: String,
pub duration: Duration,
pub tokens_used: u64,
pub usd_used: f64,
pub token_cap: Option<u64>,
pub usd_cap: Option<f64>,
pub entry_count: usize,
}
pub struct BudgetWindows {
inner: Mutex<Vec<WindowState>>,
}
impl BudgetWindows {
pub fn new(windows: Vec<Window>) -> Self {
let inner = windows.into_iter().map(WindowState::new).collect();
Self {
inner: Mutex::new(inner),
}
}
pub fn record(&self, tokens: u64, usd: f64) -> Result<(), WindowBreached> {
let now = Instant::now();
let mut state = self.inner.lock().expect("BudgetWindows lock poisoned");
for w in state.iter_mut() {
w.evict_older_than(now);
}
for w in state.iter() {
w.check_capacity(tokens, usd)?;
}
for w in state.iter_mut() {
w.record(now, tokens, usd);
}
Ok(())
}
pub fn snapshot(&self) -> Vec<WindowSnapshot> {
let now = Instant::now();
let mut state = self.inner.lock().expect("BudgetWindows lock poisoned");
let mut out = Vec::with_capacity(state.len());
for w in state.iter_mut() {
w.evict_older_than(now);
out.push(WindowSnapshot {
name: w.cfg.name.clone(),
duration: w.cfg.duration,
tokens_used: w.tokens_sum,
usd_used: w.usd_sum,
token_cap: w.cfg.token_cap,
usd_cap: w.cfg.usd_cap,
entry_count: w.entries.len(),
});
}
out
}
pub fn reset(&self) {
let mut state = self.inner.lock().expect("BudgetWindows lock poisoned");
for w in state.iter_mut() {
w.entries.clear();
w.tokens_sum = 0;
w.usd_sum = 0.0;
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn record_under_caps() {
let bw = BudgetWindows::new(vec![Window::new("min", Duration::from_secs(60))
.with_token_cap(1000)
.with_usd_cap(1.0)]);
bw.record(500, 0.5).unwrap();
let snap = &bw.snapshot()[0];
assert_eq!(snap.tokens_used, 500);
assert_eq!(snap.usd_used, 0.5);
assert_eq!(snap.entry_count, 1);
}
#[test]
fn record_breaches_first_breaching_window() {
let bw = BudgetWindows::new(vec![
Window::new("min", Duration::from_secs(60)).with_token_cap(100),
Window::new("hour", Duration::from_secs(3600)).with_usd_cap(1.0),
]);
bw.record(90, 0.5).unwrap();
let err = bw.record(20, 0.0).unwrap_err();
assert_eq!(err.window_name, "min");
assert_eq!(err.axis, "tokens");
let snap = &bw.snapshot()[0];
assert_eq!(snap.tokens_used, 90);
}
#[test]
fn unset_cap_is_unbounded() {
let bw = BudgetWindows::new(vec![Window::new("any", Duration::from_secs(60))]);
for _ in 0..1000 {
bw.record(1_000_000, 1_000_000.0).unwrap();
}
}
#[test]
fn old_entries_age_out() {
let bw = BudgetWindows::new(vec![
Window::new("fast", Duration::from_millis(50)).with_token_cap(100)
]);
bw.record(80, 0.0).unwrap();
assert!(bw.record(50, 0.0).is_err());
std::thread::sleep(Duration::from_millis(70));
bw.record(50, 0.0).unwrap();
let snap = &bw.snapshot()[0];
assert_eq!(snap.tokens_used, 50);
assert_eq!(snap.entry_count, 1);
}
#[test]
fn multiple_windows_all_track() {
let bw = BudgetWindows::new(vec![
Window::new("minute", Duration::from_secs(60)).with_usd_cap(1.0),
Window::new("hour", Duration::from_secs(3600)).with_usd_cap(10.0),
]);
for _ in 0..5 {
bw.record(100, 0.1).unwrap();
}
let snaps = bw.snapshot();
assert_eq!(snaps.len(), 2);
assert!((snaps[0].usd_used - 0.5).abs() < 1e-9);
assert!((snaps[1].usd_used - 0.5).abs() < 1e-9);
}
#[test]
fn reset_drops_everything() {
let bw = BudgetWindows::new(vec![
Window::new("min", Duration::from_secs(60)).with_token_cap(1000)
]);
bw.record(500, 0.0).unwrap();
bw.reset();
let snap = &bw.snapshot()[0];
assert_eq!(snap.tokens_used, 0);
assert_eq!(snap.entry_count, 0);
}
#[test]
fn breach_in_atomic_record_does_not_partially_commit() {
let bw = BudgetWindows::new(vec![
Window::new("min", Duration::from_secs(60)).with_token_cap(1000),
Window::new("hour", Duration::from_secs(3600)).with_usd_cap(1.0),
]);
bw.record(500, 0.9).unwrap();
assert!(bw.record(100, 0.2).is_err());
let snaps = bw.snapshot();
assert_eq!(snaps[0].tokens_used, 500);
assert!((snaps[1].usd_used - 0.9).abs() < 1e-9);
}
}