1use std::time::{Duration, SystemTime};
2use std::collections::HashMap;
3use std::sync::{Arc, Mutex};
4use serde::{Deserialize, Serialize};
5use crate::types::Usage;
6
7pub struct TokenCounter {
9 usage_stats: Arc<Mutex<UsageStats>>,
11 pricing: ModelPricing,
13}
14
15#[derive(Debug, Clone)]
17pub struct UsageStats {
18 pub total_input_tokens: u32,
20 pub total_output_tokens: u32,
22 pub total_cache_read_tokens: u32,
24 pub total_cache_write_tokens: u32,
26 pub request_count: u32,
28 pub total_cost_usd: f64,
30 pub model_usage: HashMap<String, ModelUsage>,
32 pub session_start: SystemTime,
34 pub last_request: Option<SystemTime>,
36}
37
38#[derive(Debug, Clone, Default)]
40pub struct ModelUsage {
41 pub input_tokens: u32,
43 pub output_tokens: u32,
45 pub cache_read_tokens: u32,
47 pub cache_write_tokens: u32,
49 pub request_count: u32,
51 pub cost_usd: f64,
53}
54
55#[derive(Debug, Clone)]
57pub struct RequestUsage {
58 pub input_tokens: u32,
60 pub output_tokens: u32,
62 pub cache_read_tokens: u32,
64 pub cache_write_tokens: u32,
66 pub model: String,
68 pub start_time: SystemTime,
70 pub end_time: Option<SystemTime>,
72 pub cost_usd: f64,
74}
75
76#[derive(Debug, Clone)]
78pub struct ModelPricing {
79 pricing_table: HashMap<String, ModelPrice>,
81}
82
83#[derive(Debug, Clone)]
85pub struct ModelPrice {
86 pub input_cost_per_million: f64,
88 pub output_cost_per_million: f64,
90 pub cache_read_cost_per_million: Option<f64>,
92 pub cache_write_cost_per_million: Option<f64>,
94}
95
96#[derive(Debug, Clone, Serialize, Deserialize)]
98pub struct CostBreakdown {
99 pub input_cost: f64,
101 pub output_cost: f64,
103 pub cache_read_cost: f64,
105 pub cache_write_cost: f64,
107 pub total_cost: f64,
109 pub cost_per_token: f64,
111 pub model: String,
113}
114
115#[derive(Debug, Clone, Serialize, Deserialize)]
117pub struct UsageSummary {
118 pub total_tokens: u32,
120 pub input_tokens: u32,
122 pub output_tokens: u32,
124 pub cache_tokens: u32,
126 pub total_cost_usd: f64,
128 pub avg_cost_per_token: f64,
130 pub session_duration: Duration,
132 pub requests_per_minute: f64,
134 pub tokens_per_minute: f64,
136 pub avg_cost_per_request: f64,
138}
139
140impl TokenCounter {
141 pub fn new() -> Self {
143 Self {
144 usage_stats: Arc::new(Mutex::new(UsageStats::new())),
145 pricing: ModelPricing::default(),
146 }
147 }
148
149 pub fn with_pricing(pricing: ModelPricing) -> Self {
151 Self {
152 usage_stats: Arc::new(Mutex::new(UsageStats::new())),
153 pricing,
154 }
155 }
156
157 pub fn record_usage(&self, model: &str, usage: &Usage) -> CostBreakdown {
159 let cost_breakdown = self.calculate_cost(model, usage);
160
161 let mut stats = self.usage_stats.lock().unwrap();
162 stats.add_usage(model, usage, cost_breakdown.total_cost);
163
164 cost_breakdown
165 }
166
167 pub fn start_request(&self, model: &str) -> RequestUsage {
169 RequestUsage {
170 model: model.to_string(),
171 start_time: SystemTime::now(),
172 ..Default::default()
173 }
174 }
175
176 pub fn calculate_cost(&self, model: &str, usage: &Usage) -> CostBreakdown {
178 let price = self.pricing.get_price(model);
179
180 let input_cost = (usage.input_tokens as f64 / 1_000_000.0) * price.input_cost_per_million;
181 let output_cost = (usage.output_tokens as f64 / 1_000_000.0) * price.output_cost_per_million;
182
183 let cache_read_tokens = usage.cache_read_input_tokens.unwrap_or(0);
184 let cache_write_tokens = usage.cache_creation_input_tokens.unwrap_or(0);
185
186 let cache_read_cost = price.cache_read_cost_per_million
187 .map(|rate| (cache_read_tokens as f64 / 1_000_000.0) * rate)
188 .unwrap_or(0.0);
189
190 let cache_write_cost = price.cache_write_cost_per_million
191 .map(|rate| (cache_write_tokens as f64 / 1_000_000.0) * rate)
192 .unwrap_or(0.0);
193
194 let total_cost = input_cost + output_cost + cache_read_cost + cache_write_cost;
195 let total_tokens = usage.input_tokens + usage.output_tokens + cache_read_tokens + cache_write_tokens;
196 let cost_per_token = if total_tokens > 0 {
197 total_cost / total_tokens as f64
198 } else {
199 0.0
200 };
201
202 CostBreakdown {
203 input_cost,
204 output_cost,
205 cache_read_cost,
206 cache_write_cost,
207 total_cost,
208 cost_per_token,
209 model: model.to_string(),
210 }
211 }
212
213 pub fn get_stats(&self) -> UsageStats {
215 self.usage_stats.lock().unwrap().clone()
216 }
217
218 pub fn get_summary(&self) -> UsageSummary {
220 let stats = self.usage_stats.lock().unwrap();
221 stats.to_summary()
222 }
223
224 pub fn reset(&self) {
226 let mut stats = self.usage_stats.lock().unwrap();
227 *stats = UsageStats::new();
228 }
229
230 pub fn estimate_cost(&self, model: &str, estimated_input_tokens: u32, estimated_output_tokens: u32) -> f64 {
232 let usage = Usage {
233 input_tokens: estimated_input_tokens,
234 output_tokens: estimated_output_tokens,
235 cache_creation_input_tokens: None,
236 cache_read_input_tokens: None,
237 server_tool_use: None,
238 service_tier: None,
239 };
240
241 self.calculate_cost(model, &usage).total_cost
242 }
243}
244
245impl UsageStats {
246 fn new() -> Self {
247 Self {
248 total_input_tokens: 0,
249 total_output_tokens: 0,
250 total_cache_read_tokens: 0,
251 total_cache_write_tokens: 0,
252 request_count: 0,
253 total_cost_usd: 0.0,
254 model_usage: HashMap::new(),
255 session_start: SystemTime::now(),
256 last_request: None,
257 }
258 }
259
260 fn add_usage(&mut self, model: &str, usage: &Usage, cost: f64) {
261 self.total_input_tokens += usage.input_tokens;
262 self.total_output_tokens += usage.output_tokens;
263 self.total_cache_read_tokens += usage.cache_read_input_tokens.unwrap_or(0);
264 self.total_cache_write_tokens += usage.cache_creation_input_tokens.unwrap_or(0);
265 self.total_cost_usd += cost;
266 self.request_count += 1;
267 self.last_request = Some(SystemTime::now());
268
269 let model_usage = self.model_usage.entry(model.to_string()).or_default();
271 model_usage.input_tokens += usage.input_tokens;
272 model_usage.output_tokens += usage.output_tokens;
273 model_usage.cache_read_tokens += usage.cache_read_input_tokens.unwrap_or(0);
274 model_usage.cache_write_tokens += usage.cache_creation_input_tokens.unwrap_or(0);
275 model_usage.cost_usd += cost;
276 model_usage.request_count += 1;
277 }
278
279 fn to_summary(&self) -> UsageSummary {
280 let total_tokens = self.total_input_tokens + self.total_output_tokens;
281 let cache_tokens = self.total_cache_read_tokens + self.total_cache_write_tokens;
282
283 let session_duration = self.session_start.elapsed().unwrap_or(Duration::ZERO);
284 let session_minutes = session_duration.as_secs_f64() / 60.0;
285
286 let requests_per_minute = if session_minutes > 0.0 {
287 self.request_count as f64 / session_minutes
288 } else {
289 0.0
290 };
291
292 let tokens_per_minute = if session_minutes > 0.0 {
293 total_tokens as f64 / session_minutes
294 } else {
295 0.0
296 };
297
298 let avg_cost_per_token = if total_tokens > 0 {
299 self.total_cost_usd / total_tokens as f64
300 } else {
301 0.0
302 };
303
304 let avg_cost_per_request = if self.request_count > 0 {
305 self.total_cost_usd / self.request_count as f64
306 } else {
307 0.0
308 };
309
310 UsageSummary {
311 total_tokens,
312 input_tokens: self.total_input_tokens,
313 output_tokens: self.total_output_tokens,
314 cache_tokens,
315 total_cost_usd: self.total_cost_usd,
316 avg_cost_per_token,
317 session_duration,
318 requests_per_minute,
319 tokens_per_minute,
320 avg_cost_per_request,
321 }
322 }
323}
324
325impl ModelPricing {
326 pub fn default() -> Self {
328 let mut pricing_table = HashMap::new();
329
330 pricing_table.insert("claude-3-5-sonnet-latest".to_string(), ModelPrice {
332 input_cost_per_million: 3.00,
333 output_cost_per_million: 15.00,
334 cache_read_cost_per_million: Some(0.30),
335 cache_write_cost_per_million: Some(3.75),
336 });
337
338 pricing_table.insert("claude-3-5-sonnet-20241022".to_string(), ModelPrice {
340 input_cost_per_million: 3.00,
341 output_cost_per_million: 15.00,
342 cache_read_cost_per_million: Some(0.30),
343 cache_write_cost_per_million: Some(3.75),
344 });
345
346 pricing_table.insert("claude-3-5-haiku-latest".to_string(), ModelPrice {
348 input_cost_per_million: 1.00,
349 output_cost_per_million: 5.00,
350 cache_read_cost_per_million: Some(0.10),
351 cache_write_cost_per_million: Some(1.25),
352 });
353
354 pricing_table.insert("claude-3-opus-20240229".to_string(), ModelPrice {
356 input_cost_per_million: 15.00,
357 output_cost_per_million: 75.00,
358 cache_read_cost_per_million: Some(1.50),
359 cache_write_cost_per_million: Some(18.75),
360 });
361
362 Self { pricing_table }
363 }
364
365 pub fn get_price(&self, model: &str) -> &ModelPrice {
367 self.pricing_table.get(model)
368 .unwrap_or_else(|| {
369 self.pricing_table.get("claude-3-5-sonnet-latest").unwrap()
371 })
372 }
373
374 pub fn set_price(&mut self, model: &str, price: ModelPrice) {
376 self.pricing_table.insert(model.to_string(), price);
377 }
378}
379
380impl Default for TokenCounter {
381 fn default() -> Self {
382 Self::new()
383 }
384}
385
386impl Default for UsageStats {
387 fn default() -> Self {
388 Self::new()
389 }
390}
391
392impl Default for RequestUsage {
393 fn default() -> Self {
394 Self {
395 input_tokens: 0,
396 output_tokens: 0,
397 cache_read_tokens: 0,
398 cache_write_tokens: 0,
399 model: String::new(),
400 start_time: SystemTime::now(),
401 end_time: None,
402 cost_usd: 0.0,
403 }
404 }
405}
406
407impl std::fmt::Display for UsageSummary {
408 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
409 write!(f,
410 "Usage Summary:\n\
411 Total Tokens: {} (Input: {}, Output: {}, Cache: {})\n\
412 Total Cost: ${:.4}\n\
413 Avg Cost/Token: ${:.6}\n\
414 Avg Cost/Request: ${:.4}\n\
415 Session Duration: {:.1}min\n\
416 Rate: {:.1} tokens/min, {:.1} requests/min",
417 self.total_tokens,
418 self.input_tokens,
419 self.output_tokens,
420 self.cache_tokens,
421 self.total_cost_usd,
422 self.avg_cost_per_token,
423 self.avg_cost_per_request,
424 self.session_duration.as_secs_f64() / 60.0,
425 self.tokens_per_minute,
426 self.requests_per_minute
427 )
428 }
429}
430
431impl std::fmt::Display for CostBreakdown {
432 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
433 write!(f,
434 "Cost Breakdown ({}):\n\
435 Input: ${:.4}\n\
436 Output: ${:.4}\n\
437 Cache Read: ${:.4}\n\
438 Cache Write: ${:.4}\n\
439 Total: ${:.4} (${:.6}/token)",
440 self.model,
441 self.input_cost,
442 self.output_cost,
443 self.cache_read_cost,
444 self.cache_write_cost,
445 self.total_cost,
446 self.cost_per_token
447 )
448 }
449}
450
451#[cfg(test)]
452mod tests {
453 use super::*;
454
455 #[test]
456 fn test_cost_calculation() {
457 let counter = TokenCounter::new();
458 let usage = Usage {
459 input_tokens: 1000,
460 output_tokens: 500,
461 cache_creation_input_tokens: Some(100),
462 cache_read_input_tokens: Some(200),
463 server_tool_use: None,
464 service_tier: None,
465 };
466
467 let cost = counter.calculate_cost("claude-3-5-sonnet-latest", &usage);
468
469 let expected = 0.003 + 0.0075 + 0.00006 + 0.000375;
471 assert!((cost.total_cost - expected).abs() < 0.0001);
472 }
473
474 #[test]
475 fn test_usage_tracking() {
476 let counter = TokenCounter::new();
477 let usage = Usage {
478 input_tokens: 100,
479 output_tokens: 50,
480 cache_creation_input_tokens: None,
481 cache_read_input_tokens: None,
482 server_tool_use: None,
483 service_tier: None,
484 };
485
486 counter.record_usage("claude-3-5-sonnet-latest", &usage);
487 let stats = counter.get_stats();
488
489 assert_eq!(stats.total_input_tokens, 100);
490 assert_eq!(stats.total_output_tokens, 50);
491 assert_eq!(stats.request_count, 1);
492 assert!(stats.total_cost_usd > 0.0);
493 }
494
495 #[test]
496 fn test_cost_estimation() {
497 let counter = TokenCounter::new();
498 let cost = counter.estimate_cost("claude-3-5-sonnet-latest", 1000, 500);
499 assert!(cost > 0.0);
500
501 assert!((cost - 0.0105).abs() < 0.0001);
503 }
504}