Skip to main content

claude_wrapper/
budget.rs

1//! Cumulative USD budget tracking for Claude sessions.
2//!
3//! A [`BudgetTracker`] accumulates cost across turns and fires
4//! caller-supplied callbacks when configurable thresholds are crossed.
5//! When a `max_usd` ceiling is set, [`BudgetTracker::check`] returns
6//! [`Error::BudgetExceeded`] once
7//! the total is at or above the ceiling, giving callers a hard stop
8//! before the next CLI invocation.
9//!
10//! # Ownership and sharing
11//!
12//! `BudgetTracker` wraps its state in `Arc<Mutex<...>>` so clones share
13//! the same running total and fire-once flags. Attach one tracker to a
14//! [`Session`](crate::session::Session), or clone it across several
15//! sessions to enforce a fleet-wide ceiling.
16//!
17//! # Example
18//!
19//! ```no_run
20//! use std::sync::Arc;
21//! use claude_wrapper::{Claude, BudgetTracker};
22//! use claude_wrapper::session::Session;
23//!
24//! # async fn example() -> claude_wrapper::Result<()> {
25//! let budget = BudgetTracker::builder()
26//!     .max_usd(5.00)
27//!     .warn_at_usd(4.00)
28//!     .on_warning(|total| eprintln!("warning: ${total:.2} spent"))
29//!     .on_exceeded(|total| eprintln!("budget hit: ${total:.2}"))
30//!     .build();
31//!
32//! let claude = Arc::new(Claude::builder().build()?);
33//! let mut session = Session::new(claude).with_budget(budget.clone());
34//!
35//! session.send("hello").await?;
36//! println!("spent so far: ${:.4}", budget.total_usd());
37//! # Ok(())
38//! # }
39//! ```
40
41use std::sync::{Arc, Mutex};
42
43use crate::error::{Error, Result};
44
45type Callback = Arc<dyn Fn(f64) + Send + Sync>;
46
47struct Config {
48    max_usd: Option<f64>,
49    warn_at_usd: Option<f64>,
50    on_warning: Option<Callback>,
51    on_exceeded: Option<Callback>,
52}
53
54#[derive(Default)]
55struct State {
56    total_usd: f64,
57    warned: bool,
58    exceeded: bool,
59}
60
61struct Inner {
62    config: Config,
63    state: Mutex<State>,
64}
65
66/// Cumulative USD budget tracker with threshold callbacks.
67///
68/// See the [module docs](crate::budget) for the full design.
69#[derive(Clone)]
70pub struct BudgetTracker {
71    inner: Arc<Inner>,
72}
73
74impl BudgetTracker {
75    /// Start building a tracker.
76    pub fn builder() -> BudgetBuilder {
77        BudgetBuilder::default()
78    }
79
80    /// Record an additional cost in USD. Fires `on_warning` the first
81    /// time the running total reaches `warn_at_usd`, and `on_exceeded`
82    /// the first time it reaches `max_usd`.
83    pub fn record(&self, cost_usd: f64) {
84        if cost_usd <= 0.0 || !cost_usd.is_finite() {
85            return;
86        }
87
88        let (warn_fired, exceeded_fired, total) = {
89            let mut state = self.inner.state.lock().expect("budget mutex poisoned");
90            state.total_usd += cost_usd;
91
92            let warn_fired = match self.inner.config.warn_at_usd {
93                Some(threshold) if !state.warned && state.total_usd >= threshold => {
94                    state.warned = true;
95                    true
96                }
97                _ => false,
98            };
99
100            let exceeded_fired = match self.inner.config.max_usd {
101                Some(threshold) if !state.exceeded && state.total_usd >= threshold => {
102                    state.exceeded = true;
103                    true
104                }
105                _ => false,
106            };
107
108            (warn_fired, exceeded_fired, state.total_usd)
109        };
110
111        if warn_fired && let Some(cb) = &self.inner.config.on_warning {
112            cb(total);
113        }
114        if exceeded_fired && let Some(cb) = &self.inner.config.on_exceeded {
115            cb(total);
116        }
117    }
118
119    /// Return `Err(Error::BudgetExceeded)` if the running total is at
120    /// or above the configured `max_usd`. Returns `Ok(())` when no
121    /// ceiling is set.
122    pub fn check(&self) -> Result<()> {
123        let Some(max) = self.inner.config.max_usd else {
124            return Ok(());
125        };
126        let total = self
127            .inner
128            .state
129            .lock()
130            .expect("budget mutex poisoned")
131            .total_usd;
132        if total >= max {
133            Err(Error::BudgetExceeded {
134                total_usd: total,
135                max_usd: max,
136            })
137        } else {
138            Ok(())
139        }
140    }
141
142    /// Cumulative cost recorded so far, in USD.
143    pub fn total_usd(&self) -> f64 {
144        self.inner
145            .state
146            .lock()
147            .expect("budget mutex poisoned")
148            .total_usd
149    }
150
151    /// Remaining budget in USD, if a ceiling is set. Clamped at zero
152    /// once the ceiling is reached.
153    pub fn remaining_usd(&self) -> Option<f64> {
154        let max = self.inner.config.max_usd?;
155        let total = self
156            .inner
157            .state
158            .lock()
159            .expect("budget mutex poisoned")
160            .total_usd;
161        Some((max - total).max(0.0))
162    }
163
164    /// Configured ceiling, if any.
165    pub fn max_usd(&self) -> Option<f64> {
166        self.inner.config.max_usd
167    }
168
169    /// Configured warning threshold, if any.
170    pub fn warn_at_usd(&self) -> Option<f64> {
171        self.inner.config.warn_at_usd
172    }
173
174    /// Clear the running total and re-arm both callbacks.
175    pub fn reset(&self) {
176        let mut state = self.inner.state.lock().expect("budget mutex poisoned");
177        state.total_usd = 0.0;
178        state.warned = false;
179        state.exceeded = false;
180    }
181}
182
183impl std::fmt::Debug for BudgetTracker {
184    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
185        let state = self.inner.state.lock().expect("budget mutex poisoned");
186        f.debug_struct("BudgetTracker")
187            .field("max_usd", &self.inner.config.max_usd)
188            .field("warn_at_usd", &self.inner.config.warn_at_usd)
189            .field("total_usd", &state.total_usd)
190            .field("warned", &state.warned)
191            .field("exceeded", &state.exceeded)
192            .finish()
193    }
194}
195
196/// Builder for [`BudgetTracker`].
197#[derive(Default)]
198pub struct BudgetBuilder {
199    max_usd: Option<f64>,
200    warn_at_usd: Option<f64>,
201    on_warning: Option<Callback>,
202    on_exceeded: Option<Callback>,
203}
204
205impl BudgetBuilder {
206    /// Hard ceiling in USD. Once the running total hits this value,
207    /// [`BudgetTracker::check`] returns [`Error::BudgetExceeded`].
208    pub fn max_usd(mut self, max: f64) -> Self {
209        self.max_usd = Some(max);
210        self
211    }
212
213    /// Warning threshold in USD. [`BudgetBuilder::on_warning`] fires
214    /// the first time the running total reaches this value.
215    pub fn warn_at_usd(mut self, warn: f64) -> Self {
216        self.warn_at_usd = Some(warn);
217        self
218    }
219
220    /// Callback fired once when `warn_at_usd` is crossed. The argument
221    /// is the running total at the crossing.
222    pub fn on_warning<F>(mut self, f: F) -> Self
223    where
224        F: Fn(f64) + Send + Sync + 'static,
225    {
226        self.on_warning = Some(Arc::new(f));
227        self
228    }
229
230    /// Callback fired once when `max_usd` is crossed. The argument is
231    /// the running total at the crossing.
232    pub fn on_exceeded<F>(mut self, f: F) -> Self
233    where
234        F: Fn(f64) + Send + Sync + 'static,
235    {
236        self.on_exceeded = Some(Arc::new(f));
237        self
238    }
239
240    /// Finish construction.
241    pub fn build(self) -> BudgetTracker {
242        BudgetTracker {
243            inner: Arc::new(Inner {
244                config: Config {
245                    max_usd: self.max_usd,
246                    warn_at_usd: self.warn_at_usd,
247                    on_warning: self.on_warning,
248                    on_exceeded: self.on_exceeded,
249                },
250                state: Mutex::new(State::default()),
251            }),
252        }
253    }
254}
255
256#[cfg(test)]
257mod tests {
258    use super::*;
259    use std::sync::atomic::{AtomicUsize, Ordering};
260
261    #[test]
262    fn record_accumulates() {
263        let b = BudgetTracker::builder().build();
264        b.record(0.01);
265        b.record(0.02);
266        b.record(0.03);
267        assert!((b.total_usd() - 0.06).abs() < 1e-9);
268    }
269
270    #[test]
271    fn record_ignores_non_positive_and_non_finite() {
272        let b = BudgetTracker::builder().build();
273        b.record(0.0);
274        b.record(-0.5);
275        b.record(f64::NAN);
276        b.record(f64::INFINITY);
277        assert_eq!(b.total_usd(), 0.0);
278    }
279
280    #[test]
281    fn warn_callback_fires_once_at_threshold() {
282        let count = Arc::new(AtomicUsize::new(0));
283        let c = Arc::clone(&count);
284        let b = BudgetTracker::builder()
285            .warn_at_usd(0.10)
286            .on_warning(move |_| {
287                c.fetch_add(1, Ordering::SeqCst);
288            })
289            .build();
290
291        b.record(0.05);
292        assert_eq!(count.load(Ordering::SeqCst), 0);
293        b.record(0.06); // total 0.11 -> crosses
294        assert_eq!(count.load(Ordering::SeqCst), 1);
295        b.record(0.20); // further spend, no re-fire
296        assert_eq!(count.load(Ordering::SeqCst), 1);
297    }
298
299    #[test]
300    fn exceeded_callback_fires_once_at_threshold() {
301        let count = Arc::new(AtomicUsize::new(0));
302        let c = Arc::clone(&count);
303        let b = BudgetTracker::builder()
304            .max_usd(1.00)
305            .on_exceeded(move |_| {
306                c.fetch_add(1, Ordering::SeqCst);
307            })
308            .build();
309
310        b.record(0.50);
311        b.record(0.49);
312        assert_eq!(count.load(Ordering::SeqCst), 0);
313        b.record(0.02); // total 1.01 -> crosses
314        assert_eq!(count.load(Ordering::SeqCst), 1);
315        b.record(0.50);
316        assert_eq!(count.load(Ordering::SeqCst), 1);
317    }
318
319    #[test]
320    fn check_errors_once_over_max() {
321        let b = BudgetTracker::builder().max_usd(0.10).build();
322        b.record(0.05);
323        assert!(b.check().is_ok());
324        b.record(0.05); // total 0.10 -> at threshold
325        match b.check() {
326            Err(Error::BudgetExceeded { total_usd, max_usd }) => {
327                assert!((total_usd - 0.10).abs() < 1e-9);
328                assert!((max_usd - 0.10).abs() < 1e-9);
329            }
330            other => panic!("expected BudgetExceeded, got {other:?}"),
331        }
332    }
333
334    #[test]
335    fn check_noop_without_max() {
336        let b = BudgetTracker::builder().build();
337        b.record(1_000.0);
338        assert!(b.check().is_ok());
339    }
340
341    #[test]
342    fn remaining_usd_clamps_at_zero() {
343        let b = BudgetTracker::builder().max_usd(1.00).build();
344        assert_eq!(b.remaining_usd(), Some(1.00));
345        b.record(0.40);
346        assert!((b.remaining_usd().unwrap() - 0.60).abs() < 1e-9);
347        b.record(10.00);
348        assert_eq!(b.remaining_usd(), Some(0.0));
349    }
350
351    #[test]
352    fn remaining_usd_none_without_max() {
353        let b = BudgetTracker::builder().build();
354        assert!(b.remaining_usd().is_none());
355    }
356
357    #[test]
358    fn reset_clears_total_and_rearms_callbacks() {
359        let warn = Arc::new(AtomicUsize::new(0));
360        let exc = Arc::new(AtomicUsize::new(0));
361        let w = Arc::clone(&warn);
362        let e = Arc::clone(&exc);
363        let b = BudgetTracker::builder()
364            .warn_at_usd(0.10)
365            .max_usd(0.20)
366            .on_warning(move |_| {
367                w.fetch_add(1, Ordering::SeqCst);
368            })
369            .on_exceeded(move |_| {
370                e.fetch_add(1, Ordering::SeqCst);
371            })
372            .build();
373
374        b.record(0.25);
375        assert_eq!(warn.load(Ordering::SeqCst), 1);
376        assert_eq!(exc.load(Ordering::SeqCst), 1);
377        assert!(b.check().is_err());
378
379        b.reset();
380        assert_eq!(b.total_usd(), 0.0);
381        assert!(b.check().is_ok());
382
383        b.record(0.25);
384        assert_eq!(warn.load(Ordering::SeqCst), 2);
385        assert_eq!(exc.load(Ordering::SeqCst), 2);
386    }
387
388    #[test]
389    fn clones_share_state() {
390        let a = BudgetTracker::builder().max_usd(1.00).build();
391        let b = a.clone();
392        a.record(0.60);
393        b.record(0.50);
394        assert!((a.total_usd() - 1.10).abs() < 1e-9);
395        assert!((b.total_usd() - 1.10).abs() < 1e-9);
396        assert!(a.check().is_err());
397        assert!(b.check().is_err());
398    }
399
400    #[test]
401    fn concurrent_record_preserves_total() {
402        use std::thread;
403
404        let b = BudgetTracker::builder().build();
405        let mut handles = Vec::new();
406        for _ in 0..8 {
407            let b = b.clone();
408            handles.push(thread::spawn(move || {
409                for _ in 0..1000 {
410                    b.record(0.001);
411                }
412            }));
413        }
414        for h in handles {
415            h.join().unwrap();
416        }
417        assert!((b.total_usd() - 8.0).abs() < 1e-6);
418    }
419}