Skip to main content

llm_budget_window/
window.rs

1//! Sliding-window budget core.
2
3use std::collections::VecDeque;
4use std::sync::Mutex;
5use std::time::{Duration, Instant};
6
7use crate::error::WindowBreached;
8
9#[derive(Debug, Clone, Copy)]
10struct Entry {
11    at: Instant,
12    tokens: u64,
13    usd: f64,
14}
15
16/// Configuration for one rolling window.
17///
18/// Either or both caps may be set; an unset cap is unbounded.
19#[derive(Debug, Clone)]
20pub struct Window {
21    name: String,
22    duration: Duration,
23    token_cap: Option<u64>,
24    usd_cap: Option<f64>,
25}
26
27impl Window {
28    /// Create a window with the given name and duration. No caps set yet.
29    pub fn new(name: impl Into<String>, duration: Duration) -> Self {
30        Self {
31            name: name.into(),
32            duration,
33            token_cap: None,
34            usd_cap: None,
35        }
36    }
37
38    /// Add a token cap (chainable).
39    pub fn with_token_cap(mut self, cap: u64) -> Self {
40        self.token_cap = Some(cap);
41        self
42    }
43
44    /// Add a USD cap (chainable).
45    pub fn with_usd_cap(mut self, cap: f64) -> Self {
46        self.usd_cap = Some(cap);
47        self
48    }
49
50    /// Window name (for error messages).
51    pub fn name(&self) -> &str {
52        &self.name
53    }
54
55    /// Window duration.
56    pub fn duration(&self) -> Duration {
57        self.duration
58    }
59}
60
61#[derive(Debug)]
62struct WindowState {
63    cfg: Window,
64    entries: VecDeque<Entry>,
65    tokens_sum: u64,
66    usd_sum: f64,
67}
68
69impl WindowState {
70    fn new(cfg: Window) -> Self {
71        Self {
72            cfg,
73            entries: VecDeque::new(),
74            tokens_sum: 0,
75            usd_sum: 0.0,
76        }
77    }
78
79    fn evict_older_than(&mut self, now: Instant) {
80        let cutoff = now.checked_sub(self.cfg.duration).unwrap_or(now);
81        while let Some(front) = self.entries.front() {
82            if front.at < cutoff {
83                self.tokens_sum -= front.tokens;
84                self.usd_sum -= front.usd;
85                self.entries.pop_front();
86            } else {
87                break;
88            }
89        }
90    }
91
92    /// Check if (tokens, usd) would breach this window. Caller must have
93    /// called `evict_older_than(now)` first.
94    fn check_capacity(&self, tokens: u64, usd: f64) -> Result<(), WindowBreached> {
95        if let Some(cap) = self.cfg.token_cap {
96            let attempted = self.tokens_sum + tokens;
97            if attempted > cap {
98                return Err(WindowBreached {
99                    window_name: self.cfg.name.clone(),
100                    axis: "tokens",
101                    attempted: attempted as f64,
102                    cap: cap as f64,
103                });
104            }
105        }
106        if let Some(cap) = self.cfg.usd_cap {
107            let attempted = self.usd_sum + usd;
108            if attempted > cap {
109                return Err(WindowBreached {
110                    window_name: self.cfg.name.clone(),
111                    axis: "usd",
112                    attempted,
113                    cap,
114                });
115            }
116        }
117        Ok(())
118    }
119
120    fn record(&mut self, at: Instant, tokens: u64, usd: f64) {
121        self.entries.push_back(Entry { at, tokens, usd });
122        self.tokens_sum += tokens;
123        self.usd_sum += usd;
124    }
125}
126
127/// Immutable snapshot of one window's current totals.
128#[derive(Debug, Clone)]
129pub struct WindowSnapshot {
130    /// Window name.
131    pub name: String,
132    /// Window duration.
133    pub duration: Duration,
134    /// Tokens currently inside the window.
135    pub tokens_used: u64,
136    /// USD currently inside the window.
137    pub usd_used: f64,
138    /// Configured token cap, if any.
139    pub token_cap: Option<u64>,
140    /// Configured USD cap, if any.
141    pub usd_cap: Option<f64>,
142    /// How many discrete records currently fall in the window.
143    pub entry_count: usize,
144}
145
146/// Thread-safe time-windowed budget across N windows.
147///
148/// Every `record()` checks every window. If any window would breach,
149/// returns the first breach as `WindowBreached` and applies no change.
150pub struct BudgetWindows {
151    inner: Mutex<Vec<WindowState>>,
152}
153
154impl BudgetWindows {
155    /// Create a budget across the given windows.
156    pub fn new(windows: Vec<Window>) -> Self {
157        let inner = windows.into_iter().map(WindowState::new).collect();
158        Self {
159            inner: Mutex::new(inner),
160        }
161    }
162
163    /// Try to record `(tokens, usd)` against every window. Returns
164    /// `Err(WindowBreached)` from the first window that would breach,
165    /// and applies no change to any window. Returns `Ok(())` on success.
166    pub fn record(&self, tokens: u64, usd: f64) -> Result<(), WindowBreached> {
167        let now = Instant::now();
168        let mut state = self.inner.lock().expect("BudgetWindows lock poisoned");
169        // evict first so each window's "now" view is consistent
170        for w in state.iter_mut() {
171            w.evict_older_than(now);
172        }
173        // check all
174        for w in state.iter() {
175            w.check_capacity(tokens, usd)?;
176        }
177        // commit
178        for w in state.iter_mut() {
179            w.record(now, tokens, usd);
180        }
181        Ok(())
182    }
183
184    /// Snapshot all windows' current totals.
185    pub fn snapshot(&self) -> Vec<WindowSnapshot> {
186        let now = Instant::now();
187        let mut state = self.inner.lock().expect("BudgetWindows lock poisoned");
188        let mut out = Vec::with_capacity(state.len());
189        for w in state.iter_mut() {
190            w.evict_older_than(now);
191            out.push(WindowSnapshot {
192                name: w.cfg.name.clone(),
193                duration: w.cfg.duration,
194                tokens_used: w.tokens_sum,
195                usd_used: w.usd_sum,
196                token_cap: w.cfg.token_cap,
197                usd_cap: w.cfg.usd_cap,
198                entry_count: w.entries.len(),
199            });
200        }
201        out
202    }
203
204    /// Drop all entries from all windows.
205    pub fn reset(&self) {
206        let mut state = self.inner.lock().expect("BudgetWindows lock poisoned");
207        for w in state.iter_mut() {
208            w.entries.clear();
209            w.tokens_sum = 0;
210            w.usd_sum = 0.0;
211        }
212    }
213}
214
215#[cfg(test)]
216mod tests {
217    use super::*;
218
219    #[test]
220    fn record_under_caps() {
221        let bw = BudgetWindows::new(vec![Window::new("min", Duration::from_secs(60))
222            .with_token_cap(1000)
223            .with_usd_cap(1.0)]);
224        bw.record(500, 0.5).unwrap();
225        let snap = &bw.snapshot()[0];
226        assert_eq!(snap.tokens_used, 500);
227        assert_eq!(snap.usd_used, 0.5);
228        assert_eq!(snap.entry_count, 1);
229    }
230
231    #[test]
232    fn record_breaches_first_breaching_window() {
233        let bw = BudgetWindows::new(vec![
234            Window::new("min", Duration::from_secs(60)).with_token_cap(100),
235            Window::new("hour", Duration::from_secs(3600)).with_usd_cap(1.0),
236        ]);
237        bw.record(90, 0.5).unwrap();
238        let err = bw.record(20, 0.0).unwrap_err();
239        assert_eq!(err.window_name, "min");
240        assert_eq!(err.axis, "tokens");
241        // state unchanged
242        let snap = &bw.snapshot()[0];
243        assert_eq!(snap.tokens_used, 90);
244    }
245
246    #[test]
247    fn unset_cap_is_unbounded() {
248        let bw = BudgetWindows::new(vec![Window::new("any", Duration::from_secs(60))]);
249        // no caps -> never breaches
250        for _ in 0..1000 {
251            bw.record(1_000_000, 1_000_000.0).unwrap();
252        }
253    }
254
255    #[test]
256    fn old_entries_age_out() {
257        let bw = BudgetWindows::new(vec![
258            Window::new("fast", Duration::from_millis(50)).with_token_cap(100)
259        ]);
260        bw.record(80, 0.0).unwrap();
261        // record would exceed cap if entry still counted
262        assert!(bw.record(50, 0.0).is_err());
263        std::thread::sleep(Duration::from_millis(70));
264        // after eviction, new record fits
265        bw.record(50, 0.0).unwrap();
266        let snap = &bw.snapshot()[0];
267        assert_eq!(snap.tokens_used, 50);
268        assert_eq!(snap.entry_count, 1);
269    }
270
271    #[test]
272    fn multiple_windows_all_track() {
273        let bw = BudgetWindows::new(vec![
274            Window::new("minute", Duration::from_secs(60)).with_usd_cap(1.0),
275            Window::new("hour", Duration::from_secs(3600)).with_usd_cap(10.0),
276        ]);
277        for _ in 0..5 {
278            bw.record(100, 0.1).unwrap();
279        }
280        let snaps = bw.snapshot();
281        assert_eq!(snaps.len(), 2);
282        assert!((snaps[0].usd_used - 0.5).abs() < 1e-9);
283        assert!((snaps[1].usd_used - 0.5).abs() < 1e-9);
284    }
285
286    #[test]
287    fn reset_drops_everything() {
288        let bw = BudgetWindows::new(vec![
289            Window::new("min", Duration::from_secs(60)).with_token_cap(1000)
290        ]);
291        bw.record(500, 0.0).unwrap();
292        bw.reset();
293        let snap = &bw.snapshot()[0];
294        assert_eq!(snap.tokens_used, 0);
295        assert_eq!(snap.entry_count, 0);
296    }
297
298    #[test]
299    fn breach_in_atomic_record_does_not_partially_commit() {
300        let bw = BudgetWindows::new(vec![
301            Window::new("min", Duration::from_secs(60)).with_token_cap(1000),
302            Window::new("hour", Duration::from_secs(3600)).with_usd_cap(1.0),
303        ]);
304        bw.record(500, 0.9).unwrap();
305        // would breach the hour window (0.9 + 0.2 > 1.0); the min window
306        // would have accepted, but we should atomically reject
307        assert!(bw.record(100, 0.2).is_err());
308        let snaps = bw.snapshot();
309        assert_eq!(snaps[0].tokens_used, 500);
310        assert!((snaps[1].usd_used - 0.9).abs() < 1e-9);
311    }
312}