cstats_core/api/
usage_tracker.rs

1//! Local usage tracking for Anthropic API calls
2
3use chrono::{DateTime, Utc};
4use serde::{Deserialize, Serialize};
5use std::collections::HashMap;
6use std::sync::Arc;
7use tokio::sync::RwLock;
8
9use crate::Result;
10
11use super::{
12    AnthropicUsageStats, ApiCallStats, CostBreakdown, RateLimitInfo, TokenUsage, UsagePeriod,
13};
14
15/// Local usage tracker for Anthropic API calls
16#[derive(Debug, Clone)]
17pub struct LocalUsageTracker {
18    data: Arc<RwLock<UsageData>>,
19}
20
21/// Internal usage data storage
22#[derive(Debug, Clone, Serialize, Deserialize)]
23struct UsageData {
24    calls: Vec<ApiCallRecord>,
25    session_start: DateTime<Utc>,
26}
27
28/// Record of an individual API call
29#[derive(Debug, Clone, Serialize, Deserialize)]
30struct ApiCallRecord {
31    timestamp: DateTime<Utc>,
32    model: String,
33    input_tokens: u32,
34    output_tokens: u32,
35    response_time_ms: u64,
36    success: bool,
37    cost_usd: f64,
38    request_id: Option<String>,
39}
40
41impl LocalUsageTracker {
42    /// Create a new usage tracker
43    pub fn new() -> Self {
44        Self {
45            data: Arc::new(RwLock::new(UsageData {
46                calls: Vec::new(),
47                session_start: Utc::now(),
48            })),
49        }
50    }
51
52    /// Record an API call
53    pub async fn record_call(
54        &self,
55        model: &str,
56        input_tokens: u32,
57        output_tokens: u32,
58        response_time_ms: u64,
59        success: bool,
60        request_id: Option<String>,
61    ) -> Result<()> {
62        let cost_usd = self.estimate_cost(model, input_tokens, output_tokens);
63
64        let record = ApiCallRecord {
65            timestamp: Utc::now(),
66            model: model.to_string(),
67            input_tokens,
68            output_tokens,
69            response_time_ms,
70            success,
71            cost_usd,
72            request_id,
73        };
74
75        let mut data = self.data.write().await;
76        data.calls.push(record);
77
78        Ok(())
79    }
80
81    /// Get usage statistics for a time period
82    pub async fn get_usage_stats(
83        &self,
84        start_time: DateTime<Utc>,
85        end_time: DateTime<Utc>,
86    ) -> Result<AnthropicUsageStats> {
87        let data = self.data.read().await;
88
89        let calls: Vec<&ApiCallRecord> = data
90            .calls
91            .iter()
92            .filter(|call| call.timestamp >= start_time && call.timestamp <= end_time)
93            .collect();
94
95        Ok(self.create_stats_from_calls(&calls, start_time, end_time))
96    }
97
98    /// Get estimated rate limit information
99    pub async fn get_rate_limit_info(&self) -> Result<RateLimitInfo> {
100        Ok(RateLimitInfo {
101            requests_per_minute: 1000,
102            requests_remaining: 1000,
103            reset_time: Utc::now() + chrono::Duration::seconds(60),
104            tokens_per_minute: Some(50_000),
105            tokens_remaining: Some(50_000),
106        })
107    }
108
109    /// Clear all recorded data
110    pub async fn clear(&self) -> Result<()> {
111        let mut data = self.data.write().await;
112        data.calls.clear();
113        data.session_start = Utc::now();
114        Ok(())
115    }
116
117    /// Get the number of recorded calls
118    pub async fn call_count(&self) -> usize {
119        let data = self.data.read().await;
120        data.calls.len()
121    }
122
123    /// Estimate cost for tokens based on model pricing
124    fn estimate_cost(&self, model: &str, input_tokens: u32, output_tokens: u32) -> f64 {
125        // Simplified cost estimation
126        let (input_rate, output_rate) = match model {
127            "claude-3-haiku-20240307" => (0.25, 1.25),
128            "claude-3-sonnet-20240229" => (3.0, 15.0),
129            "claude-3-opus-20240229" => (15.0, 75.0),
130            "claude-3-5-sonnet-20241022" => (3.0, 15.0),
131            "claude-3-5-sonnet-20240620" => (3.0, 15.0),
132            "claude-3-5-haiku-20241022" => (1.0, 5.0),
133            _ => (3.0, 15.0), // Default to Sonnet pricing
134        };
135
136        let input_cost = (input_tokens as f64 / 1_000_000.0) * input_rate;
137        let output_cost = (output_tokens as f64 / 1_000_000.0) * output_rate;
138
139        input_cost + output_cost
140    }
141
142    /// Create usage statistics from call records
143    fn create_stats_from_calls(
144        &self,
145        calls: &[&ApiCallRecord],
146        start_time: DateTime<Utc>,
147        end_time: DateTime<Utc>,
148    ) -> AnthropicUsageStats {
149        if calls.is_empty() {
150            return self.create_empty_stats(start_time, end_time);
151        }
152
153        let token_usage = self.calculate_token_usage(calls);
154        let api_calls = self.calculate_api_stats(calls);
155        let costs = self.calculate_costs(calls);
156
157        AnthropicUsageStats {
158            token_usage,
159            api_calls,
160            costs,
161            model_usage: vec![], // Simplified for now
162            period: UsagePeriod {
163                start: start_time,
164                end: end_time,
165                period_type: "local_tracking".to_string(),
166            },
167        }
168    }
169
170    /// Calculate token usage statistics
171    fn calculate_token_usage(&self, calls: &[&ApiCallRecord]) -> TokenUsage {
172        let total_input: u64 = calls.iter().map(|c| c.input_tokens as u64).sum();
173        let total_output: u64 = calls.iter().map(|c| c.output_tokens as u64).sum();
174
175        TokenUsage {
176            input_tokens: total_input,
177            output_tokens: total_output,
178            total_tokens: total_input + total_output,
179            by_model: HashMap::new(), // Simplified for now
180        }
181    }
182
183    /// Calculate API call statistics
184    fn calculate_api_stats(&self, calls: &[&ApiCallRecord]) -> ApiCallStats {
185        let total_calls = calls.len() as u64;
186        let successful_calls = calls.iter().filter(|c| c.success).count() as u64;
187        let failed_calls = total_calls - successful_calls;
188
189        let avg_response_time_ms = if !calls.is_empty() {
190            calls.iter().map(|c| c.response_time_ms).sum::<u64>() as f64 / calls.len() as f64
191        } else {
192            0.0
193        };
194
195        ApiCallStats {
196            total_calls,
197            successful_calls,
198            failed_calls,
199            avg_response_time_ms,
200            by_model: HashMap::new(), // Simplified for now
201            hourly_breakdown: vec![], // Simplified for now
202        }
203    }
204
205    /// Calculate cost breakdown
206    fn calculate_costs(&self, calls: &[&ApiCallRecord]) -> CostBreakdown {
207        let total_cost_usd: f64 = calls.iter().map(|c| c.cost_usd).sum();
208
209        CostBreakdown {
210            total_cost_usd,
211            by_model: HashMap::new(), // Simplified for now
212            by_token_type: super::TokenCostBreakdown {
213                input_cost_usd: total_cost_usd * 0.2,  // Rough estimate
214                output_cost_usd: total_cost_usd * 0.8, // Rough estimate
215            },
216            estimated_monthly_cost_usd: total_cost_usd * 30.0,
217        }
218    }
219
220    /// Create empty stats structure
221    fn create_empty_stats(
222        &self,
223        start_time: DateTime<Utc>,
224        end_time: DateTime<Utc>,
225    ) -> AnthropicUsageStats {
226        AnthropicUsageStats {
227            token_usage: TokenUsage {
228                input_tokens: 0,
229                output_tokens: 0,
230                total_tokens: 0,
231                by_model: HashMap::new(),
232            },
233            api_calls: ApiCallStats {
234                total_calls: 0,
235                successful_calls: 0,
236                failed_calls: 0,
237                avg_response_time_ms: 0.0,
238                by_model: HashMap::new(),
239                hourly_breakdown: vec![],
240            },
241            costs: CostBreakdown {
242                total_cost_usd: 0.0,
243                by_model: HashMap::new(),
244                by_token_type: super::TokenCostBreakdown {
245                    input_cost_usd: 0.0,
246                    output_cost_usd: 0.0,
247                },
248                estimated_monthly_cost_usd: 0.0,
249            },
250            model_usage: vec![],
251            period: UsagePeriod {
252                start: start_time,
253                end: end_time,
254                period_type: "empty".to_string(),
255            },
256        }
257    }
258}
259
260impl Default for LocalUsageTracker {
261    fn default() -> Self {
262        Self::new()
263    }
264}
265
266#[cfg(test)]
267mod tests {
268    use super::*;
269
270    #[tokio::test]
271    async fn test_usage_tracker_new() {
272        let tracker = LocalUsageTracker::new();
273        let count = tracker.call_count().await;
274        assert_eq!(count, 0);
275    }
276
277    #[tokio::test]
278    async fn test_usage_tracker_basic() {
279        let tracker = LocalUsageTracker::new();
280
281        tracker
282            .record_call("claude-3-haiku-20240307", 100, 50, 500, true, None)
283            .await
284            .unwrap();
285
286        let count = tracker.call_count().await;
287        assert_eq!(count, 1);
288
289        let end_time = Utc::now();
290        let start_time = end_time - chrono::Duration::hours(1);
291
292        let stats = tracker.get_usage_stats(start_time, end_time).await.unwrap();
293        assert_eq!(stats.token_usage.total_tokens, 150);
294        assert_eq!(stats.api_calls.total_calls, 1);
295        assert_eq!(stats.api_calls.successful_calls, 1);
296        assert_eq!(stats.api_calls.failed_calls, 0);
297    }
298
299    #[tokio::test]
300    async fn test_record_multiple_calls() {
301        let tracker = LocalUsageTracker::new();
302
303        // Record successful call
304        tracker
305            .record_call(
306                "claude-3-haiku-20240307",
307                100,
308                50,
309                500,
310                true,
311                Some("req-1".to_string()),
312            )
313            .await
314            .unwrap();
315
316        // Record failed call
317        tracker
318            .record_call(
319                "claude-3-sonnet-20240229",
320                200,
321                0,
322                1000,
323                false,
324                Some("req-2".to_string()),
325            )
326            .await
327            .unwrap();
328
329        // Record another successful call
330        tracker
331            .record_call("claude-3-opus-20240229", 300, 100, 750, true, None)
332            .await
333            .unwrap();
334
335        let count = tracker.call_count().await;
336        assert_eq!(count, 3);
337
338        let end_time = Utc::now();
339        let start_time = end_time - chrono::Duration::hours(1);
340
341        let stats = tracker.get_usage_stats(start_time, end_time).await.unwrap();
342        assert_eq!(stats.token_usage.input_tokens, 600); // 100 + 200 + 300
343        assert_eq!(stats.token_usage.output_tokens, 150); // 50 + 0 + 100
344        assert_eq!(stats.token_usage.total_tokens, 750);
345        assert_eq!(stats.api_calls.total_calls, 3);
346        assert_eq!(stats.api_calls.successful_calls, 2);
347        assert_eq!(stats.api_calls.failed_calls, 1);
348        assert_eq!(stats.api_calls.avg_response_time_ms, 750.0); // (500 + 1000 + 750) / 3
349    }
350
351    #[tokio::test]
352    async fn test_cost_estimation() {
353        let tracker = LocalUsageTracker::new();
354
355        // Test different models
356        let haiku_cost = tracker.estimate_cost("claude-3-haiku-20240307", 1_000_000, 1_000_000);
357        let sonnet_cost = tracker.estimate_cost("claude-3-sonnet-20240229", 1_000_000, 1_000_000);
358        let opus_cost = tracker.estimate_cost("claude-3-opus-20240229", 1_000_000, 1_000_000);
359
360        assert!(haiku_cost > 0.0);
361        assert!(sonnet_cost > haiku_cost);
362        assert!(opus_cost > sonnet_cost);
363
364        // Test specific values
365        assert_eq!(haiku_cost, 1.5); // (0.25 + 1.25) for 1M tokens each
366        assert_eq!(sonnet_cost, 18.0); // (3.0 + 15.0) for 1M tokens each
367        assert_eq!(opus_cost, 90.0); // (15.0 + 75.0) for 1M tokens each
368
369        // Test unknown model defaults to Sonnet pricing
370        let unknown_cost = tracker.estimate_cost("claude-unknown-model", 1_000_000, 1_000_000);
371        assert_eq!(unknown_cost, sonnet_cost);
372    }
373
374    #[tokio::test]
375    async fn test_clear_data() {
376        let tracker = LocalUsageTracker::new();
377
378        // Add some data
379        tracker
380            .record_call("claude-3-haiku-20240307", 100, 50, 500, true, None)
381            .await
382            .unwrap();
383
384        assert_eq!(tracker.call_count().await, 1);
385
386        // Clear data
387        tracker.clear().await.unwrap();
388        assert_eq!(tracker.call_count().await, 0);
389
390        // Verify stats are empty
391        let end_time = Utc::now();
392        let start_time = end_time - chrono::Duration::hours(1);
393        let stats = tracker.get_usage_stats(start_time, end_time).await.unwrap();
394        assert_eq!(stats.token_usage.total_tokens, 0);
395        assert_eq!(stats.api_calls.total_calls, 0);
396        assert_eq!(stats.costs.total_cost_usd, 0.0);
397    }
398
399    #[tokio::test]
400    async fn test_get_usage_stats_empty() {
401        let tracker = LocalUsageTracker::new();
402
403        let end_time = Utc::now();
404        let start_time = end_time - chrono::Duration::hours(1);
405
406        let stats = tracker.get_usage_stats(start_time, end_time).await.unwrap();
407        assert_eq!(stats.token_usage.total_tokens, 0);
408        assert_eq!(stats.api_calls.total_calls, 0);
409        assert_eq!(stats.costs.total_cost_usd, 0.0);
410        assert_eq!(stats.period.period_type, "empty");
411    }
412
413    #[tokio::test]
414    async fn test_get_usage_stats_time_filtering() {
415        let tracker = LocalUsageTracker::new();
416
417        // Record a call
418        tracker
419            .record_call("claude-3-haiku-20240307", 100, 50, 500, true, None)
420            .await
421            .unwrap();
422
423        // Query for future time range (should be empty)
424        let future_start = Utc::now() + chrono::Duration::hours(1);
425        let future_end = future_start + chrono::Duration::hours(1);
426        let future_stats = tracker
427            .get_usage_stats(future_start, future_end)
428            .await
429            .unwrap();
430        assert_eq!(future_stats.token_usage.total_tokens, 0);
431        assert_eq!(future_stats.api_calls.total_calls, 0);
432
433        // Query for past time range (should include the call)
434        let past_end = Utc::now() + chrono::Duration::minutes(1); // slight buffer for timing
435        let past_start = past_end - chrono::Duration::hours(1);
436        let past_stats = tracker.get_usage_stats(past_start, past_end).await.unwrap();
437        assert_eq!(past_stats.token_usage.total_tokens, 150);
438        assert_eq!(past_stats.api_calls.total_calls, 1);
439    }
440
441    #[tokio::test]
442    async fn test_get_rate_limit_info() {
443        let tracker = LocalUsageTracker::new();
444        let rate_limit = tracker.get_rate_limit_info().await.unwrap();
445
446        assert_eq!(rate_limit.requests_per_minute, 1000);
447        assert_eq!(rate_limit.requests_remaining, 1000);
448        assert_eq!(rate_limit.tokens_per_minute, Some(50_000));
449        assert_eq!(rate_limit.tokens_remaining, Some(50_000));
450        assert!(rate_limit.reset_time > Utc::now());
451    }
452
453    #[tokio::test]
454    async fn test_cost_calculation_precision() {
455        let tracker = LocalUsageTracker::new();
456
457        // Test small token amounts
458        tracker
459            .record_call("claude-3-haiku-20240307", 1000, 500, 500, true, None)
460            .await
461            .unwrap();
462
463        let end_time = Utc::now();
464        let start_time = end_time - chrono::Duration::hours(1);
465        let stats = tracker.get_usage_stats(start_time, end_time).await.unwrap();
466
467        // Expected cost: (1000/1M * 0.25) + (500/1M * 1.25) = 0.00025 + 0.000625 = 0.000875
468        let expected_cost = 0.000875;
469        assert!((stats.costs.total_cost_usd - expected_cost).abs() < f64::EPSILON);
470
471        // Check cost breakdown approximation (simplified in current implementation)
472        assert!(stats.costs.by_token_type.input_cost_usd > 0.0);
473        assert!(stats.costs.by_token_type.output_cost_usd > 0.0);
474        let total_cost_breakdown =
475            stats.costs.by_token_type.input_cost_usd + stats.costs.by_token_type.output_cost_usd;
476        assert!((total_cost_breakdown - stats.costs.total_cost_usd).abs() < f64::EPSILON);
477    }
478
479    #[tokio::test]
480    async fn test_token_usage_calculation() {
481        let tracker = LocalUsageTracker::new();
482
483        tracker
484            .record_call("claude-3-haiku-20240307", 100, 50, 500, true, None)
485            .await
486            .unwrap();
487
488        tracker
489            .record_call("claude-3-sonnet-20240229", 200, 75, 600, true, None)
490            .await
491            .unwrap();
492
493        let end_time = Utc::now();
494        let start_time = end_time - chrono::Duration::hours(1);
495        let stats = tracker.get_usage_stats(start_time, end_time).await.unwrap();
496
497        assert_eq!(stats.token_usage.input_tokens, 300);
498        assert_eq!(stats.token_usage.output_tokens, 125);
499        assert_eq!(stats.token_usage.total_tokens, 425);
500
501        // Verify model breakdown is empty in simplified implementation
502        assert!(stats.token_usage.by_model.is_empty());
503    }
504
505    #[tokio::test]
506    async fn test_api_call_stats_calculation() {
507        let tracker = LocalUsageTracker::new();
508
509        // Mix of successful and failed calls with different response times
510        tracker
511            .record_call("claude-3-haiku-20240307", 100, 50, 200, true, None)
512            .await
513            .unwrap();
514
515        tracker
516            .record_call("claude-3-haiku-20240307", 100, 0, 300, false, None)
517            .await
518            .unwrap();
519
520        tracker
521            .record_call("claude-3-sonnet-20240229", 200, 75, 500, true, None)
522            .await
523            .unwrap();
524
525        let end_time = Utc::now();
526        let start_time = end_time - chrono::Duration::hours(1);
527        let stats = tracker.get_usage_stats(start_time, end_time).await.unwrap();
528
529        assert_eq!(stats.api_calls.total_calls, 3);
530        assert_eq!(stats.api_calls.successful_calls, 2);
531        assert_eq!(stats.api_calls.failed_calls, 1);
532
533        // Average response time: (200 + 300 + 500) / 3 = 333.33...
534        let expected_avg = 1000.0 / 3.0;
535        assert!((stats.api_calls.avg_response_time_ms - expected_avg).abs() < 0.01);
536    }
537
538    #[tokio::test]
539    async fn test_monthly_cost_estimation() {
540        let tracker = LocalUsageTracker::new();
541
542        tracker
543            .record_call("claude-3-haiku-20240307", 100, 50, 500, true, None)
544            .await
545            .unwrap();
546
547        let end_time = Utc::now();
548        let start_time = end_time - chrono::Duration::hours(1);
549        let stats = tracker.get_usage_stats(start_time, end_time).await.unwrap();
550
551        // Monthly cost should be daily cost * 30
552        let expected_monthly = stats.costs.total_cost_usd * 30.0;
553        assert_eq!(stats.costs.estimated_monthly_cost_usd, expected_monthly);
554    }
555
556    #[tokio::test]
557    async fn test_default_trait() {
558        let tracker1 = LocalUsageTracker::new();
559        let tracker2 = LocalUsageTracker::default();
560
561        // Both should start with zero calls
562        assert_eq!(tracker1.call_count().await, 0);
563        assert_eq!(tracker2.call_count().await, 0);
564    }
565
566    #[tokio::test]
567    async fn test_concurrent_access() {
568        let tracker = LocalUsageTracker::new();
569
570        // Clone tracker for concurrent access
571        let tracker_clone = tracker.clone();
572
573        // Spawn concurrent tasks
574        let handle1 = tokio::spawn(async move {
575            tracker_clone
576                .record_call("claude-3-haiku-20240307", 100, 50, 500, true, None)
577                .await
578        });
579
580        let handle2 = tokio::spawn(async move {
581            tracker
582                .record_call("claude-3-sonnet-20240229", 200, 75, 600, true, None)
583                .await
584        });
585
586        // Wait for both tasks
587        let (result1, result2) = tokio::join!(handle1, handle2);
588        assert!(result1.unwrap().is_ok());
589        assert!(result2.unwrap().is_ok());
590    }
591}