Skip to main content

cc_sdk/
token_tracker.rs

1//! Token usage tracking and budget management
2//!
3//! This module provides utilities for monitoring token consumption and managing budgets
4//! to help control costs when using Claude Code.
5
6use std::sync::Arc;
7use tokio::sync::RwLock;
8use tracing::warn;
9
10/// Token usage statistics tracker
11#[derive(Debug, Clone, Default)]
12pub struct TokenUsageTracker {
13    /// Total input tokens consumed
14    pub total_input_tokens: u64,
15    /// Total output tokens consumed
16    pub total_output_tokens: u64,
17    /// Total cache read input tokens (prompt caching hits)
18    pub cache_read_input_tokens: u64,
19    /// Total cache creation input tokens (prompt caching writes)
20    pub cache_creation_input_tokens: u64,
21    /// Total cost in USD
22    pub total_cost_usd: f64,
23    /// Number of sessions/queries completed
24    pub session_count: usize,
25}
26
27impl TokenUsageTracker {
28    /// Create a new empty tracker
29    pub fn new() -> Self {
30        Self::default()
31    }
32
33    /// Get total tokens (input + output)
34    pub fn total_tokens(&self) -> u64 {
35        self.total_input_tokens + self.total_output_tokens
36    }
37
38    /// Get average tokens per session
39    pub fn avg_tokens_per_session(&self) -> f64 {
40        if self.session_count == 0 {
41            0.0
42        } else {
43            self.total_tokens() as f64 / self.session_count as f64
44        }
45    }
46
47    /// Get average cost per session
48    pub fn avg_cost_per_session(&self) -> f64 {
49        if self.session_count == 0 {
50            0.0
51        } else {
52            self.total_cost_usd / self.session_count as f64
53        }
54    }
55
56    /// Update statistics with new usage data
57    pub fn update(&mut self, input_tokens: u64, output_tokens: u64, cost_usd: f64) {
58        self.total_input_tokens += input_tokens;
59        self.total_output_tokens += output_tokens;
60        self.total_cost_usd += cost_usd;
61        self.session_count += 1;
62    }
63
64    /// Update statistics with full usage data including cache tokens
65    pub fn update_with_cache(
66        &mut self,
67        input_tokens: u64,
68        output_tokens: u64,
69        cache_read_input_tokens: u64,
70        cache_creation_input_tokens: u64,
71        cost_usd: f64,
72    ) {
73        self.total_input_tokens += input_tokens;
74        self.total_output_tokens += output_tokens;
75        self.cache_read_input_tokens += cache_read_input_tokens;
76        self.cache_creation_input_tokens += cache_creation_input_tokens;
77        self.total_cost_usd += cost_usd;
78        self.session_count += 1;
79    }
80
81    /// Get total cache tokens (read + creation)
82    pub fn total_cache_tokens(&self) -> u64 {
83        self.cache_read_input_tokens + self.cache_creation_input_tokens
84    }
85
86    /// Reset all statistics to zero
87    pub fn reset(&mut self) {
88        self.total_input_tokens = 0;
89        self.total_output_tokens = 0;
90        self.cache_read_input_tokens = 0;
91        self.cache_creation_input_tokens = 0;
92        self.total_cost_usd = 0.0;
93        self.session_count = 0;
94    }
95}
96
97/// Budget limits and alerts
98#[derive(Debug, Clone)]
99pub struct BudgetLimit {
100    /// Maximum cost in USD (None = unlimited)
101    pub max_cost_usd: Option<f64>,
102    /// Maximum total tokens (None = unlimited)
103    pub max_tokens: Option<u64>,
104    /// Threshold percentage for warning (0.0-1.0, default 0.8 for 80%)
105    pub warning_threshold: f64,
106}
107
108impl Default for BudgetLimit {
109    fn default() -> Self {
110        Self {
111            max_cost_usd: None,
112            max_tokens: None,
113            warning_threshold: 0.8,
114        }
115    }
116}
117
118impl BudgetLimit {
119    /// Create a new budget limit with cost cap
120    pub fn with_cost(max_cost_usd: f64) -> Self {
121        Self {
122            max_cost_usd: Some(max_cost_usd),
123            ..Default::default()
124        }
125    }
126
127    /// Create a new budget limit with token cap
128    pub fn with_tokens(max_tokens: u64) -> Self {
129        Self {
130            max_tokens: Some(max_tokens),
131            ..Default::default()
132        }
133    }
134
135    /// Create a new budget limit with both caps
136    pub fn with_both(max_cost_usd: f64, max_tokens: u64) -> Self {
137        Self {
138            max_cost_usd: Some(max_cost_usd),
139            max_tokens: Some(max_tokens),
140            warning_threshold: 0.8,
141        }
142    }
143
144    /// Set warning threshold (0.0-1.0)
145    pub fn with_warning_threshold(mut self, threshold: f64) -> Self {
146        self.warning_threshold = threshold.clamp(0.0, 1.0);
147        self
148    }
149
150    /// Check if usage exceeds limits
151    pub fn check_limits(&self, usage: &TokenUsageTracker) -> BudgetStatus {
152        let mut status = BudgetStatus::Ok;
153
154        // Check cost limit
155        if let Some(max_cost) = self.max_cost_usd {
156            let cost_ratio = usage.total_cost_usd / max_cost;
157
158            if cost_ratio >= 1.0 {
159                status = BudgetStatus::Exceeded;
160            } else if cost_ratio >= self.warning_threshold {
161                status = BudgetStatus::Warning {
162                    current_ratio: cost_ratio,
163                    message: format!(
164                        "Cost usage at {:.1}% (${:.2}/${:.2})",
165                        cost_ratio * 100.0,
166                        usage.total_cost_usd,
167                        max_cost
168                    ),
169                };
170            }
171        }
172
173        // Check token limit
174        if let Some(max_tokens) = self.max_tokens {
175            let token_ratio = usage.total_tokens() as f64 / max_tokens as f64;
176
177            if token_ratio >= 1.0 {
178                status = BudgetStatus::Exceeded;
179            } else if token_ratio >= self.warning_threshold {
180                // If already warning from cost, keep the exceeded state
181                if !matches!(status, BudgetStatus::Exceeded) {
182                    status = BudgetStatus::Warning {
183                        current_ratio: token_ratio,
184                        message: format!(
185                            "Token usage at {:.1}% ({}/{})",
186                            token_ratio * 100.0,
187                            usage.total_tokens(),
188                            max_tokens
189                        ),
190                    };
191                }
192            }
193        }
194
195        status
196    }
197}
198
199/// Budget status result
200#[derive(Debug, Clone, PartialEq)]
201pub enum BudgetStatus {
202    /// Usage is within limits
203    Ok,
204    /// Usage exceeds warning threshold
205    Warning {
206        /// Current usage ratio (0.0-1.0)
207        current_ratio: f64,
208        /// Warning message
209        message: String,
210    },
211    /// Usage exceeds limits
212    Exceeded,
213}
214
215/// Callback type for budget warnings
216pub type BudgetWarningCallback = Arc<dyn Fn(&str) + Send + Sync>;
217
218/// Budget manager that combines tracker and limits
219#[derive(Clone)]
220pub struct BudgetManager {
221    tracker: Arc<RwLock<TokenUsageTracker>>,
222    limit: Arc<RwLock<Option<BudgetLimit>>>,
223    on_warning: Arc<RwLock<Option<BudgetWarningCallback>>>,
224    warning_fired: Arc<RwLock<bool>>,
225}
226
227impl BudgetManager {
228    /// Create a new budget manager
229    pub fn new() -> Self {
230        Self {
231            tracker: Arc::new(RwLock::new(TokenUsageTracker::new())),
232            limit: Arc::new(RwLock::new(None)),
233            on_warning: Arc::new(RwLock::new(None)),
234            warning_fired: Arc::new(RwLock::new(false)),
235        }
236    }
237
238    /// Set budget limit
239    pub async fn set_limit(&self, limit: BudgetLimit) {
240        *self.limit.write().await = Some(limit);
241        *self.warning_fired.write().await = false;
242    }
243
244    /// Set warning callback
245    pub async fn set_warning_callback(&self, callback: BudgetWarningCallback) {
246        *self.on_warning.write().await = Some(callback);
247    }
248
249    /// Clear budget limit
250    pub async fn clear_limit(&self) {
251        *self.limit.write().await = None;
252        *self.warning_fired.write().await = false;
253    }
254
255    /// Get current usage statistics
256    pub async fn get_usage(&self) -> TokenUsageTracker {
257        self.tracker.read().await.clone()
258    }
259
260    /// Update usage and check limits
261    pub async fn update_usage(&self, input_tokens: u64, output_tokens: u64, cost_usd: f64) {
262        // Update tracker
263        self.tracker.write().await.update(input_tokens, output_tokens, cost_usd);
264
265        // Check limits
266        if let Some(limit) = self.limit.read().await.as_ref() {
267            let usage = self.tracker.read().await.clone();
268            let status = limit.check_limits(&usage);
269
270            match status {
271                BudgetStatus::Warning { message, .. } => {
272                    let mut fired = self.warning_fired.write().await;
273                    if !*fired {
274                        *fired = true;
275                        warn!("Budget warning: {}", message);
276
277                        if let Some(callback) = self.on_warning.read().await.as_ref() {
278                            callback(&message);
279                        }
280                    }
281                }
282                BudgetStatus::Exceeded => {
283                    warn!("Budget exceeded! Current usage: {} tokens, ${:.2}",
284                          usage.total_tokens(), usage.total_cost_usd);
285
286                    if let Some(callback) = self.on_warning.read().await.as_ref() {
287                        callback("Budget limit exceeded");
288                    }
289                }
290                BudgetStatus::Ok => {
291                    // Reset warning flag if usage dropped below threshold
292                    *self.warning_fired.write().await = false;
293                }
294            }
295        }
296    }
297
298    /// Reset usage statistics
299    pub async fn reset_usage(&self) {
300        self.tracker.write().await.reset();
301        *self.warning_fired.write().await = false;
302    }
303
304    /// Check if budget is exceeded
305    pub async fn is_exceeded(&self) -> bool {
306        if let Some(limit) = self.limit.read().await.as_ref() {
307            let usage = self.tracker.read().await.clone();
308            matches!(limit.check_limits(&usage), BudgetStatus::Exceeded)
309        } else {
310            false
311        }
312    }
313}
314
315impl Default for BudgetManager {
316    fn default() -> Self {
317        Self::new()
318    }
319}
320
321#[cfg(test)]
322mod tests {
323    use super::*;
324
325    #[test]
326    fn test_tracker_basics() {
327        let mut tracker = TokenUsageTracker::new();
328        assert_eq!(tracker.total_tokens(), 0);
329        assert_eq!(tracker.total_cost_usd, 0.0);
330
331        tracker.update(100, 200, 0.05);
332        assert_eq!(tracker.total_input_tokens, 100);
333        assert_eq!(tracker.total_output_tokens, 200);
334        assert_eq!(tracker.total_tokens(), 300);
335        assert_eq!(tracker.total_cost_usd, 0.05);
336        assert_eq!(tracker.session_count, 1);
337
338        tracker.update(50, 100, 0.02);
339        assert_eq!(tracker.total_tokens(), 450);
340        assert_eq!(tracker.total_cost_usd, 0.07);
341        assert_eq!(tracker.session_count, 2);
342    }
343
344    #[test]
345    fn test_budget_limits() {
346        let limit = BudgetLimit::with_cost(1.0).with_warning_threshold(0.8);
347
348        let mut tracker = TokenUsageTracker::new();
349        tracker.update(100, 200, 0.5);
350        assert!(matches!(limit.check_limits(&tracker), BudgetStatus::Ok));
351
352        tracker.update(100, 200, 0.35);
353        assert!(matches!(limit.check_limits(&tracker), BudgetStatus::Warning { .. }));
354
355        tracker.update(100, 200, 0.2);
356        assert!(matches!(limit.check_limits(&tracker), BudgetStatus::Exceeded));
357    }
358
359    #[tokio::test]
360    async fn test_budget_manager() {
361        let manager = BudgetManager::new();
362
363        manager.set_limit(BudgetLimit::with_tokens(1000)).await;
364        manager.update_usage(300, 200, 0.05).await;
365
366        let usage = manager.get_usage().await;
367        assert_eq!(usage.total_tokens(), 500);
368
369        assert!(!manager.is_exceeded().await);
370
371        manager.update_usage(300, 300, 0.05).await;
372        assert!(manager.is_exceeded().await);
373    }
374
375    #[test]
376    fn test_tracker_cache_tokens() {
377        let mut tracker = TokenUsageTracker::new();
378        assert_eq!(tracker.cache_read_input_tokens, 0);
379        assert_eq!(tracker.cache_creation_input_tokens, 0);
380        assert_eq!(tracker.total_cache_tokens(), 0);
381
382        tracker.update_with_cache(100, 200, 500, 150, 0.03);
383        assert_eq!(tracker.total_input_tokens, 100);
384        assert_eq!(tracker.total_output_tokens, 200);
385        assert_eq!(tracker.cache_read_input_tokens, 500);
386        assert_eq!(tracker.cache_creation_input_tokens, 150);
387        assert_eq!(tracker.total_cache_tokens(), 650);
388        assert_eq!(tracker.session_count, 1);
389
390        // Accumulate
391        tracker.update_with_cache(50, 100, 300, 0, 0.02);
392        assert_eq!(tracker.cache_read_input_tokens, 800);
393        assert_eq!(tracker.cache_creation_input_tokens, 150);
394        assert_eq!(tracker.total_cache_tokens(), 950);
395        assert_eq!(tracker.session_count, 2);
396    }
397
398    #[test]
399    fn test_tracker_reset_clears_cache() {
400        let mut tracker = TokenUsageTracker::new();
401        tracker.update_with_cache(100, 200, 500, 150, 0.03);
402        assert_eq!(tracker.total_cache_tokens(), 650);
403
404        tracker.reset();
405        assert_eq!(tracker.cache_read_input_tokens, 0);
406        assert_eq!(tracker.cache_creation_input_tokens, 0);
407        assert_eq!(tracker.total_cache_tokens(), 0);
408        assert_eq!(tracker.total_tokens(), 0);
409    }
410
411    #[test]
412    fn test_update_and_update_with_cache_both_increment_sessions() {
413        let mut tracker = TokenUsageTracker::new();
414        tracker.update(10, 20, 0.01);
415        tracker.update_with_cache(10, 20, 100, 50, 0.01);
416        assert_eq!(tracker.session_count, 2);
417        assert_eq!(tracker.total_tokens(), 60);
418        // Regular update doesn't touch cache fields
419        assert_eq!(tracker.cache_read_input_tokens, 100);
420    }
421}