1use std::sync::{Arc, Mutex};
42
43use crate::error::{Error, Result};
44
45type Callback = Arc<dyn Fn(f64) + Send + Sync>;
46
47struct Config {
48 max_usd: Option<f64>,
49 warn_at_usd: Option<f64>,
50 on_warning: Option<Callback>,
51 on_exceeded: Option<Callback>,
52}
53
54#[derive(Default)]
55struct State {
56 total_usd: f64,
57 warned: bool,
58 exceeded: bool,
59}
60
61struct Inner {
62 config: Config,
63 state: Mutex<State>,
64}
65
66#[derive(Clone)]
70pub struct BudgetTracker {
71 inner: Arc<Inner>,
72}
73
74impl BudgetTracker {
75 pub fn builder() -> BudgetBuilder {
77 BudgetBuilder::default()
78 }
79
80 pub fn record(&self, cost_usd: f64) {
84 if cost_usd <= 0.0 || !cost_usd.is_finite() {
85 return;
86 }
87
88 let (warn_fired, exceeded_fired, total) = {
89 let mut state = self.inner.state.lock().expect("budget mutex poisoned");
90 state.total_usd += cost_usd;
91
92 let warn_fired = match self.inner.config.warn_at_usd {
93 Some(threshold) if !state.warned && state.total_usd >= threshold => {
94 state.warned = true;
95 true
96 }
97 _ => false,
98 };
99
100 let exceeded_fired = match self.inner.config.max_usd {
101 Some(threshold) if !state.exceeded && state.total_usd >= threshold => {
102 state.exceeded = true;
103 true
104 }
105 _ => false,
106 };
107
108 (warn_fired, exceeded_fired, state.total_usd)
109 };
110
111 if warn_fired && let Some(cb) = &self.inner.config.on_warning {
112 cb(total);
113 }
114 if exceeded_fired && let Some(cb) = &self.inner.config.on_exceeded {
115 cb(total);
116 }
117 }
118
119 pub fn check(&self) -> Result<()> {
123 let Some(max) = self.inner.config.max_usd else {
124 return Ok(());
125 };
126 let total = self
127 .inner
128 .state
129 .lock()
130 .expect("budget mutex poisoned")
131 .total_usd;
132 if total >= max {
133 Err(Error::BudgetExceeded {
134 total_usd: total,
135 max_usd: max,
136 })
137 } else {
138 Ok(())
139 }
140 }
141
142 pub fn total_usd(&self) -> f64 {
144 self.inner
145 .state
146 .lock()
147 .expect("budget mutex poisoned")
148 .total_usd
149 }
150
151 pub fn remaining_usd(&self) -> Option<f64> {
154 let max = self.inner.config.max_usd?;
155 let total = self
156 .inner
157 .state
158 .lock()
159 .expect("budget mutex poisoned")
160 .total_usd;
161 Some((max - total).max(0.0))
162 }
163
164 pub fn max_usd(&self) -> Option<f64> {
166 self.inner.config.max_usd
167 }
168
169 pub fn warn_at_usd(&self) -> Option<f64> {
171 self.inner.config.warn_at_usd
172 }
173
174 pub fn reset(&self) {
176 let mut state = self.inner.state.lock().expect("budget mutex poisoned");
177 state.total_usd = 0.0;
178 state.warned = false;
179 state.exceeded = false;
180 }
181}
182
183impl std::fmt::Debug for BudgetTracker {
184 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
185 let state = self.inner.state.lock().expect("budget mutex poisoned");
186 f.debug_struct("BudgetTracker")
187 .field("max_usd", &self.inner.config.max_usd)
188 .field("warn_at_usd", &self.inner.config.warn_at_usd)
189 .field("total_usd", &state.total_usd)
190 .field("warned", &state.warned)
191 .field("exceeded", &state.exceeded)
192 .finish()
193 }
194}
195
196#[derive(Default)]
198pub struct BudgetBuilder {
199 max_usd: Option<f64>,
200 warn_at_usd: Option<f64>,
201 on_warning: Option<Callback>,
202 on_exceeded: Option<Callback>,
203}
204
205impl BudgetBuilder {
206 pub fn max_usd(mut self, max: f64) -> Self {
209 self.max_usd = Some(max);
210 self
211 }
212
213 pub fn warn_at_usd(mut self, warn: f64) -> Self {
216 self.warn_at_usd = Some(warn);
217 self
218 }
219
220 pub fn on_warning<F>(mut self, f: F) -> Self
223 where
224 F: Fn(f64) + Send + Sync + 'static,
225 {
226 self.on_warning = Some(Arc::new(f));
227 self
228 }
229
230 pub fn on_exceeded<F>(mut self, f: F) -> Self
233 where
234 F: Fn(f64) + Send + Sync + 'static,
235 {
236 self.on_exceeded = Some(Arc::new(f));
237 self
238 }
239
240 pub fn build(self) -> BudgetTracker {
242 BudgetTracker {
243 inner: Arc::new(Inner {
244 config: Config {
245 max_usd: self.max_usd,
246 warn_at_usd: self.warn_at_usd,
247 on_warning: self.on_warning,
248 on_exceeded: self.on_exceeded,
249 },
250 state: Mutex::new(State::default()),
251 }),
252 }
253 }
254}
255
256#[cfg(test)]
257mod tests {
258 use super::*;
259 use std::sync::atomic::{AtomicUsize, Ordering};
260
261 #[test]
262 fn record_accumulates() {
263 let b = BudgetTracker::builder().build();
264 b.record(0.01);
265 b.record(0.02);
266 b.record(0.03);
267 assert!((b.total_usd() - 0.06).abs() < 1e-9);
268 }
269
270 #[test]
271 fn record_ignores_non_positive_and_non_finite() {
272 let b = BudgetTracker::builder().build();
273 b.record(0.0);
274 b.record(-0.5);
275 b.record(f64::NAN);
276 b.record(f64::INFINITY);
277 assert_eq!(b.total_usd(), 0.0);
278 }
279
280 #[test]
281 fn warn_callback_fires_once_at_threshold() {
282 let count = Arc::new(AtomicUsize::new(0));
283 let c = Arc::clone(&count);
284 let b = BudgetTracker::builder()
285 .warn_at_usd(0.10)
286 .on_warning(move |_| {
287 c.fetch_add(1, Ordering::SeqCst);
288 })
289 .build();
290
291 b.record(0.05);
292 assert_eq!(count.load(Ordering::SeqCst), 0);
293 b.record(0.06); assert_eq!(count.load(Ordering::SeqCst), 1);
295 b.record(0.20); assert_eq!(count.load(Ordering::SeqCst), 1);
297 }
298
299 #[test]
300 fn exceeded_callback_fires_once_at_threshold() {
301 let count = Arc::new(AtomicUsize::new(0));
302 let c = Arc::clone(&count);
303 let b = BudgetTracker::builder()
304 .max_usd(1.00)
305 .on_exceeded(move |_| {
306 c.fetch_add(1, Ordering::SeqCst);
307 })
308 .build();
309
310 b.record(0.50);
311 b.record(0.49);
312 assert_eq!(count.load(Ordering::SeqCst), 0);
313 b.record(0.02); assert_eq!(count.load(Ordering::SeqCst), 1);
315 b.record(0.50);
316 assert_eq!(count.load(Ordering::SeqCst), 1);
317 }
318
319 #[test]
320 fn check_errors_once_over_max() {
321 let b = BudgetTracker::builder().max_usd(0.10).build();
322 b.record(0.05);
323 assert!(b.check().is_ok());
324 b.record(0.05); match b.check() {
326 Err(Error::BudgetExceeded { total_usd, max_usd }) => {
327 assert!((total_usd - 0.10).abs() < 1e-9);
328 assert!((max_usd - 0.10).abs() < 1e-9);
329 }
330 other => panic!("expected BudgetExceeded, got {other:?}"),
331 }
332 }
333
334 #[test]
335 fn check_noop_without_max() {
336 let b = BudgetTracker::builder().build();
337 b.record(1_000.0);
338 assert!(b.check().is_ok());
339 }
340
341 #[test]
342 fn remaining_usd_clamps_at_zero() {
343 let b = BudgetTracker::builder().max_usd(1.00).build();
344 assert_eq!(b.remaining_usd(), Some(1.00));
345 b.record(0.40);
346 assert!((b.remaining_usd().unwrap() - 0.60).abs() < 1e-9);
347 b.record(10.00);
348 assert_eq!(b.remaining_usd(), Some(0.0));
349 }
350
351 #[test]
352 fn remaining_usd_none_without_max() {
353 let b = BudgetTracker::builder().build();
354 assert!(b.remaining_usd().is_none());
355 }
356
357 #[test]
358 fn reset_clears_total_and_rearms_callbacks() {
359 let warn = Arc::new(AtomicUsize::new(0));
360 let exc = Arc::new(AtomicUsize::new(0));
361 let w = Arc::clone(&warn);
362 let e = Arc::clone(&exc);
363 let b = BudgetTracker::builder()
364 .warn_at_usd(0.10)
365 .max_usd(0.20)
366 .on_warning(move |_| {
367 w.fetch_add(1, Ordering::SeqCst);
368 })
369 .on_exceeded(move |_| {
370 e.fetch_add(1, Ordering::SeqCst);
371 })
372 .build();
373
374 b.record(0.25);
375 assert_eq!(warn.load(Ordering::SeqCst), 1);
376 assert_eq!(exc.load(Ordering::SeqCst), 1);
377 assert!(b.check().is_err());
378
379 b.reset();
380 assert_eq!(b.total_usd(), 0.0);
381 assert!(b.check().is_ok());
382
383 b.record(0.25);
384 assert_eq!(warn.load(Ordering::SeqCst), 2);
385 assert_eq!(exc.load(Ordering::SeqCst), 2);
386 }
387
388 #[test]
389 fn clones_share_state() {
390 let a = BudgetTracker::builder().max_usd(1.00).build();
391 let b = a.clone();
392 a.record(0.60);
393 b.record(0.50);
394 assert!((a.total_usd() - 1.10).abs() < 1e-9);
395 assert!((b.total_usd() - 1.10).abs() < 1e-9);
396 assert!(a.check().is_err());
397 assert!(b.check().is_err());
398 }
399
400 #[test]
401 fn concurrent_record_preserves_total() {
402 use std::thread;
403
404 let b = BudgetTracker::builder().build();
405 let mut handles = Vec::new();
406 for _ in 0..8 {
407 let b = b.clone();
408 handles.push(thread::spawn(move || {
409 for _ in 0..1000 {
410 b.record(0.001);
411 }
412 }));
413 }
414 for h in handles {
415 h.join().unwrap();
416 }
417 assert!((b.total_usd() - 8.0).abs() < 1e-6);
418 }
419}