use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::RwLock;
use crate::Result;
use super::{
AnthropicUsageStats, ApiCallStats, CostBreakdown, RateLimitInfo, TokenUsage, UsagePeriod,
};
#[derive(Debug, Clone)]
pub struct LocalUsageTracker {
data: Arc<RwLock<UsageData>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct UsageData {
calls: Vec<ApiCallRecord>,
session_start: DateTime<Utc>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct ApiCallRecord {
timestamp: DateTime<Utc>,
model: String,
input_tokens: u32,
output_tokens: u32,
response_time_ms: u64,
success: bool,
cost_usd: f64,
request_id: Option<String>,
}
impl LocalUsageTracker {
pub fn new() -> Self {
Self {
data: Arc::new(RwLock::new(UsageData {
calls: Vec::new(),
session_start: Utc::now(),
})),
}
}
pub async fn record_call(
&self,
model: &str,
input_tokens: u32,
output_tokens: u32,
response_time_ms: u64,
success: bool,
request_id: Option<String>,
) -> Result<()> {
let cost_usd = self.estimate_cost(model, input_tokens, output_tokens);
let record = ApiCallRecord {
timestamp: Utc::now(),
model: model.to_string(),
input_tokens,
output_tokens,
response_time_ms,
success,
cost_usd,
request_id,
};
let mut data = self.data.write().await;
data.calls.push(record);
Ok(())
}
pub async fn get_usage_stats(
&self,
start_time: DateTime<Utc>,
end_time: DateTime<Utc>,
) -> Result<AnthropicUsageStats> {
let data = self.data.read().await;
let calls: Vec<&ApiCallRecord> = data
.calls
.iter()
.filter(|call| call.timestamp >= start_time && call.timestamp <= end_time)
.collect();
Ok(self.create_stats_from_calls(&calls, start_time, end_time))
}
pub async fn get_rate_limit_info(&self) -> Result<RateLimitInfo> {
Ok(RateLimitInfo {
requests_per_minute: 1000,
requests_remaining: 1000,
reset_time: Utc::now() + chrono::Duration::seconds(60),
tokens_per_minute: Some(50_000),
tokens_remaining: Some(50_000),
})
}
pub async fn clear(&self) -> Result<()> {
let mut data = self.data.write().await;
data.calls.clear();
data.session_start = Utc::now();
Ok(())
}
pub async fn call_count(&self) -> usize {
let data = self.data.read().await;
data.calls.len()
}
fn estimate_cost(&self, model: &str, input_tokens: u32, output_tokens: u32) -> f64 {
let (input_rate, output_rate) = match model {
"claude-3-haiku-20240307" => (0.25, 1.25),
"claude-3-sonnet-20240229" => (3.0, 15.0),
"claude-3-opus-20240229" => (15.0, 75.0),
"claude-3-5-sonnet-20241022" => (3.0, 15.0),
"claude-3-5-sonnet-20240620" => (3.0, 15.0),
"claude-3-5-haiku-20241022" => (1.0, 5.0),
_ => (3.0, 15.0), };
let input_cost = (input_tokens as f64 / 1_000_000.0) * input_rate;
let output_cost = (output_tokens as f64 / 1_000_000.0) * output_rate;
input_cost + output_cost
}
fn create_stats_from_calls(
&self,
calls: &[&ApiCallRecord],
start_time: DateTime<Utc>,
end_time: DateTime<Utc>,
) -> AnthropicUsageStats {
if calls.is_empty() {
return self.create_empty_stats(start_time, end_time);
}
let token_usage = self.calculate_token_usage(calls);
let api_calls = self.calculate_api_stats(calls);
let costs = self.calculate_costs(calls);
AnthropicUsageStats {
token_usage,
api_calls,
costs,
model_usage: vec![], period: UsagePeriod {
start: start_time,
end: end_time,
period_type: "local_tracking".to_string(),
},
}
}
fn calculate_token_usage(&self, calls: &[&ApiCallRecord]) -> TokenUsage {
let total_input: u64 = calls.iter().map(|c| c.input_tokens as u64).sum();
let total_output: u64 = calls.iter().map(|c| c.output_tokens as u64).sum();
TokenUsage {
input_tokens: total_input,
output_tokens: total_output,
total_tokens: total_input + total_output,
by_model: HashMap::new(), }
}
fn calculate_api_stats(&self, calls: &[&ApiCallRecord]) -> ApiCallStats {
let total_calls = calls.len() as u64;
let successful_calls = calls.iter().filter(|c| c.success).count() as u64;
let failed_calls = total_calls - successful_calls;
let avg_response_time_ms = if !calls.is_empty() {
calls.iter().map(|c| c.response_time_ms).sum::<u64>() as f64 / calls.len() as f64
} else {
0.0
};
ApiCallStats {
total_calls,
successful_calls,
failed_calls,
avg_response_time_ms,
by_model: HashMap::new(), hourly_breakdown: vec![], }
}
fn calculate_costs(&self, calls: &[&ApiCallRecord]) -> CostBreakdown {
let total_cost_usd: f64 = calls.iter().map(|c| c.cost_usd).sum();
CostBreakdown {
total_cost_usd,
by_model: HashMap::new(), by_token_type: super::TokenCostBreakdown {
input_cost_usd: total_cost_usd * 0.2, output_cost_usd: total_cost_usd * 0.8, },
estimated_monthly_cost_usd: total_cost_usd * 30.0,
}
}
fn create_empty_stats(
&self,
start_time: DateTime<Utc>,
end_time: DateTime<Utc>,
) -> AnthropicUsageStats {
AnthropicUsageStats {
token_usage: TokenUsage {
input_tokens: 0,
output_tokens: 0,
total_tokens: 0,
by_model: HashMap::new(),
},
api_calls: ApiCallStats {
total_calls: 0,
successful_calls: 0,
failed_calls: 0,
avg_response_time_ms: 0.0,
by_model: HashMap::new(),
hourly_breakdown: vec![],
},
costs: CostBreakdown {
total_cost_usd: 0.0,
by_model: HashMap::new(),
by_token_type: super::TokenCostBreakdown {
input_cost_usd: 0.0,
output_cost_usd: 0.0,
},
estimated_monthly_cost_usd: 0.0,
},
model_usage: vec![],
period: UsagePeriod {
start: start_time,
end: end_time,
period_type: "empty".to_string(),
},
}
}
}
impl Default for LocalUsageTracker {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_usage_tracker_new() {
let tracker = LocalUsageTracker::new();
let count = tracker.call_count().await;
assert_eq!(count, 0);
}
#[tokio::test]
async fn test_usage_tracker_basic() {
let tracker = LocalUsageTracker::new();
tracker
.record_call("claude-3-haiku-20240307", 100, 50, 500, true, None)
.await
.unwrap();
let count = tracker.call_count().await;
assert_eq!(count, 1);
let end_time = Utc::now();
let start_time = end_time - chrono::Duration::hours(1);
let stats = tracker.get_usage_stats(start_time, end_time).await.unwrap();
assert_eq!(stats.token_usage.total_tokens, 150);
assert_eq!(stats.api_calls.total_calls, 1);
assert_eq!(stats.api_calls.successful_calls, 1);
assert_eq!(stats.api_calls.failed_calls, 0);
}
#[tokio::test]
async fn test_record_multiple_calls() {
let tracker = LocalUsageTracker::new();
tracker
.record_call(
"claude-3-haiku-20240307",
100,
50,
500,
true,
Some("req-1".to_string()),
)
.await
.unwrap();
tracker
.record_call(
"claude-3-sonnet-20240229",
200,
0,
1000,
false,
Some("req-2".to_string()),
)
.await
.unwrap();
tracker
.record_call("claude-3-opus-20240229", 300, 100, 750, true, None)
.await
.unwrap();
let count = tracker.call_count().await;
assert_eq!(count, 3);
let end_time = Utc::now();
let start_time = end_time - chrono::Duration::hours(1);
let stats = tracker.get_usage_stats(start_time, end_time).await.unwrap();
assert_eq!(stats.token_usage.input_tokens, 600); assert_eq!(stats.token_usage.output_tokens, 150); assert_eq!(stats.token_usage.total_tokens, 750);
assert_eq!(stats.api_calls.total_calls, 3);
assert_eq!(stats.api_calls.successful_calls, 2);
assert_eq!(stats.api_calls.failed_calls, 1);
assert_eq!(stats.api_calls.avg_response_time_ms, 750.0); }
#[tokio::test]
async fn test_cost_estimation() {
let tracker = LocalUsageTracker::new();
let haiku_cost = tracker.estimate_cost("claude-3-haiku-20240307", 1_000_000, 1_000_000);
let sonnet_cost = tracker.estimate_cost("claude-3-sonnet-20240229", 1_000_000, 1_000_000);
let opus_cost = tracker.estimate_cost("claude-3-opus-20240229", 1_000_000, 1_000_000);
assert!(haiku_cost > 0.0);
assert!(sonnet_cost > haiku_cost);
assert!(opus_cost > sonnet_cost);
assert_eq!(haiku_cost, 1.5); assert_eq!(sonnet_cost, 18.0); assert_eq!(opus_cost, 90.0);
let unknown_cost = tracker.estimate_cost("claude-unknown-model", 1_000_000, 1_000_000);
assert_eq!(unknown_cost, sonnet_cost);
}
#[tokio::test]
async fn test_clear_data() {
let tracker = LocalUsageTracker::new();
tracker
.record_call("claude-3-haiku-20240307", 100, 50, 500, true, None)
.await
.unwrap();
assert_eq!(tracker.call_count().await, 1);
tracker.clear().await.unwrap();
assert_eq!(tracker.call_count().await, 0);
let end_time = Utc::now();
let start_time = end_time - chrono::Duration::hours(1);
let stats = tracker.get_usage_stats(start_time, end_time).await.unwrap();
assert_eq!(stats.token_usage.total_tokens, 0);
assert_eq!(stats.api_calls.total_calls, 0);
assert_eq!(stats.costs.total_cost_usd, 0.0);
}
#[tokio::test]
async fn test_get_usage_stats_empty() {
let tracker = LocalUsageTracker::new();
let end_time = Utc::now();
let start_time = end_time - chrono::Duration::hours(1);
let stats = tracker.get_usage_stats(start_time, end_time).await.unwrap();
assert_eq!(stats.token_usage.total_tokens, 0);
assert_eq!(stats.api_calls.total_calls, 0);
assert_eq!(stats.costs.total_cost_usd, 0.0);
assert_eq!(stats.period.period_type, "empty");
}
#[tokio::test]
async fn test_get_usage_stats_time_filtering() {
let tracker = LocalUsageTracker::new();
tracker
.record_call("claude-3-haiku-20240307", 100, 50, 500, true, None)
.await
.unwrap();
let future_start = Utc::now() + chrono::Duration::hours(1);
let future_end = future_start + chrono::Duration::hours(1);
let future_stats = tracker
.get_usage_stats(future_start, future_end)
.await
.unwrap();
assert_eq!(future_stats.token_usage.total_tokens, 0);
assert_eq!(future_stats.api_calls.total_calls, 0);
let past_end = Utc::now() + chrono::Duration::minutes(1); let past_start = past_end - chrono::Duration::hours(1);
let past_stats = tracker.get_usage_stats(past_start, past_end).await.unwrap();
assert_eq!(past_stats.token_usage.total_tokens, 150);
assert_eq!(past_stats.api_calls.total_calls, 1);
}
#[tokio::test]
async fn test_get_rate_limit_info() {
let tracker = LocalUsageTracker::new();
let rate_limit = tracker.get_rate_limit_info().await.unwrap();
assert_eq!(rate_limit.requests_per_minute, 1000);
assert_eq!(rate_limit.requests_remaining, 1000);
assert_eq!(rate_limit.tokens_per_minute, Some(50_000));
assert_eq!(rate_limit.tokens_remaining, Some(50_000));
assert!(rate_limit.reset_time > Utc::now());
}
#[tokio::test]
async fn test_cost_calculation_precision() {
let tracker = LocalUsageTracker::new();
tracker
.record_call("claude-3-haiku-20240307", 1000, 500, 500, true, None)
.await
.unwrap();
let end_time = Utc::now();
let start_time = end_time - chrono::Duration::hours(1);
let stats = tracker.get_usage_stats(start_time, end_time).await.unwrap();
let expected_cost = 0.000875;
assert!((stats.costs.total_cost_usd - expected_cost).abs() < f64::EPSILON);
assert!(stats.costs.by_token_type.input_cost_usd > 0.0);
assert!(stats.costs.by_token_type.output_cost_usd > 0.0);
let total_cost_breakdown =
stats.costs.by_token_type.input_cost_usd + stats.costs.by_token_type.output_cost_usd;
assert!((total_cost_breakdown - stats.costs.total_cost_usd).abs() < f64::EPSILON);
}
#[tokio::test]
async fn test_token_usage_calculation() {
let tracker = LocalUsageTracker::new();
tracker
.record_call("claude-3-haiku-20240307", 100, 50, 500, true, None)
.await
.unwrap();
tracker
.record_call("claude-3-sonnet-20240229", 200, 75, 600, true, None)
.await
.unwrap();
let end_time = Utc::now();
let start_time = end_time - chrono::Duration::hours(1);
let stats = tracker.get_usage_stats(start_time, end_time).await.unwrap();
assert_eq!(stats.token_usage.input_tokens, 300);
assert_eq!(stats.token_usage.output_tokens, 125);
assert_eq!(stats.token_usage.total_tokens, 425);
assert!(stats.token_usage.by_model.is_empty());
}
#[tokio::test]
async fn test_api_call_stats_calculation() {
let tracker = LocalUsageTracker::new();
tracker
.record_call("claude-3-haiku-20240307", 100, 50, 200, true, None)
.await
.unwrap();
tracker
.record_call("claude-3-haiku-20240307", 100, 0, 300, false, None)
.await
.unwrap();
tracker
.record_call("claude-3-sonnet-20240229", 200, 75, 500, true, None)
.await
.unwrap();
let end_time = Utc::now();
let start_time = end_time - chrono::Duration::hours(1);
let stats = tracker.get_usage_stats(start_time, end_time).await.unwrap();
assert_eq!(stats.api_calls.total_calls, 3);
assert_eq!(stats.api_calls.successful_calls, 2);
assert_eq!(stats.api_calls.failed_calls, 1);
let expected_avg = 1000.0 / 3.0;
assert!((stats.api_calls.avg_response_time_ms - expected_avg).abs() < 0.01);
}
#[tokio::test]
async fn test_monthly_cost_estimation() {
let tracker = LocalUsageTracker::new();
tracker
.record_call("claude-3-haiku-20240307", 100, 50, 500, true, None)
.await
.unwrap();
let end_time = Utc::now();
let start_time = end_time - chrono::Duration::hours(1);
let stats = tracker.get_usage_stats(start_time, end_time).await.unwrap();
let expected_monthly = stats.costs.total_cost_usd * 30.0;
assert_eq!(stats.costs.estimated_monthly_cost_usd, expected_monthly);
}
#[tokio::test]
async fn test_default_trait() {
let tracker1 = LocalUsageTracker::new();
let tracker2 = LocalUsageTracker::default();
assert_eq!(tracker1.call_count().await, 0);
assert_eq!(tracker2.call_count().await, 0);
}
#[tokio::test]
async fn test_concurrent_access() {
let tracker = LocalUsageTracker::new();
let tracker_clone = tracker.clone();
let handle1 = tokio::spawn(async move {
tracker_clone
.record_call("claude-3-haiku-20240307", 100, 50, 500, true, None)
.await
});
let handle2 = tokio::spawn(async move {
tracker
.record_call("claude-3-sonnet-20240229", 200, 75, 600, true, None)
.await
});
let (result1, result2) = tokio::join!(handle1, handle2);
assert!(result1.unwrap().is_ok());
assert!(result2.unwrap().is_ok());
}
}