claude_agent/budget/
tracker.rs1use std::sync::atomic::{AtomicU64, Ordering};
4
5use super::pricing::{PricingTable, global_pricing_table};
6
7#[derive(Debug, Clone, Default, PartialEq)]
9pub enum OnExceed {
10 #[default]
12 StopBeforeNext,
13 WarnAndContinue,
15 FallbackModel(String),
17}
18
19impl OnExceed {
20 pub fn fallback(model: impl Into<String>) -> Self {
21 Self::FallbackModel(model.into())
22 }
23
24 pub fn fallback_model(&self) -> Option<&str> {
25 match self {
26 Self::FallbackModel(model) => Some(model),
27 _ => None,
28 }
29 }
30}
31
32#[derive(Debug)]
33pub struct BudgetTracker {
34 max_cost_usd: Option<f64>,
35 used_cost_bits: AtomicU64,
36 on_exceed: OnExceed,
37 pricing: &'static PricingTable,
38}
39
40impl Default for BudgetTracker {
41 fn default() -> Self {
42 Self {
43 max_cost_usd: None,
44 used_cost_bits: AtomicU64::new(0),
45 on_exceed: OnExceed::default(),
46 pricing: global_pricing_table(),
47 }
48 }
49}
50
51impl Clone for BudgetTracker {
52 fn clone(&self) -> Self {
53 Self {
54 max_cost_usd: self.max_cost_usd,
55 used_cost_bits: AtomicU64::new(self.used_cost_bits.load(Ordering::Relaxed)),
56 on_exceed: self.on_exceed.clone(),
57 pricing: self.pricing,
58 }
59 }
60}
61
62impl BudgetTracker {
63 pub fn new(max_cost_usd: f64) -> Self {
64 Self {
65 max_cost_usd: Some(max_cost_usd),
66 ..Default::default()
67 }
68 }
69
70 pub fn with_on_exceed(mut self, on_exceed: OnExceed) -> Self {
71 self.on_exceed = on_exceed;
72 self
73 }
74
75 pub fn unlimited() -> Self {
76 Self::default()
77 }
78
79 pub fn record(&self, model: &str, usage: &crate::types::Usage) -> f64 {
80 let cost = self.pricing.calculate(model, usage);
81 let cost_bits = (cost * 1_000_000.0) as u64;
82 self.used_cost_bits.fetch_add(cost_bits, Ordering::Relaxed);
83 cost
84 }
85
86 fn used_cost_usd_internal(&self) -> f64 {
87 self.used_cost_bits.load(Ordering::Relaxed) as f64 / 1_000_000.0
88 }
89
90 pub fn check(&self) -> BudgetStatus {
91 let used = self.used_cost_usd_internal();
92 match self.max_cost_usd {
93 None => BudgetStatus::Unlimited { used },
94 Some(max) if used >= max => BudgetStatus::Exceeded {
95 used,
96 limit: max,
97 overage: used - max,
98 },
99 Some(max) => BudgetStatus::WithinBudget {
100 used,
101 limit: max,
102 remaining: max - used,
103 },
104 }
105 }
106
107 pub fn should_stop(&self) -> bool {
108 matches!(self.on_exceed, OnExceed::StopBeforeNext)
109 && matches!(self.check(), BudgetStatus::Exceeded { .. })
110 }
111
112 pub fn should_fallback(&self) -> Option<&str> {
113 if matches!(self.check(), BudgetStatus::Exceeded { .. }) {
114 self.on_exceed.fallback_model()
115 } else {
116 None
117 }
118 }
119
120 pub fn used_cost_usd(&self) -> f64 {
121 self.used_cost_usd_internal()
122 }
123
124 pub fn remaining(&self) -> Option<f64> {
125 self.max_cost_usd
126 .map(|max| (max - self.used_cost_usd_internal()).max(0.0))
127 }
128
129 pub fn on_exceed(&self) -> &OnExceed {
130 &self.on_exceed
131 }
132}
133
134#[derive(Debug, Clone)]
135pub enum BudgetStatus {
136 Unlimited {
137 used: f64,
138 },
139 WithinBudget {
140 used: f64,
141 limit: f64,
142 remaining: f64,
143 },
144 Exceeded {
145 used: f64,
146 limit: f64,
147 overage: f64,
148 },
149}
150
151impl BudgetStatus {
152 pub fn is_exceeded(&self) -> bool {
153 matches!(self, Self::Exceeded { .. })
154 }
155
156 pub fn used(&self) -> f64 {
157 match self {
158 Self::Unlimited { used } => *used,
159 Self::WithinBudget { used, .. } => *used,
160 Self::Exceeded { used, .. } => *used,
161 }
162 }
163}
164
165#[cfg(test)]
166mod tests {
167 use super::*;
168 use crate::types::Usage;
169
170 #[test]
171 fn test_budget_tracking() {
172 let tracker = BudgetTracker::new(10.0);
173
174 let usage = Usage {
175 input_tokens: 100_000,
176 output_tokens: 50_000,
177 ..Default::default()
178 };
179
180 let cost = tracker.record("claude-sonnet-4-5", &usage);
182 assert!((cost - 1.05).abs() < 0.01);
183 assert!(!tracker.should_stop());
184
185 for _ in 0..10 {
187 tracker.record("claude-sonnet-4-5", &usage);
188 }
189
190 assert!(tracker.should_stop());
191 assert!(matches!(tracker.check(), BudgetStatus::Exceeded { .. }));
192 }
193
194 #[test]
195 fn test_unlimited_budget() {
196 let tracker = BudgetTracker::unlimited();
197
198 let usage = Usage {
199 input_tokens: 1_000_000,
200 output_tokens: 1_000_000,
201 ..Default::default()
202 };
203
204 for _ in 0..100 {
205 tracker.record("claude-opus-4-5", &usage);
206 }
207
208 assert!(!tracker.should_stop());
209 assert!(matches!(tracker.check(), BudgetStatus::Unlimited { .. }));
210 }
211
212 #[test]
213 fn test_warn_and_continue() {
214 let tracker = BudgetTracker::new(1.0).with_on_exceed(OnExceed::WarnAndContinue);
215
216 let usage = Usage {
217 input_tokens: 1_000_000,
218 output_tokens: 1_000_000,
219 ..Default::default()
220 };
221
222 tracker.record("claude-sonnet-4-5", &usage);
223
224 assert!(matches!(tracker.check(), BudgetStatus::Exceeded { .. }));
225 assert!(!tracker.should_stop()); }
227}