use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::RwLock;
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct ProviderUsage {
pub request_count: u64,
pub input_tokens: u64,
pub output_tokens: u64,
pub total_tokens: u64,
#[serde(skip_serializing_if = "HashMap::is_empty")]
pub models: HashMap<String, ModelUsage>,
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct ModelUsage {
pub request_count: u64,
pub input_tokens: u64,
pub output_tokens: u64,
pub total_tokens: u64,
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct UsageStats {
pub providers: HashMap<String, ProviderUsage>,
pub total: ProviderUsage,
}
pub struct UsageTracker {
stats: Arc<RwLock<UsageStats>>,
}
impl UsageTracker {
pub fn new() -> Self {
Self {
stats: Arc::new(RwLock::new(UsageStats::default())),
}
}
pub async fn record(
&self,
provider: &str,
model: &str,
input_tokens: u64,
output_tokens: u64,
) {
let mut stats = self.stats.write().await;
let total_tokens = input_tokens + output_tokens;
let provider_stats = stats.providers.entry(provider.to_string()).or_default();
provider_stats.request_count += 1;
provider_stats.input_tokens += input_tokens;
provider_stats.output_tokens += output_tokens;
provider_stats.total_tokens += total_tokens;
let model_stats = provider_stats.models.entry(model.to_string()).or_default();
model_stats.request_count += 1;
model_stats.input_tokens += input_tokens;
model_stats.output_tokens += output_tokens;
model_stats.total_tokens += total_tokens;
stats.total.request_count += 1;
stats.total.input_tokens += input_tokens;
stats.total.output_tokens += output_tokens;
stats.total.total_tokens += total_tokens;
}
pub async fn get_stats(&self) -> UsageStats {
self.stats.read().await.clone()
}
pub async fn reset(&self) {
let mut stats = self.stats.write().await;
*stats = UsageStats::default();
}
pub async fn get_provider_usage(&self, provider: &str) -> Option<ProviderUsage> {
let stats = self.stats.read().await;
stats.providers.get(provider).cloned()
}
}
impl Default for UsageTracker {
fn default() -> Self {
Self::new()
}
}
impl Clone for UsageTracker {
fn clone(&self) -> Self {
Self {
stats: Arc::clone(&self.stats),
}
}
}