1use std::collections::VecDeque;
4use std::sync::Mutex;
5use std::time::{Duration, Instant};
6
7use crate::error::WindowBreached;
8
9#[derive(Debug, Clone, Copy)]
10struct Entry {
11 at: Instant,
12 tokens: u64,
13 usd: f64,
14}
15
16#[derive(Debug, Clone)]
20pub struct Window {
21 name: String,
22 duration: Duration,
23 token_cap: Option<u64>,
24 usd_cap: Option<f64>,
25}
26
27impl Window {
28 pub fn new(name: impl Into<String>, duration: Duration) -> Self {
30 Self {
31 name: name.into(),
32 duration,
33 token_cap: None,
34 usd_cap: None,
35 }
36 }
37
38 pub fn with_token_cap(mut self, cap: u64) -> Self {
40 self.token_cap = Some(cap);
41 self
42 }
43
44 pub fn with_usd_cap(mut self, cap: f64) -> Self {
46 self.usd_cap = Some(cap);
47 self
48 }
49
50 pub fn name(&self) -> &str {
52 &self.name
53 }
54
55 pub fn duration(&self) -> Duration {
57 self.duration
58 }
59}
60
61#[derive(Debug)]
62struct WindowState {
63 cfg: Window,
64 entries: VecDeque<Entry>,
65 tokens_sum: u64,
66 usd_sum: f64,
67}
68
69impl WindowState {
70 fn new(cfg: Window) -> Self {
71 Self {
72 cfg,
73 entries: VecDeque::new(),
74 tokens_sum: 0,
75 usd_sum: 0.0,
76 }
77 }
78
79 fn evict_older_than(&mut self, now: Instant) {
80 let cutoff = now.checked_sub(self.cfg.duration).unwrap_or(now);
81 while let Some(front) = self.entries.front() {
82 if front.at < cutoff {
83 self.tokens_sum -= front.tokens;
84 self.usd_sum -= front.usd;
85 self.entries.pop_front();
86 } else {
87 break;
88 }
89 }
90 }
91
92 fn check_capacity(&self, tokens: u64, usd: f64) -> Result<(), WindowBreached> {
95 if let Some(cap) = self.cfg.token_cap {
96 let attempted = self.tokens_sum + tokens;
97 if attempted > cap {
98 return Err(WindowBreached {
99 window_name: self.cfg.name.clone(),
100 axis: "tokens",
101 attempted: attempted as f64,
102 cap: cap as f64,
103 });
104 }
105 }
106 if let Some(cap) = self.cfg.usd_cap {
107 let attempted = self.usd_sum + usd;
108 if attempted > cap {
109 return Err(WindowBreached {
110 window_name: self.cfg.name.clone(),
111 axis: "usd",
112 attempted,
113 cap,
114 });
115 }
116 }
117 Ok(())
118 }
119
120 fn record(&mut self, at: Instant, tokens: u64, usd: f64) {
121 self.entries.push_back(Entry { at, tokens, usd });
122 self.tokens_sum += tokens;
123 self.usd_sum += usd;
124 }
125}
126
127#[derive(Debug, Clone)]
129pub struct WindowSnapshot {
130 pub name: String,
132 pub duration: Duration,
134 pub tokens_used: u64,
136 pub usd_used: f64,
138 pub token_cap: Option<u64>,
140 pub usd_cap: Option<f64>,
142 pub entry_count: usize,
144}
145
146pub struct BudgetWindows {
151 inner: Mutex<Vec<WindowState>>,
152}
153
154impl BudgetWindows {
155 pub fn new(windows: Vec<Window>) -> Self {
157 let inner = windows.into_iter().map(WindowState::new).collect();
158 Self {
159 inner: Mutex::new(inner),
160 }
161 }
162
163 pub fn record(&self, tokens: u64, usd: f64) -> Result<(), WindowBreached> {
167 let now = Instant::now();
168 let mut state = self.inner.lock().expect("BudgetWindows lock poisoned");
169 for w in state.iter_mut() {
171 w.evict_older_than(now);
172 }
173 for w in state.iter() {
175 w.check_capacity(tokens, usd)?;
176 }
177 for w in state.iter_mut() {
179 w.record(now, tokens, usd);
180 }
181 Ok(())
182 }
183
184 pub fn snapshot(&self) -> Vec<WindowSnapshot> {
186 let now = Instant::now();
187 let mut state = self.inner.lock().expect("BudgetWindows lock poisoned");
188 let mut out = Vec::with_capacity(state.len());
189 for w in state.iter_mut() {
190 w.evict_older_than(now);
191 out.push(WindowSnapshot {
192 name: w.cfg.name.clone(),
193 duration: w.cfg.duration,
194 tokens_used: w.tokens_sum,
195 usd_used: w.usd_sum,
196 token_cap: w.cfg.token_cap,
197 usd_cap: w.cfg.usd_cap,
198 entry_count: w.entries.len(),
199 });
200 }
201 out
202 }
203
204 pub fn reset(&self) {
206 let mut state = self.inner.lock().expect("BudgetWindows lock poisoned");
207 for w in state.iter_mut() {
208 w.entries.clear();
209 w.tokens_sum = 0;
210 w.usd_sum = 0.0;
211 }
212 }
213}
214
215#[cfg(test)]
216mod tests {
217 use super::*;
218
219 #[test]
220 fn record_under_caps() {
221 let bw = BudgetWindows::new(vec![Window::new("min", Duration::from_secs(60))
222 .with_token_cap(1000)
223 .with_usd_cap(1.0)]);
224 bw.record(500, 0.5).unwrap();
225 let snap = &bw.snapshot()[0];
226 assert_eq!(snap.tokens_used, 500);
227 assert_eq!(snap.usd_used, 0.5);
228 assert_eq!(snap.entry_count, 1);
229 }
230
231 #[test]
232 fn record_breaches_first_breaching_window() {
233 let bw = BudgetWindows::new(vec![
234 Window::new("min", Duration::from_secs(60)).with_token_cap(100),
235 Window::new("hour", Duration::from_secs(3600)).with_usd_cap(1.0),
236 ]);
237 bw.record(90, 0.5).unwrap();
238 let err = bw.record(20, 0.0).unwrap_err();
239 assert_eq!(err.window_name, "min");
240 assert_eq!(err.axis, "tokens");
241 let snap = &bw.snapshot()[0];
243 assert_eq!(snap.tokens_used, 90);
244 }
245
246 #[test]
247 fn unset_cap_is_unbounded() {
248 let bw = BudgetWindows::new(vec![Window::new("any", Duration::from_secs(60))]);
249 for _ in 0..1000 {
251 bw.record(1_000_000, 1_000_000.0).unwrap();
252 }
253 }
254
255 #[test]
256 fn old_entries_age_out() {
257 let bw = BudgetWindows::new(vec![
258 Window::new("fast", Duration::from_millis(50)).with_token_cap(100)
259 ]);
260 bw.record(80, 0.0).unwrap();
261 assert!(bw.record(50, 0.0).is_err());
263 std::thread::sleep(Duration::from_millis(70));
264 bw.record(50, 0.0).unwrap();
266 let snap = &bw.snapshot()[0];
267 assert_eq!(snap.tokens_used, 50);
268 assert_eq!(snap.entry_count, 1);
269 }
270
271 #[test]
272 fn multiple_windows_all_track() {
273 let bw = BudgetWindows::new(vec![
274 Window::new("minute", Duration::from_secs(60)).with_usd_cap(1.0),
275 Window::new("hour", Duration::from_secs(3600)).with_usd_cap(10.0),
276 ]);
277 for _ in 0..5 {
278 bw.record(100, 0.1).unwrap();
279 }
280 let snaps = bw.snapshot();
281 assert_eq!(snaps.len(), 2);
282 assert!((snaps[0].usd_used - 0.5).abs() < 1e-9);
283 assert!((snaps[1].usd_used - 0.5).abs() < 1e-9);
284 }
285
286 #[test]
287 fn reset_drops_everything() {
288 let bw = BudgetWindows::new(vec![
289 Window::new("min", Duration::from_secs(60)).with_token_cap(1000)
290 ]);
291 bw.record(500, 0.0).unwrap();
292 bw.reset();
293 let snap = &bw.snapshot()[0];
294 assert_eq!(snap.tokens_used, 0);
295 assert_eq!(snap.entry_count, 0);
296 }
297
298 #[test]
299 fn breach_in_atomic_record_does_not_partially_commit() {
300 let bw = BudgetWindows::new(vec![
301 Window::new("min", Duration::from_secs(60)).with_token_cap(1000),
302 Window::new("hour", Duration::from_secs(3600)).with_usd_cap(1.0),
303 ]);
304 bw.record(500, 0.9).unwrap();
305 assert!(bw.record(100, 0.2).is_err());
308 let snaps = bw.snapshot();
309 assert_eq!(snaps[0].tokens_used, 500);
310 assert!((snaps[1].usd_used - 0.9).abs() < 1e-9);
311 }
312}