1use std::sync::Arc;
7use tokio::sync::RwLock;
8use tracing::warn;
9
10#[derive(Debug, Clone, Default)]
12pub struct TokenUsageTracker {
13 pub total_input_tokens: u64,
15 pub total_output_tokens: u64,
17 pub cache_read_input_tokens: u64,
19 pub cache_creation_input_tokens: u64,
21 pub total_cost_usd: f64,
23 pub session_count: usize,
25}
26
27impl TokenUsageTracker {
28 pub fn new() -> Self {
30 Self::default()
31 }
32
33 pub fn total_tokens(&self) -> u64 {
35 self.total_input_tokens + self.total_output_tokens
36 }
37
38 pub fn avg_tokens_per_session(&self) -> f64 {
40 if self.session_count == 0 {
41 0.0
42 } else {
43 self.total_tokens() as f64 / self.session_count as f64
44 }
45 }
46
47 pub fn avg_cost_per_session(&self) -> f64 {
49 if self.session_count == 0 {
50 0.0
51 } else {
52 self.total_cost_usd / self.session_count as f64
53 }
54 }
55
56 pub fn update(&mut self, input_tokens: u64, output_tokens: u64, cost_usd: f64) {
58 self.total_input_tokens += input_tokens;
59 self.total_output_tokens += output_tokens;
60 self.total_cost_usd += cost_usd;
61 self.session_count += 1;
62 }
63
64 pub fn update_with_cache(
66 &mut self,
67 input_tokens: u64,
68 output_tokens: u64,
69 cache_read_input_tokens: u64,
70 cache_creation_input_tokens: u64,
71 cost_usd: f64,
72 ) {
73 self.total_input_tokens += input_tokens;
74 self.total_output_tokens += output_tokens;
75 self.cache_read_input_tokens += cache_read_input_tokens;
76 self.cache_creation_input_tokens += cache_creation_input_tokens;
77 self.total_cost_usd += cost_usd;
78 self.session_count += 1;
79 }
80
81 pub fn total_cache_tokens(&self) -> u64 {
83 self.cache_read_input_tokens + self.cache_creation_input_tokens
84 }
85
86 pub fn reset(&mut self) {
88 self.total_input_tokens = 0;
89 self.total_output_tokens = 0;
90 self.cache_read_input_tokens = 0;
91 self.cache_creation_input_tokens = 0;
92 self.total_cost_usd = 0.0;
93 self.session_count = 0;
94 }
95}
96
97#[derive(Debug, Clone)]
99pub struct BudgetLimit {
100 pub max_cost_usd: Option<f64>,
102 pub max_tokens: Option<u64>,
104 pub warning_threshold: f64,
106}
107
108impl Default for BudgetLimit {
109 fn default() -> Self {
110 Self {
111 max_cost_usd: None,
112 max_tokens: None,
113 warning_threshold: 0.8,
114 }
115 }
116}
117
118impl BudgetLimit {
119 pub fn with_cost(max_cost_usd: f64) -> Self {
121 Self {
122 max_cost_usd: Some(max_cost_usd),
123 ..Default::default()
124 }
125 }
126
127 pub fn with_tokens(max_tokens: u64) -> Self {
129 Self {
130 max_tokens: Some(max_tokens),
131 ..Default::default()
132 }
133 }
134
135 pub fn with_both(max_cost_usd: f64, max_tokens: u64) -> Self {
137 Self {
138 max_cost_usd: Some(max_cost_usd),
139 max_tokens: Some(max_tokens),
140 warning_threshold: 0.8,
141 }
142 }
143
144 pub fn with_warning_threshold(mut self, threshold: f64) -> Self {
146 self.warning_threshold = threshold.clamp(0.0, 1.0);
147 self
148 }
149
150 pub fn check_limits(&self, usage: &TokenUsageTracker) -> BudgetStatus {
152 let mut status = BudgetStatus::Ok;
153
154 if let Some(max_cost) = self.max_cost_usd {
156 let cost_ratio = usage.total_cost_usd / max_cost;
157
158 if cost_ratio >= 1.0 {
159 status = BudgetStatus::Exceeded;
160 } else if cost_ratio >= self.warning_threshold {
161 status = BudgetStatus::Warning {
162 current_ratio: cost_ratio,
163 message: format!(
164 "Cost usage at {:.1}% (${:.2}/${:.2})",
165 cost_ratio * 100.0,
166 usage.total_cost_usd,
167 max_cost
168 ),
169 };
170 }
171 }
172
173 if let Some(max_tokens) = self.max_tokens {
175 let token_ratio = usage.total_tokens() as f64 / max_tokens as f64;
176
177 if token_ratio >= 1.0 {
178 status = BudgetStatus::Exceeded;
179 } else if token_ratio >= self.warning_threshold {
180 if !matches!(status, BudgetStatus::Exceeded) {
182 status = BudgetStatus::Warning {
183 current_ratio: token_ratio,
184 message: format!(
185 "Token usage at {:.1}% ({}/{})",
186 token_ratio * 100.0,
187 usage.total_tokens(),
188 max_tokens
189 ),
190 };
191 }
192 }
193 }
194
195 status
196 }
197}
198
199#[derive(Debug, Clone, PartialEq)]
201pub enum BudgetStatus {
202 Ok,
204 Warning {
206 current_ratio: f64,
208 message: String,
210 },
211 Exceeded,
213}
214
215pub type BudgetWarningCallback = Arc<dyn Fn(&str) + Send + Sync>;
217
218#[derive(Clone)]
220pub struct BudgetManager {
221 tracker: Arc<RwLock<TokenUsageTracker>>,
222 limit: Arc<RwLock<Option<BudgetLimit>>>,
223 on_warning: Arc<RwLock<Option<BudgetWarningCallback>>>,
224 warning_fired: Arc<RwLock<bool>>,
225}
226
227impl BudgetManager {
228 pub fn new() -> Self {
230 Self {
231 tracker: Arc::new(RwLock::new(TokenUsageTracker::new())),
232 limit: Arc::new(RwLock::new(None)),
233 on_warning: Arc::new(RwLock::new(None)),
234 warning_fired: Arc::new(RwLock::new(false)),
235 }
236 }
237
238 pub async fn set_limit(&self, limit: BudgetLimit) {
240 *self.limit.write().await = Some(limit);
241 *self.warning_fired.write().await = false;
242 }
243
244 pub async fn set_warning_callback(&self, callback: BudgetWarningCallback) {
246 *self.on_warning.write().await = Some(callback);
247 }
248
249 pub async fn clear_limit(&self) {
251 *self.limit.write().await = None;
252 *self.warning_fired.write().await = false;
253 }
254
255 pub async fn get_usage(&self) -> TokenUsageTracker {
257 self.tracker.read().await.clone()
258 }
259
260 pub async fn update_usage(&self, input_tokens: u64, output_tokens: u64, cost_usd: f64) {
262 self.tracker.write().await.update(input_tokens, output_tokens, cost_usd);
264
265 if let Some(limit) = self.limit.read().await.as_ref() {
267 let usage = self.tracker.read().await.clone();
268 let status = limit.check_limits(&usage);
269
270 match status {
271 BudgetStatus::Warning { message, .. } => {
272 let mut fired = self.warning_fired.write().await;
273 if !*fired {
274 *fired = true;
275 warn!("Budget warning: {}", message);
276
277 if let Some(callback) = self.on_warning.read().await.as_ref() {
278 callback(&message);
279 }
280 }
281 }
282 BudgetStatus::Exceeded => {
283 warn!("Budget exceeded! Current usage: {} tokens, ${:.2}",
284 usage.total_tokens(), usage.total_cost_usd);
285
286 if let Some(callback) = self.on_warning.read().await.as_ref() {
287 callback("Budget limit exceeded");
288 }
289 }
290 BudgetStatus::Ok => {
291 *self.warning_fired.write().await = false;
293 }
294 }
295 }
296 }
297
298 pub async fn reset_usage(&self) {
300 self.tracker.write().await.reset();
301 *self.warning_fired.write().await = false;
302 }
303
304 pub async fn is_exceeded(&self) -> bool {
306 if let Some(limit) = self.limit.read().await.as_ref() {
307 let usage = self.tracker.read().await.clone();
308 matches!(limit.check_limits(&usage), BudgetStatus::Exceeded)
309 } else {
310 false
311 }
312 }
313}
314
315impl Default for BudgetManager {
316 fn default() -> Self {
317 Self::new()
318 }
319}
320
321#[cfg(test)]
322mod tests {
323 use super::*;
324
325 #[test]
326 fn test_tracker_basics() {
327 let mut tracker = TokenUsageTracker::new();
328 assert_eq!(tracker.total_tokens(), 0);
329 assert_eq!(tracker.total_cost_usd, 0.0);
330
331 tracker.update(100, 200, 0.05);
332 assert_eq!(tracker.total_input_tokens, 100);
333 assert_eq!(tracker.total_output_tokens, 200);
334 assert_eq!(tracker.total_tokens(), 300);
335 assert_eq!(tracker.total_cost_usd, 0.05);
336 assert_eq!(tracker.session_count, 1);
337
338 tracker.update(50, 100, 0.02);
339 assert_eq!(tracker.total_tokens(), 450);
340 assert_eq!(tracker.total_cost_usd, 0.07);
341 assert_eq!(tracker.session_count, 2);
342 }
343
344 #[test]
345 fn test_budget_limits() {
346 let limit = BudgetLimit::with_cost(1.0).with_warning_threshold(0.8);
347
348 let mut tracker = TokenUsageTracker::new();
349 tracker.update(100, 200, 0.5);
350 assert!(matches!(limit.check_limits(&tracker), BudgetStatus::Ok));
351
352 tracker.update(100, 200, 0.35);
353 assert!(matches!(limit.check_limits(&tracker), BudgetStatus::Warning { .. }));
354
355 tracker.update(100, 200, 0.2);
356 assert!(matches!(limit.check_limits(&tracker), BudgetStatus::Exceeded));
357 }
358
359 #[tokio::test]
360 async fn test_budget_manager() {
361 let manager = BudgetManager::new();
362
363 manager.set_limit(BudgetLimit::with_tokens(1000)).await;
364 manager.update_usage(300, 200, 0.05).await;
365
366 let usage = manager.get_usage().await;
367 assert_eq!(usage.total_tokens(), 500);
368
369 assert!(!manager.is_exceeded().await);
370
371 manager.update_usage(300, 300, 0.05).await;
372 assert!(manager.is_exceeded().await);
373 }
374
375 #[test]
376 fn test_tracker_cache_tokens() {
377 let mut tracker = TokenUsageTracker::new();
378 assert_eq!(tracker.cache_read_input_tokens, 0);
379 assert_eq!(tracker.cache_creation_input_tokens, 0);
380 assert_eq!(tracker.total_cache_tokens(), 0);
381
382 tracker.update_with_cache(100, 200, 500, 150, 0.03);
383 assert_eq!(tracker.total_input_tokens, 100);
384 assert_eq!(tracker.total_output_tokens, 200);
385 assert_eq!(tracker.cache_read_input_tokens, 500);
386 assert_eq!(tracker.cache_creation_input_tokens, 150);
387 assert_eq!(tracker.total_cache_tokens(), 650);
388 assert_eq!(tracker.session_count, 1);
389
390 tracker.update_with_cache(50, 100, 300, 0, 0.02);
392 assert_eq!(tracker.cache_read_input_tokens, 800);
393 assert_eq!(tracker.cache_creation_input_tokens, 150);
394 assert_eq!(tracker.total_cache_tokens(), 950);
395 assert_eq!(tracker.session_count, 2);
396 }
397
398 #[test]
399 fn test_tracker_reset_clears_cache() {
400 let mut tracker = TokenUsageTracker::new();
401 tracker.update_with_cache(100, 200, 500, 150, 0.03);
402 assert_eq!(tracker.total_cache_tokens(), 650);
403
404 tracker.reset();
405 assert_eq!(tracker.cache_read_input_tokens, 0);
406 assert_eq!(tracker.cache_creation_input_tokens, 0);
407 assert_eq!(tracker.total_cache_tokens(), 0);
408 assert_eq!(tracker.total_tokens(), 0);
409 }
410
411 #[test]
412 fn test_update_and_update_with_cache_both_increment_sessions() {
413 let mut tracker = TokenUsageTracker::new();
414 tracker.update(10, 20, 0.01);
415 tracker.update_with_cache(10, 20, 100, 50, 0.01);
416 assert_eq!(tracker.session_count, 2);
417 assert_eq!(tracker.total_tokens(), 60);
418 assert_eq!(tracker.cache_read_input_tokens, 100);
420 }
421}