1use chrono::{DateTime, Utc};
4use serde::{Deserialize, Serialize};
5use std::collections::HashMap;
6use std::sync::Arc;
7use tokio::sync::RwLock;
8
9use crate::Result;
10
11use super::{
12 AnthropicUsageStats, ApiCallStats, CostBreakdown, RateLimitInfo, TokenUsage, UsagePeriod,
13};
14
15#[derive(Debug, Clone)]
17pub struct LocalUsageTracker {
18 data: Arc<RwLock<UsageData>>,
19}
20
21#[derive(Debug, Clone, Serialize, Deserialize)]
23struct UsageData {
24 calls: Vec<ApiCallRecord>,
25 session_start: DateTime<Utc>,
26}
27
28#[derive(Debug, Clone, Serialize, Deserialize)]
30struct ApiCallRecord {
31 timestamp: DateTime<Utc>,
32 model: String,
33 input_tokens: u32,
34 output_tokens: u32,
35 response_time_ms: u64,
36 success: bool,
37 cost_usd: f64,
38 request_id: Option<String>,
39}
40
41impl LocalUsageTracker {
42 pub fn new() -> Self {
44 Self {
45 data: Arc::new(RwLock::new(UsageData {
46 calls: Vec::new(),
47 session_start: Utc::now(),
48 })),
49 }
50 }
51
52 pub async fn record_call(
54 &self,
55 model: &str,
56 input_tokens: u32,
57 output_tokens: u32,
58 response_time_ms: u64,
59 success: bool,
60 request_id: Option<String>,
61 ) -> Result<()> {
62 let cost_usd = self.estimate_cost(model, input_tokens, output_tokens);
63
64 let record = ApiCallRecord {
65 timestamp: Utc::now(),
66 model: model.to_string(),
67 input_tokens,
68 output_tokens,
69 response_time_ms,
70 success,
71 cost_usd,
72 request_id,
73 };
74
75 let mut data = self.data.write().await;
76 data.calls.push(record);
77
78 Ok(())
79 }
80
81 pub async fn get_usage_stats(
83 &self,
84 start_time: DateTime<Utc>,
85 end_time: DateTime<Utc>,
86 ) -> Result<AnthropicUsageStats> {
87 let data = self.data.read().await;
88
89 let calls: Vec<&ApiCallRecord> = data
90 .calls
91 .iter()
92 .filter(|call| call.timestamp >= start_time && call.timestamp <= end_time)
93 .collect();
94
95 Ok(self.create_stats_from_calls(&calls, start_time, end_time))
96 }
97
98 pub async fn get_rate_limit_info(&self) -> Result<RateLimitInfo> {
100 Ok(RateLimitInfo {
101 requests_per_minute: 1000,
102 requests_remaining: 1000,
103 reset_time: Utc::now() + chrono::Duration::seconds(60),
104 tokens_per_minute: Some(50_000),
105 tokens_remaining: Some(50_000),
106 })
107 }
108
109 pub async fn clear(&self) -> Result<()> {
111 let mut data = self.data.write().await;
112 data.calls.clear();
113 data.session_start = Utc::now();
114 Ok(())
115 }
116
117 pub async fn call_count(&self) -> usize {
119 let data = self.data.read().await;
120 data.calls.len()
121 }
122
123 fn estimate_cost(&self, model: &str, input_tokens: u32, output_tokens: u32) -> f64 {
125 let (input_rate, output_rate) = match model {
127 "claude-3-haiku-20240307" => (0.25, 1.25),
128 "claude-3-sonnet-20240229" => (3.0, 15.0),
129 "claude-3-opus-20240229" => (15.0, 75.0),
130 "claude-3-5-sonnet-20241022" => (3.0, 15.0),
131 "claude-3-5-sonnet-20240620" => (3.0, 15.0),
132 "claude-3-5-haiku-20241022" => (1.0, 5.0),
133 _ => (3.0, 15.0), };
135
136 let input_cost = (input_tokens as f64 / 1_000_000.0) * input_rate;
137 let output_cost = (output_tokens as f64 / 1_000_000.0) * output_rate;
138
139 input_cost + output_cost
140 }
141
142 fn create_stats_from_calls(
144 &self,
145 calls: &[&ApiCallRecord],
146 start_time: DateTime<Utc>,
147 end_time: DateTime<Utc>,
148 ) -> AnthropicUsageStats {
149 if calls.is_empty() {
150 return self.create_empty_stats(start_time, end_time);
151 }
152
153 let token_usage = self.calculate_token_usage(calls);
154 let api_calls = self.calculate_api_stats(calls);
155 let costs = self.calculate_costs(calls);
156
157 AnthropicUsageStats {
158 token_usage,
159 api_calls,
160 costs,
161 model_usage: vec![], period: UsagePeriod {
163 start: start_time,
164 end: end_time,
165 period_type: "local_tracking".to_string(),
166 },
167 }
168 }
169
170 fn calculate_token_usage(&self, calls: &[&ApiCallRecord]) -> TokenUsage {
172 let total_input: u64 = calls.iter().map(|c| c.input_tokens as u64).sum();
173 let total_output: u64 = calls.iter().map(|c| c.output_tokens as u64).sum();
174
175 TokenUsage {
176 input_tokens: total_input,
177 output_tokens: total_output,
178 total_tokens: total_input + total_output,
179 by_model: HashMap::new(), }
181 }
182
183 fn calculate_api_stats(&self, calls: &[&ApiCallRecord]) -> ApiCallStats {
185 let total_calls = calls.len() as u64;
186 let successful_calls = calls.iter().filter(|c| c.success).count() as u64;
187 let failed_calls = total_calls - successful_calls;
188
189 let avg_response_time_ms = if !calls.is_empty() {
190 calls.iter().map(|c| c.response_time_ms).sum::<u64>() as f64 / calls.len() as f64
191 } else {
192 0.0
193 };
194
195 ApiCallStats {
196 total_calls,
197 successful_calls,
198 failed_calls,
199 avg_response_time_ms,
200 by_model: HashMap::new(), hourly_breakdown: vec![], }
203 }
204
205 fn calculate_costs(&self, calls: &[&ApiCallRecord]) -> CostBreakdown {
207 let total_cost_usd: f64 = calls.iter().map(|c| c.cost_usd).sum();
208
209 CostBreakdown {
210 total_cost_usd,
211 by_model: HashMap::new(), by_token_type: super::TokenCostBreakdown {
213 input_cost_usd: total_cost_usd * 0.2, output_cost_usd: total_cost_usd * 0.8, },
216 estimated_monthly_cost_usd: total_cost_usd * 30.0,
217 }
218 }
219
220 fn create_empty_stats(
222 &self,
223 start_time: DateTime<Utc>,
224 end_time: DateTime<Utc>,
225 ) -> AnthropicUsageStats {
226 AnthropicUsageStats {
227 token_usage: TokenUsage {
228 input_tokens: 0,
229 output_tokens: 0,
230 total_tokens: 0,
231 by_model: HashMap::new(),
232 },
233 api_calls: ApiCallStats {
234 total_calls: 0,
235 successful_calls: 0,
236 failed_calls: 0,
237 avg_response_time_ms: 0.0,
238 by_model: HashMap::new(),
239 hourly_breakdown: vec![],
240 },
241 costs: CostBreakdown {
242 total_cost_usd: 0.0,
243 by_model: HashMap::new(),
244 by_token_type: super::TokenCostBreakdown {
245 input_cost_usd: 0.0,
246 output_cost_usd: 0.0,
247 },
248 estimated_monthly_cost_usd: 0.0,
249 },
250 model_usage: vec![],
251 period: UsagePeriod {
252 start: start_time,
253 end: end_time,
254 period_type: "empty".to_string(),
255 },
256 }
257 }
258}
259
260impl Default for LocalUsageTracker {
261 fn default() -> Self {
262 Self::new()
263 }
264}
265
266#[cfg(test)]
267mod tests {
268 use super::*;
269
270 #[tokio::test]
271 async fn test_usage_tracker_new() {
272 let tracker = LocalUsageTracker::new();
273 let count = tracker.call_count().await;
274 assert_eq!(count, 0);
275 }
276
277 #[tokio::test]
278 async fn test_usage_tracker_basic() {
279 let tracker = LocalUsageTracker::new();
280
281 tracker
282 .record_call("claude-3-haiku-20240307", 100, 50, 500, true, None)
283 .await
284 .unwrap();
285
286 let count = tracker.call_count().await;
287 assert_eq!(count, 1);
288
289 let end_time = Utc::now();
290 let start_time = end_time - chrono::Duration::hours(1);
291
292 let stats = tracker.get_usage_stats(start_time, end_time).await.unwrap();
293 assert_eq!(stats.token_usage.total_tokens, 150);
294 assert_eq!(stats.api_calls.total_calls, 1);
295 assert_eq!(stats.api_calls.successful_calls, 1);
296 assert_eq!(stats.api_calls.failed_calls, 0);
297 }
298
299 #[tokio::test]
300 async fn test_record_multiple_calls() {
301 let tracker = LocalUsageTracker::new();
302
303 tracker
305 .record_call(
306 "claude-3-haiku-20240307",
307 100,
308 50,
309 500,
310 true,
311 Some("req-1".to_string()),
312 )
313 .await
314 .unwrap();
315
316 tracker
318 .record_call(
319 "claude-3-sonnet-20240229",
320 200,
321 0,
322 1000,
323 false,
324 Some("req-2".to_string()),
325 )
326 .await
327 .unwrap();
328
329 tracker
331 .record_call("claude-3-opus-20240229", 300, 100, 750, true, None)
332 .await
333 .unwrap();
334
335 let count = tracker.call_count().await;
336 assert_eq!(count, 3);
337
338 let end_time = Utc::now();
339 let start_time = end_time - chrono::Duration::hours(1);
340
341 let stats = tracker.get_usage_stats(start_time, end_time).await.unwrap();
342 assert_eq!(stats.token_usage.input_tokens, 600); assert_eq!(stats.token_usage.output_tokens, 150); assert_eq!(stats.token_usage.total_tokens, 750);
345 assert_eq!(stats.api_calls.total_calls, 3);
346 assert_eq!(stats.api_calls.successful_calls, 2);
347 assert_eq!(stats.api_calls.failed_calls, 1);
348 assert_eq!(stats.api_calls.avg_response_time_ms, 750.0); }
350
351 #[tokio::test]
352 async fn test_cost_estimation() {
353 let tracker = LocalUsageTracker::new();
354
355 let haiku_cost = tracker.estimate_cost("claude-3-haiku-20240307", 1_000_000, 1_000_000);
357 let sonnet_cost = tracker.estimate_cost("claude-3-sonnet-20240229", 1_000_000, 1_000_000);
358 let opus_cost = tracker.estimate_cost("claude-3-opus-20240229", 1_000_000, 1_000_000);
359
360 assert!(haiku_cost > 0.0);
361 assert!(sonnet_cost > haiku_cost);
362 assert!(opus_cost > sonnet_cost);
363
364 assert_eq!(haiku_cost, 1.5); assert_eq!(sonnet_cost, 18.0); assert_eq!(opus_cost, 90.0); let unknown_cost = tracker.estimate_cost("claude-unknown-model", 1_000_000, 1_000_000);
371 assert_eq!(unknown_cost, sonnet_cost);
372 }
373
374 #[tokio::test]
375 async fn test_clear_data() {
376 let tracker = LocalUsageTracker::new();
377
378 tracker
380 .record_call("claude-3-haiku-20240307", 100, 50, 500, true, None)
381 .await
382 .unwrap();
383
384 assert_eq!(tracker.call_count().await, 1);
385
386 tracker.clear().await.unwrap();
388 assert_eq!(tracker.call_count().await, 0);
389
390 let end_time = Utc::now();
392 let start_time = end_time - chrono::Duration::hours(1);
393 let stats = tracker.get_usage_stats(start_time, end_time).await.unwrap();
394 assert_eq!(stats.token_usage.total_tokens, 0);
395 assert_eq!(stats.api_calls.total_calls, 0);
396 assert_eq!(stats.costs.total_cost_usd, 0.0);
397 }
398
399 #[tokio::test]
400 async fn test_get_usage_stats_empty() {
401 let tracker = LocalUsageTracker::new();
402
403 let end_time = Utc::now();
404 let start_time = end_time - chrono::Duration::hours(1);
405
406 let stats = tracker.get_usage_stats(start_time, end_time).await.unwrap();
407 assert_eq!(stats.token_usage.total_tokens, 0);
408 assert_eq!(stats.api_calls.total_calls, 0);
409 assert_eq!(stats.costs.total_cost_usd, 0.0);
410 assert_eq!(stats.period.period_type, "empty");
411 }
412
413 #[tokio::test]
414 async fn test_get_usage_stats_time_filtering() {
415 let tracker = LocalUsageTracker::new();
416
417 tracker
419 .record_call("claude-3-haiku-20240307", 100, 50, 500, true, None)
420 .await
421 .unwrap();
422
423 let future_start = Utc::now() + chrono::Duration::hours(1);
425 let future_end = future_start + chrono::Duration::hours(1);
426 let future_stats = tracker
427 .get_usage_stats(future_start, future_end)
428 .await
429 .unwrap();
430 assert_eq!(future_stats.token_usage.total_tokens, 0);
431 assert_eq!(future_stats.api_calls.total_calls, 0);
432
433 let past_end = Utc::now() + chrono::Duration::minutes(1); let past_start = past_end - chrono::Duration::hours(1);
436 let past_stats = tracker.get_usage_stats(past_start, past_end).await.unwrap();
437 assert_eq!(past_stats.token_usage.total_tokens, 150);
438 assert_eq!(past_stats.api_calls.total_calls, 1);
439 }
440
441 #[tokio::test]
442 async fn test_get_rate_limit_info() {
443 let tracker = LocalUsageTracker::new();
444 let rate_limit = tracker.get_rate_limit_info().await.unwrap();
445
446 assert_eq!(rate_limit.requests_per_minute, 1000);
447 assert_eq!(rate_limit.requests_remaining, 1000);
448 assert_eq!(rate_limit.tokens_per_minute, Some(50_000));
449 assert_eq!(rate_limit.tokens_remaining, Some(50_000));
450 assert!(rate_limit.reset_time > Utc::now());
451 }
452
453 #[tokio::test]
454 async fn test_cost_calculation_precision() {
455 let tracker = LocalUsageTracker::new();
456
457 tracker
459 .record_call("claude-3-haiku-20240307", 1000, 500, 500, true, None)
460 .await
461 .unwrap();
462
463 let end_time = Utc::now();
464 let start_time = end_time - chrono::Duration::hours(1);
465 let stats = tracker.get_usage_stats(start_time, end_time).await.unwrap();
466
467 let expected_cost = 0.000875;
469 assert!((stats.costs.total_cost_usd - expected_cost).abs() < f64::EPSILON);
470
471 assert!(stats.costs.by_token_type.input_cost_usd > 0.0);
473 assert!(stats.costs.by_token_type.output_cost_usd > 0.0);
474 let total_cost_breakdown =
475 stats.costs.by_token_type.input_cost_usd + stats.costs.by_token_type.output_cost_usd;
476 assert!((total_cost_breakdown - stats.costs.total_cost_usd).abs() < f64::EPSILON);
477 }
478
479 #[tokio::test]
480 async fn test_token_usage_calculation() {
481 let tracker = LocalUsageTracker::new();
482
483 tracker
484 .record_call("claude-3-haiku-20240307", 100, 50, 500, true, None)
485 .await
486 .unwrap();
487
488 tracker
489 .record_call("claude-3-sonnet-20240229", 200, 75, 600, true, None)
490 .await
491 .unwrap();
492
493 let end_time = Utc::now();
494 let start_time = end_time - chrono::Duration::hours(1);
495 let stats = tracker.get_usage_stats(start_time, end_time).await.unwrap();
496
497 assert_eq!(stats.token_usage.input_tokens, 300);
498 assert_eq!(stats.token_usage.output_tokens, 125);
499 assert_eq!(stats.token_usage.total_tokens, 425);
500
501 assert!(stats.token_usage.by_model.is_empty());
503 }
504
505 #[tokio::test]
506 async fn test_api_call_stats_calculation() {
507 let tracker = LocalUsageTracker::new();
508
509 tracker
511 .record_call("claude-3-haiku-20240307", 100, 50, 200, true, None)
512 .await
513 .unwrap();
514
515 tracker
516 .record_call("claude-3-haiku-20240307", 100, 0, 300, false, None)
517 .await
518 .unwrap();
519
520 tracker
521 .record_call("claude-3-sonnet-20240229", 200, 75, 500, true, None)
522 .await
523 .unwrap();
524
525 let end_time = Utc::now();
526 let start_time = end_time - chrono::Duration::hours(1);
527 let stats = tracker.get_usage_stats(start_time, end_time).await.unwrap();
528
529 assert_eq!(stats.api_calls.total_calls, 3);
530 assert_eq!(stats.api_calls.successful_calls, 2);
531 assert_eq!(stats.api_calls.failed_calls, 1);
532
533 let expected_avg = 1000.0 / 3.0;
535 assert!((stats.api_calls.avg_response_time_ms - expected_avg).abs() < 0.01);
536 }
537
538 #[tokio::test]
539 async fn test_monthly_cost_estimation() {
540 let tracker = LocalUsageTracker::new();
541
542 tracker
543 .record_call("claude-3-haiku-20240307", 100, 50, 500, true, None)
544 .await
545 .unwrap();
546
547 let end_time = Utc::now();
548 let start_time = end_time - chrono::Duration::hours(1);
549 let stats = tracker.get_usage_stats(start_time, end_time).await.unwrap();
550
551 let expected_monthly = stats.costs.total_cost_usd * 30.0;
553 assert_eq!(stats.costs.estimated_monthly_cost_usd, expected_monthly);
554 }
555
556 #[tokio::test]
557 async fn test_default_trait() {
558 let tracker1 = LocalUsageTracker::new();
559 let tracker2 = LocalUsageTracker::default();
560
561 assert_eq!(tracker1.call_count().await, 0);
563 assert_eq!(tracker2.call_count().await, 0);
564 }
565
566 #[tokio::test]
567 async fn test_concurrent_access() {
568 let tracker = LocalUsageTracker::new();
569
570 let tracker_clone = tracker.clone();
572
573 let handle1 = tokio::spawn(async move {
575 tracker_clone
576 .record_call("claude-3-haiku-20240307", 100, 50, 500, true, None)
577 .await
578 });
579
580 let handle2 = tokio::spawn(async move {
581 tracker
582 .record_call("claude-3-sonnet-20240229", 200, 75, 600, true, None)
583 .await
584 });
585
586 let (result1, result2) = tokio::join!(handle1, handle2);
588 assert!(result1.unwrap().is_ok());
589 assert!(result2.unwrap().is_ok());
590 }
591}