cc_sdk/
token_tracker.rs

1//! Token usage tracking and budget management
2//!
3//! This module provides utilities for monitoring token consumption and managing budgets
4//! to help control costs when using Claude Code.
5
6use std::sync::Arc;
7use tokio::sync::RwLock;
8use tracing::warn;
9
10/// Token usage statistics tracker
11#[derive(Debug, Clone, Default)]
12pub struct TokenUsageTracker {
13    /// Total input tokens consumed
14    pub total_input_tokens: u64,
15    /// Total output tokens consumed
16    pub total_output_tokens: u64,
17    /// Total cost in USD
18    pub total_cost_usd: f64,
19    /// Number of sessions/queries completed
20    pub session_count: usize,
21}
22
23impl TokenUsageTracker {
24    /// Create a new empty tracker
25    pub fn new() -> Self {
26        Self::default()
27    }
28
29    /// Get total tokens (input + output)
30    pub fn total_tokens(&self) -> u64 {
31        self.total_input_tokens + self.total_output_tokens
32    }
33
34    /// Get average tokens per session
35    pub fn avg_tokens_per_session(&self) -> f64 {
36        if self.session_count == 0 {
37            0.0
38        } else {
39            self.total_tokens() as f64 / self.session_count as f64
40        }
41    }
42
43    /// Get average cost per session
44    pub fn avg_cost_per_session(&self) -> f64 {
45        if self.session_count == 0 {
46            0.0
47        } else {
48            self.total_cost_usd / self.session_count as f64
49        }
50    }
51
52    /// Update statistics with new usage data
53    pub fn update(&mut self, input_tokens: u64, output_tokens: u64, cost_usd: f64) {
54        self.total_input_tokens += input_tokens;
55        self.total_output_tokens += output_tokens;
56        self.total_cost_usd += cost_usd;
57        self.session_count += 1;
58    }
59
60    /// Reset all statistics to zero
61    pub fn reset(&mut self) {
62        self.total_input_tokens = 0;
63        self.total_output_tokens = 0;
64        self.total_cost_usd = 0.0;
65        self.session_count = 0;
66    }
67}
68
69/// Budget limits and alerts
70#[derive(Debug, Clone)]
71pub struct BudgetLimit {
72    /// Maximum cost in USD (None = unlimited)
73    pub max_cost_usd: Option<f64>,
74    /// Maximum total tokens (None = unlimited)
75    pub max_tokens: Option<u64>,
76    /// Threshold percentage for warning (0.0-1.0, default 0.8 for 80%)
77    pub warning_threshold: f64,
78}
79
80impl Default for BudgetLimit {
81    fn default() -> Self {
82        Self {
83            max_cost_usd: None,
84            max_tokens: None,
85            warning_threshold: 0.8,
86        }
87    }
88}
89
90impl BudgetLimit {
91    /// Create a new budget limit with cost cap
92    pub fn with_cost(max_cost_usd: f64) -> Self {
93        Self {
94            max_cost_usd: Some(max_cost_usd),
95            ..Default::default()
96        }
97    }
98
99    /// Create a new budget limit with token cap
100    pub fn with_tokens(max_tokens: u64) -> Self {
101        Self {
102            max_tokens: Some(max_tokens),
103            ..Default::default()
104        }
105    }
106
107    /// Create a new budget limit with both caps
108    pub fn with_both(max_cost_usd: f64, max_tokens: u64) -> Self {
109        Self {
110            max_cost_usd: Some(max_cost_usd),
111            max_tokens: Some(max_tokens),
112            warning_threshold: 0.8,
113        }
114    }
115
116    /// Set warning threshold (0.0-1.0)
117    pub fn with_warning_threshold(mut self, threshold: f64) -> Self {
118        self.warning_threshold = threshold.clamp(0.0, 1.0);
119        self
120    }
121
122    /// Check if usage exceeds limits
123    pub fn check_limits(&self, usage: &TokenUsageTracker) -> BudgetStatus {
124        let mut status = BudgetStatus::Ok;
125
126        // Check cost limit
127        if let Some(max_cost) = self.max_cost_usd {
128            let cost_ratio = usage.total_cost_usd / max_cost;
129
130            if cost_ratio >= 1.0 {
131                status = BudgetStatus::Exceeded;
132            } else if cost_ratio >= self.warning_threshold {
133                status = BudgetStatus::Warning {
134                    current_ratio: cost_ratio,
135                    message: format!(
136                        "Cost usage at {:.1}% (${:.2}/${:.2})",
137                        cost_ratio * 100.0,
138                        usage.total_cost_usd,
139                        max_cost
140                    ),
141                };
142            }
143        }
144
145        // Check token limit
146        if let Some(max_tokens) = self.max_tokens {
147            let token_ratio = usage.total_tokens() as f64 / max_tokens as f64;
148
149            if token_ratio >= 1.0 {
150                status = BudgetStatus::Exceeded;
151            } else if token_ratio >= self.warning_threshold {
152                // If already warning from cost, keep the exceeded state
153                if !matches!(status, BudgetStatus::Exceeded) {
154                    status = BudgetStatus::Warning {
155                        current_ratio: token_ratio,
156                        message: format!(
157                            "Token usage at {:.1}% ({}/{})",
158                            token_ratio * 100.0,
159                            usage.total_tokens(),
160                            max_tokens
161                        ),
162                    };
163                }
164            }
165        }
166
167        status
168    }
169}
170
171/// Budget status result
172#[derive(Debug, Clone, PartialEq)]
173pub enum BudgetStatus {
174    /// Usage is within limits
175    Ok,
176    /// Usage exceeds warning threshold
177    Warning {
178        /// Current usage ratio (0.0-1.0)
179        current_ratio: f64,
180        /// Warning message
181        message: String,
182    },
183    /// Usage exceeds limits
184    Exceeded,
185}
186
187/// Callback type for budget warnings
188pub type BudgetWarningCallback = Arc<dyn Fn(&str) + Send + Sync>;
189
190/// Budget manager that combines tracker and limits
191#[derive(Clone)]
192pub struct BudgetManager {
193    tracker: Arc<RwLock<TokenUsageTracker>>,
194    limit: Arc<RwLock<Option<BudgetLimit>>>,
195    on_warning: Arc<RwLock<Option<BudgetWarningCallback>>>,
196    warning_fired: Arc<RwLock<bool>>,
197}
198
199impl BudgetManager {
200    /// Create a new budget manager
201    pub fn new() -> Self {
202        Self {
203            tracker: Arc::new(RwLock::new(TokenUsageTracker::new())),
204            limit: Arc::new(RwLock::new(None)),
205            on_warning: Arc::new(RwLock::new(None)),
206            warning_fired: Arc::new(RwLock::new(false)),
207        }
208    }
209
210    /// Set budget limit
211    pub async fn set_limit(&self, limit: BudgetLimit) {
212        *self.limit.write().await = Some(limit);
213        *self.warning_fired.write().await = false;
214    }
215
216    /// Set warning callback
217    pub async fn set_warning_callback(&self, callback: BudgetWarningCallback) {
218        *self.on_warning.write().await = Some(callback);
219    }
220
221    /// Clear budget limit
222    pub async fn clear_limit(&self) {
223        *self.limit.write().await = None;
224        *self.warning_fired.write().await = false;
225    }
226
227    /// Get current usage statistics
228    pub async fn get_usage(&self) -> TokenUsageTracker {
229        self.tracker.read().await.clone()
230    }
231
232    /// Update usage and check limits
233    pub async fn update_usage(&self, input_tokens: u64, output_tokens: u64, cost_usd: f64) {
234        // Update tracker
235        self.tracker.write().await.update(input_tokens, output_tokens, cost_usd);
236
237        // Check limits
238        if let Some(limit) = self.limit.read().await.as_ref() {
239            let usage = self.tracker.read().await.clone();
240            let status = limit.check_limits(&usage);
241
242            match status {
243                BudgetStatus::Warning { message, .. } => {
244                    let mut fired = self.warning_fired.write().await;
245                    if !*fired {
246                        *fired = true;
247                        warn!("Budget warning: {}", message);
248
249                        if let Some(callback) = self.on_warning.read().await.as_ref() {
250                            callback(&message);
251                        }
252                    }
253                }
254                BudgetStatus::Exceeded => {
255                    warn!("Budget exceeded! Current usage: {} tokens, ${:.2}",
256                          usage.total_tokens(), usage.total_cost_usd);
257
258                    if let Some(callback) = self.on_warning.read().await.as_ref() {
259                        callback("Budget limit exceeded");
260                    }
261                }
262                BudgetStatus::Ok => {
263                    // Reset warning flag if usage dropped below threshold
264                    *self.warning_fired.write().await = false;
265                }
266            }
267        }
268    }
269
270    /// Reset usage statistics
271    pub async fn reset_usage(&self) {
272        self.tracker.write().await.reset();
273        *self.warning_fired.write().await = false;
274    }
275
276    /// Check if budget is exceeded
277    pub async fn is_exceeded(&self) -> bool {
278        if let Some(limit) = self.limit.read().await.as_ref() {
279            let usage = self.tracker.read().await.clone();
280            matches!(limit.check_limits(&usage), BudgetStatus::Exceeded)
281        } else {
282            false
283        }
284    }
285}
286
287impl Default for BudgetManager {
288    fn default() -> Self {
289        Self::new()
290    }
291}
292
293#[cfg(test)]
294mod tests {
295    use super::*;
296
297    #[test]
298    fn test_tracker_basics() {
299        let mut tracker = TokenUsageTracker::new();
300        assert_eq!(tracker.total_tokens(), 0);
301        assert_eq!(tracker.total_cost_usd, 0.0);
302
303        tracker.update(100, 200, 0.05);
304        assert_eq!(tracker.total_input_tokens, 100);
305        assert_eq!(tracker.total_output_tokens, 200);
306        assert_eq!(tracker.total_tokens(), 300);
307        assert_eq!(tracker.total_cost_usd, 0.05);
308        assert_eq!(tracker.session_count, 1);
309
310        tracker.update(50, 100, 0.02);
311        assert_eq!(tracker.total_tokens(), 450);
312        assert_eq!(tracker.total_cost_usd, 0.07);
313        assert_eq!(tracker.session_count, 2);
314    }
315
316    #[test]
317    fn test_budget_limits() {
318        let limit = BudgetLimit::with_cost(1.0).with_warning_threshold(0.8);
319
320        let mut tracker = TokenUsageTracker::new();
321        tracker.update(100, 200, 0.5);
322        assert!(matches!(limit.check_limits(&tracker), BudgetStatus::Ok));
323
324        tracker.update(100, 200, 0.35);
325        assert!(matches!(limit.check_limits(&tracker), BudgetStatus::Warning { .. }));
326
327        tracker.update(100, 200, 0.2);
328        assert!(matches!(limit.check_limits(&tracker), BudgetStatus::Exceeded));
329    }
330
331    #[tokio::test]
332    async fn test_budget_manager() {
333        let manager = BudgetManager::new();
334
335        manager.set_limit(BudgetLimit::with_tokens(1000)).await;
336        manager.update_usage(300, 200, 0.05).await;
337
338        let usage = manager.get_usage().await;
339        assert_eq!(usage.total_tokens(), 500);
340
341        assert!(!manager.is_exceeded().await);
342
343        manager.update_usage(300, 300, 0.05).await;
344        assert!(manager.is_exceeded().await);
345    }
346}