claude_agent/budget/
tracker.rs

1//! Budget tracking for individual agent sessions.
2
3use std::sync::atomic::{AtomicU64, Ordering};
4
5use super::pricing::{PricingTable, global_pricing_table};
6
7/// Action to take when budget is exceeded.
8#[derive(Debug, Clone, Default, PartialEq)]
9pub enum OnExceed {
10    /// Stop execution before the next API call.
11    #[default]
12    StopBeforeNext,
13    /// Log a warning and continue execution.
14    WarnAndContinue,
15    /// Switch to a cheaper model when budget is exceeded.
16    FallbackModel(String),
17}
18
19impl OnExceed {
20    pub fn fallback(model: impl Into<String>) -> Self {
21        Self::FallbackModel(model.into())
22    }
23
24    pub fn fallback_model(&self) -> Option<&str> {
25        match self {
26            Self::FallbackModel(model) => Some(model),
27            _ => None,
28        }
29    }
30}
31
32#[derive(Debug)]
33pub struct BudgetTracker {
34    max_cost_usd: Option<f64>,
35    used_cost_bits: AtomicU64,
36    on_exceed: OnExceed,
37    pricing: &'static PricingTable,
38}
39
40impl Default for BudgetTracker {
41    fn default() -> Self {
42        Self {
43            max_cost_usd: None,
44            used_cost_bits: AtomicU64::new(0),
45            on_exceed: OnExceed::default(),
46            pricing: global_pricing_table(),
47        }
48    }
49}
50
51impl Clone for BudgetTracker {
52    fn clone(&self) -> Self {
53        Self {
54            max_cost_usd: self.max_cost_usd,
55            used_cost_bits: AtomicU64::new(self.used_cost_bits.load(Ordering::Relaxed)),
56            on_exceed: self.on_exceed.clone(),
57            pricing: self.pricing,
58        }
59    }
60}
61
62impl BudgetTracker {
63    pub fn new(max_cost_usd: f64) -> Self {
64        Self {
65            max_cost_usd: Some(max_cost_usd),
66            ..Default::default()
67        }
68    }
69
70    pub fn with_on_exceed(mut self, on_exceed: OnExceed) -> Self {
71        self.on_exceed = on_exceed;
72        self
73    }
74
75    pub fn unlimited() -> Self {
76        Self::default()
77    }
78
79    pub fn record(&self, model: &str, usage: &crate::types::Usage) -> f64 {
80        let cost = self.pricing.calculate(model, usage);
81        let cost_bits = (cost * 1_000_000.0) as u64;
82        self.used_cost_bits.fetch_add(cost_bits, Ordering::Relaxed);
83        cost
84    }
85
86    fn used_cost_usd_internal(&self) -> f64 {
87        self.used_cost_bits.load(Ordering::Relaxed) as f64 / 1_000_000.0
88    }
89
90    pub fn check(&self) -> BudgetStatus {
91        let used = self.used_cost_usd_internal();
92        match self.max_cost_usd {
93            None => BudgetStatus::Unlimited { used },
94            Some(max) if used >= max => BudgetStatus::Exceeded {
95                used,
96                limit: max,
97                overage: used - max,
98            },
99            Some(max) => BudgetStatus::WithinBudget {
100                used,
101                limit: max,
102                remaining: max - used,
103            },
104        }
105    }
106
107    pub fn should_stop(&self) -> bool {
108        matches!(self.on_exceed, OnExceed::StopBeforeNext)
109            && matches!(self.check(), BudgetStatus::Exceeded { .. })
110    }
111
112    pub fn should_fallback(&self) -> Option<&str> {
113        if matches!(self.check(), BudgetStatus::Exceeded { .. }) {
114            self.on_exceed.fallback_model()
115        } else {
116            None
117        }
118    }
119
120    pub fn used_cost_usd(&self) -> f64 {
121        self.used_cost_usd_internal()
122    }
123
124    pub fn remaining(&self) -> Option<f64> {
125        self.max_cost_usd
126            .map(|max| (max - self.used_cost_usd_internal()).max(0.0))
127    }
128
129    pub fn on_exceed(&self) -> &OnExceed {
130        &self.on_exceed
131    }
132}
133
134#[derive(Debug, Clone)]
135pub enum BudgetStatus {
136    Unlimited {
137        used: f64,
138    },
139    WithinBudget {
140        used: f64,
141        limit: f64,
142        remaining: f64,
143    },
144    Exceeded {
145        used: f64,
146        limit: f64,
147        overage: f64,
148    },
149}
150
151impl BudgetStatus {
152    pub fn is_exceeded(&self) -> bool {
153        matches!(self, Self::Exceeded { .. })
154    }
155
156    pub fn used(&self) -> f64 {
157        match self {
158            Self::Unlimited { used } => *used,
159            Self::WithinBudget { used, .. } => *used,
160            Self::Exceeded { used, .. } => *used,
161        }
162    }
163}
164
165#[cfg(test)]
166mod tests {
167    use super::*;
168    use crate::types::Usage;
169
170    #[test]
171    fn test_budget_tracking() {
172        let tracker = BudgetTracker::new(10.0);
173
174        let usage = Usage {
175            input_tokens: 100_000,
176            output_tokens: 50_000,
177            ..Default::default()
178        };
179
180        // Sonnet: 0.1M * $3 + 0.05M * $15 = $0.30 + $0.75 = $1.05
181        let cost = tracker.record("claude-sonnet-4-5", &usage);
182        assert!((cost - 1.05).abs() < 0.01);
183        assert!(!tracker.should_stop());
184
185        // Add more usage to exceed budget
186        for _ in 0..10 {
187            tracker.record("claude-sonnet-4-5", &usage);
188        }
189
190        assert!(tracker.should_stop());
191        assert!(matches!(tracker.check(), BudgetStatus::Exceeded { .. }));
192    }
193
194    #[test]
195    fn test_unlimited_budget() {
196        let tracker = BudgetTracker::unlimited();
197
198        let usage = Usage {
199            input_tokens: 1_000_000,
200            output_tokens: 1_000_000,
201            ..Default::default()
202        };
203
204        for _ in 0..100 {
205            tracker.record("claude-opus-4-5", &usage);
206        }
207
208        assert!(!tracker.should_stop());
209        assert!(matches!(tracker.check(), BudgetStatus::Unlimited { .. }));
210    }
211
212    #[test]
213    fn test_warn_and_continue() {
214        let tracker = BudgetTracker::new(1.0).with_on_exceed(OnExceed::WarnAndContinue);
215
216        let usage = Usage {
217            input_tokens: 1_000_000,
218            output_tokens: 1_000_000,
219            ..Default::default()
220        };
221
222        tracker.record("claude-sonnet-4-5", &usage);
223
224        assert!(matches!(tracker.check(), BudgetStatus::Exceeded { .. }));
225        assert!(!tracker.should_stop()); // WarnAndContinue doesn't stop
226    }
227}