Skip to main content

agentik_sdk/tokens/
mod.rs

1use std::time::{Duration, SystemTime};
2use std::collections::HashMap;
3use std::sync::{Arc, Mutex};
4use serde::{Deserialize, Serialize};
5use crate::types::Usage;
6
7/// Enhanced token counting and cost estimation utilities
8pub struct TokenCounter {
9    /// Accumulated usage statistics
10    usage_stats: Arc<Mutex<UsageStats>>,
11    /// Model pricing information
12    pricing: ModelPricing,
13}
14
15/// Accumulated usage statistics across all requests
16#[derive(Debug, Clone)]
17pub struct UsageStats {
18    /// Total input tokens across all requests
19    pub total_input_tokens: u32,
20    /// Total output tokens across all requests
21    pub total_output_tokens: u32,
22    /// Total cache read tokens
23    pub total_cache_read_tokens: u32,
24    /// Total cache write tokens  
25    pub total_cache_write_tokens: u32,
26    /// Number of requests made
27    pub request_count: u32,
28    /// Total cost in USD
29    pub total_cost_usd: f64,
30    /// Usage by model
31    pub model_usage: HashMap<String, ModelUsage>,
32    /// Session start time
33    pub session_start: SystemTime,
34    /// Last request time
35    pub last_request: Option<SystemTime>,
36}
37
38/// Usage statistics for a specific model
39#[derive(Debug, Clone, Default)]
40pub struct ModelUsage {
41    /// Input tokens for this model
42    pub input_tokens: u32,
43    /// Output tokens for this model
44    pub output_tokens: u32,
45    /// Cache read tokens for this model
46    pub cache_read_tokens: u32,
47    /// Cache write tokens for this model
48    pub cache_write_tokens: u32,
49    /// Number of requests for this model
50    pub request_count: u32,
51    /// Total cost for this model
52    pub cost_usd: f64,
53}
54
55/// Real-time usage tracking for a single request
56#[derive(Debug, Clone)]
57pub struct RequestUsage {
58    /// Input tokens for this request
59    pub input_tokens: u32,
60    /// Output tokens accumulated so far
61    pub output_tokens: u32,
62    /// Cache read tokens for this request
63    pub cache_read_tokens: u32,
64    /// Cache write tokens for this request
65    pub cache_write_tokens: u32,
66    /// Model used for this request
67    pub model: String,
68    /// Request start time
69    pub start_time: SystemTime,
70    /// Request completion time
71    pub end_time: Option<SystemTime>,
72    /// Cost for this request
73    pub cost_usd: f64,
74}
75
76/// Model pricing information
77#[derive(Debug, Clone)]
78pub struct ModelPricing {
79    /// Pricing per model
80    pricing_table: HashMap<String, ModelPrice>,
81}
82
83/// Pricing for a specific model
84#[derive(Debug, Clone)]
85pub struct ModelPrice {
86    /// Cost per 1M input tokens in USD
87    pub input_cost_per_million: f64,
88    /// Cost per 1M output tokens in USD
89    pub output_cost_per_million: f64,
90    /// Cost per 1M cache read tokens in USD (if applicable)
91    pub cache_read_cost_per_million: Option<f64>,
92    /// Cost per 1M cache write tokens in USD (if applicable)
93    pub cache_write_cost_per_million: Option<f64>,
94}
95
96/// Cost breakdown for detailed analysis
97#[derive(Debug, Clone, Serialize, Deserialize)]
98pub struct CostBreakdown {
99    /// Input token cost
100    pub input_cost: f64,
101    /// Output token cost
102    pub output_cost: f64,
103    /// Cache read cost
104    pub cache_read_cost: f64,
105    /// Cache write cost
106    pub cache_write_cost: f64,
107    /// Total cost
108    pub total_cost: f64,
109    /// Cost per token
110    pub cost_per_token: f64,
111    /// Model used
112    pub model: String,
113}
114
115/// Usage summary for reporting
116#[derive(Debug, Clone, Serialize, Deserialize)]
117pub struct UsageSummary {
118    /// Total tokens (input + output)
119    pub total_tokens: u32,
120    /// Input tokens
121    pub input_tokens: u32,
122    /// Output tokens
123    pub output_tokens: u32,
124    /// Cache tokens
125    pub cache_tokens: u32,
126    /// Total cost
127    pub total_cost_usd: f64,
128    /// Average cost per token
129    pub avg_cost_per_token: f64,
130    /// Session duration
131    pub session_duration: Duration,
132    /// Requests per minute
133    pub requests_per_minute: f64,
134    /// Tokens per minute
135    pub tokens_per_minute: f64,
136    /// Cost per request
137    pub avg_cost_per_request: f64,
138}
139
140impl TokenCounter {
141    /// Create a new token counter with default pricing
142    pub fn new() -> Self {
143        Self {
144            usage_stats: Arc::new(Mutex::new(UsageStats::new())),
145            pricing: ModelPricing::default(),
146        }
147    }
148
149    /// Create a token counter with custom pricing
150    pub fn with_pricing(pricing: ModelPricing) -> Self {
151        Self {
152            usage_stats: Arc::new(Mutex::new(UsageStats::new())),
153            pricing,
154        }
155    }
156
157    /// Record usage from a completed request
158    pub fn record_usage(&self, model: &str, usage: &Usage) -> CostBreakdown {
159        let cost_breakdown = self.calculate_cost(model, usage);
160        
161        let mut stats = self.usage_stats.lock().unwrap();
162        stats.add_usage(model, usage, cost_breakdown.total_cost);
163        
164        cost_breakdown
165    }
166
167    /// Start tracking a new request
168    pub fn start_request(&self, model: &str) -> RequestUsage {
169        RequestUsage {
170            model: model.to_string(),
171            start_time: SystemTime::now(),
172            ..Default::default()
173        }
174    }
175
176    /// Calculate cost for a usage
177    pub fn calculate_cost(&self, model: &str, usage: &Usage) -> CostBreakdown {
178        let price = self.pricing.get_price(model);
179        
180        let input_cost = (usage.input_tokens as f64 / 1_000_000.0) * price.input_cost_per_million;
181        let output_cost = (usage.output_tokens as f64 / 1_000_000.0) * price.output_cost_per_million;
182        
183        let cache_read_tokens = usage.cache_read_input_tokens.unwrap_or(0);
184        let cache_write_tokens = usage.cache_creation_input_tokens.unwrap_or(0);
185        
186        let cache_read_cost = price.cache_read_cost_per_million
187            .map(|rate| (cache_read_tokens as f64 / 1_000_000.0) * rate)
188            .unwrap_or(0.0);
189            
190        let cache_write_cost = price.cache_write_cost_per_million
191            .map(|rate| (cache_write_tokens as f64 / 1_000_000.0) * rate)
192            .unwrap_or(0.0);
193        
194        let total_cost = input_cost + output_cost + cache_read_cost + cache_write_cost;
195        let total_tokens = usage.input_tokens + usage.output_tokens + cache_read_tokens + cache_write_tokens;
196        let cost_per_token = if total_tokens > 0 {
197            total_cost / total_tokens as f64
198        } else {
199            0.0
200        };
201
202        CostBreakdown {
203            input_cost,
204            output_cost,
205            cache_read_cost,
206            cache_write_cost,
207            total_cost,
208            cost_per_token,
209            model: model.to_string(),
210        }
211    }
212
213    /// Get current usage statistics
214    pub fn get_stats(&self) -> UsageStats {
215        self.usage_stats.lock().unwrap().clone()
216    }
217
218    /// Get usage summary
219    pub fn get_summary(&self) -> UsageSummary {
220        let stats = self.usage_stats.lock().unwrap();
221        stats.to_summary()
222    }
223
224    /// Reset usage statistics
225    pub fn reset(&self) {
226        let mut stats = self.usage_stats.lock().unwrap();
227        *stats = UsageStats::new();
228    }
229
230    /// Estimate cost for a request before sending
231    pub fn estimate_cost(&self, model: &str, estimated_input_tokens: u32, estimated_output_tokens: u32) -> f64 {
232        let usage = Usage {
233            input_tokens: estimated_input_tokens,
234            output_tokens: estimated_output_tokens,
235            cache_creation_input_tokens: None,
236            cache_read_input_tokens: None,
237            server_tool_use: None,
238            service_tier: None,
239        };
240        
241        self.calculate_cost(model, &usage).total_cost
242    }
243}
244
245impl UsageStats {
246    fn new() -> Self {
247        Self {
248            total_input_tokens: 0,
249            total_output_tokens: 0,
250            total_cache_read_tokens: 0,
251            total_cache_write_tokens: 0,
252            request_count: 0,
253            total_cost_usd: 0.0,
254            model_usage: HashMap::new(),
255            session_start: SystemTime::now(),
256            last_request: None,
257        }
258    }
259
260    fn add_usage(&mut self, model: &str, usage: &Usage, cost: f64) {
261        self.total_input_tokens += usage.input_tokens;
262        self.total_output_tokens += usage.output_tokens;
263        self.total_cache_read_tokens += usage.cache_read_input_tokens.unwrap_or(0);
264        self.total_cache_write_tokens += usage.cache_creation_input_tokens.unwrap_or(0);
265        self.total_cost_usd += cost;
266        self.request_count += 1;
267        self.last_request = Some(SystemTime::now());
268
269        // Update model-specific usage
270        let model_usage = self.model_usage.entry(model.to_string()).or_default();
271        model_usage.input_tokens += usage.input_tokens;
272        model_usage.output_tokens += usage.output_tokens;
273        model_usage.cache_read_tokens += usage.cache_read_input_tokens.unwrap_or(0);
274        model_usage.cache_write_tokens += usage.cache_creation_input_tokens.unwrap_or(0);
275        model_usage.cost_usd += cost;
276        model_usage.request_count += 1;
277    }
278
279    fn to_summary(&self) -> UsageSummary {
280        let total_tokens = self.total_input_tokens + self.total_output_tokens;
281        let cache_tokens = self.total_cache_read_tokens + self.total_cache_write_tokens;
282        
283        let session_duration = self.session_start.elapsed().unwrap_or(Duration::ZERO);
284        let session_minutes = session_duration.as_secs_f64() / 60.0;
285        
286        let requests_per_minute = if session_minutes > 0.0 {
287            self.request_count as f64 / session_minutes
288        } else {
289            0.0
290        };
291        
292        let tokens_per_minute = if session_minutes > 0.0 {
293            total_tokens as f64 / session_minutes
294        } else {
295            0.0
296        };
297        
298        let avg_cost_per_token = if total_tokens > 0 {
299            self.total_cost_usd / total_tokens as f64
300        } else {
301            0.0
302        };
303        
304        let avg_cost_per_request = if self.request_count > 0 {
305            self.total_cost_usd / self.request_count as f64
306        } else {
307            0.0
308        };
309
310        UsageSummary {
311            total_tokens,
312            input_tokens: self.total_input_tokens,
313            output_tokens: self.total_output_tokens,
314            cache_tokens,
315            total_cost_usd: self.total_cost_usd,
316            avg_cost_per_token,
317            session_duration,
318            requests_per_minute,
319            tokens_per_minute,
320            avg_cost_per_request,
321        }
322    }
323}
324
325impl ModelPricing {
326    /// Create default pricing with current Anthropic rates
327    pub fn default() -> Self {
328        let mut pricing_table = HashMap::new();
329        
330        // Claude 3.5 Sonnet (latest)
331        pricing_table.insert("claude-3-5-sonnet-latest".to_string(), ModelPrice {
332            input_cost_per_million: 3.00,
333            output_cost_per_million: 15.00,
334            cache_read_cost_per_million: Some(0.30),
335            cache_write_cost_per_million: Some(3.75),
336        });
337        
338        // Claude 3.5 Sonnet (20241022)
339        pricing_table.insert("claude-3-5-sonnet-20241022".to_string(), ModelPrice {
340            input_cost_per_million: 3.00,
341            output_cost_per_million: 15.00,
342            cache_read_cost_per_million: Some(0.30),
343            cache_write_cost_per_million: Some(3.75),
344        });
345        
346        // Claude 3.5 Haiku (latest)
347        pricing_table.insert("claude-3-5-haiku-latest".to_string(), ModelPrice {
348            input_cost_per_million: 1.00,
349            output_cost_per_million: 5.00,
350            cache_read_cost_per_million: Some(0.10),
351            cache_write_cost_per_million: Some(1.25),
352        });
353        
354        // Claude 3 Opus
355        pricing_table.insert("claude-3-opus-20240229".to_string(), ModelPrice {
356            input_cost_per_million: 15.00,
357            output_cost_per_million: 75.00,
358            cache_read_cost_per_million: Some(1.50),
359            cache_write_cost_per_million: Some(18.75),
360        });
361
362        Self { pricing_table }
363    }
364
365    /// Get pricing for a model
366    pub fn get_price(&self, model: &str) -> &ModelPrice {
367        self.pricing_table.get(model)
368            .unwrap_or_else(|| {
369                // Default to Claude 3.5 Sonnet pricing for unknown models
370                self.pricing_table.get("claude-3-5-sonnet-latest").unwrap()
371            })
372    }
373
374    /// Set pricing for a model
375    pub fn set_price(&mut self, model: &str, price: ModelPrice) {
376        self.pricing_table.insert(model.to_string(), price);
377    }
378}
379
380impl Default for TokenCounter {
381    fn default() -> Self {
382        Self::new()
383    }
384}
385
386impl Default for UsageStats {
387    fn default() -> Self {
388        Self::new()
389    }
390}
391
392impl Default for RequestUsage {
393    fn default() -> Self {
394        Self {
395            input_tokens: 0,
396            output_tokens: 0,
397            cache_read_tokens: 0,
398            cache_write_tokens: 0,
399            model: String::new(),
400            start_time: SystemTime::now(),
401            end_time: None,
402            cost_usd: 0.0,
403        }
404    }
405}
406
407impl std::fmt::Display for UsageSummary {
408    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
409        write!(f, 
410            "Usage Summary:\n\
411             Total Tokens: {} (Input: {}, Output: {}, Cache: {})\n\
412             Total Cost: ${:.4}\n\
413             Avg Cost/Token: ${:.6}\n\
414             Avg Cost/Request: ${:.4}\n\
415             Session Duration: {:.1}min\n\
416             Rate: {:.1} tokens/min, {:.1} requests/min",
417            self.total_tokens,
418            self.input_tokens,
419            self.output_tokens,
420            self.cache_tokens,
421            self.total_cost_usd,
422            self.avg_cost_per_token,
423            self.avg_cost_per_request,
424            self.session_duration.as_secs_f64() / 60.0,
425            self.tokens_per_minute,
426            self.requests_per_minute
427        )
428    }
429}
430
431impl std::fmt::Display for CostBreakdown {
432    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
433        write!(f,
434            "Cost Breakdown ({}):\n\
435             Input: ${:.4}\n\
436             Output: ${:.4}\n\
437             Cache Read: ${:.4}\n\
438             Cache Write: ${:.4}\n\
439             Total: ${:.4} (${:.6}/token)",
440            self.model,
441            self.input_cost,
442            self.output_cost,
443            self.cache_read_cost,
444            self.cache_write_cost,
445            self.total_cost,
446            self.cost_per_token
447        )
448    }
449}
450
451#[cfg(test)]
452mod tests {
453    use super::*;
454
455    #[test]
456    fn test_cost_calculation() {
457        let counter = TokenCounter::new();
458        let usage = Usage {
459            input_tokens: 1000,
460            output_tokens: 500,
461            cache_creation_input_tokens: Some(100),
462            cache_read_input_tokens: Some(200),
463            server_tool_use: None,
464            service_tier: None,
465        };
466        
467        let cost = counter.calculate_cost("claude-3-5-sonnet-latest", &usage);
468        
469        // Expected: (1000/1M * $3) + (500/1M * $15) + (200/1M * $0.30) + (100/1M * $3.75)
470        let expected = 0.003 + 0.0075 + 0.00006 + 0.000375;
471        assert!((cost.total_cost - expected).abs() < 0.0001);
472    }
473
474    #[test]
475    fn test_usage_tracking() {
476        let counter = TokenCounter::new();
477        let usage = Usage {
478            input_tokens: 100,
479            output_tokens: 50,
480            cache_creation_input_tokens: None,
481            cache_read_input_tokens: None,
482            server_tool_use: None,
483            service_tier: None,
484        };
485        
486        counter.record_usage("claude-3-5-sonnet-latest", &usage);
487        let stats = counter.get_stats();
488        
489        assert_eq!(stats.total_input_tokens, 100);
490        assert_eq!(stats.total_output_tokens, 50);
491        assert_eq!(stats.request_count, 1);
492        assert!(stats.total_cost_usd > 0.0);
493    }
494
495    #[test]
496    fn test_cost_estimation() {
497        let counter = TokenCounter::new();
498        let cost = counter.estimate_cost("claude-3-5-sonnet-latest", 1000, 500);
499        assert!(cost > 0.0);
500        
501        // Should be: (1000/1M * $3) + (500/1M * $15) = $0.0105
502        assert!((cost - 0.0105).abs() < 0.0001);
503    }
504}