Skip to main content

aiclient_api/usage/
tracker.rs

1use serde::{Deserialize, Serialize};
2use std::collections::HashMap;
3use std::sync::Arc;
4use tokio::sync::RwLock;
5
6/// Usage statistics for a single provider
7#[derive(Debug, Clone, Serialize, Deserialize, Default)]
8pub struct ProviderUsage {
9    /// Total number of requests
10    pub request_count: u64,
11    /// Total input/prompt tokens consumed
12    pub input_tokens: u64,
13    /// Total output/completion tokens consumed
14    pub output_tokens: u64,
15    /// Total tokens consumed
16    pub total_tokens: u64,
17    /// Per-model usage breakdown
18    #[serde(skip_serializing_if = "HashMap::is_empty")]
19    pub models: HashMap<String, ModelUsage>,
20}
21
22/// Usage statistics for a specific model
23#[derive(Debug, Clone, Serialize, Deserialize, Default)]
24pub struct ModelUsage {
25    pub request_count: u64,
26    pub input_tokens: u64,
27    pub output_tokens: u64,
28    pub total_tokens: u64,
29}
30
31/// Aggregated usage statistics across all providers
32#[derive(Debug, Clone, Serialize, Deserialize, Default)]
33pub struct UsageStats {
34    /// Per-provider usage
35    pub providers: HashMap<String, ProviderUsage>,
36    /// Total aggregated usage
37    pub total: ProviderUsage,
38}
39
40/// Thread-safe usage tracker
41pub struct UsageTracker {
42    stats: Arc<RwLock<UsageStats>>,
43}
44
45impl UsageTracker {
46    pub fn new() -> Self {
47        Self {
48            stats: Arc::new(RwLock::new(UsageStats::default())),
49        }
50    }
51
52    /// Record usage for a request
53    pub async fn record(
54        &self,
55        provider: &str,
56        model: &str,
57        input_tokens: u64,
58        output_tokens: u64,
59    ) {
60        let mut stats = self.stats.write().await;
61        let total_tokens = input_tokens + output_tokens;
62
63        // Update provider stats
64        let provider_stats = stats.providers.entry(provider.to_string()).or_default();
65        provider_stats.request_count += 1;
66        provider_stats.input_tokens += input_tokens;
67        provider_stats.output_tokens += output_tokens;
68        provider_stats.total_tokens += total_tokens;
69
70        // Update per-model stats
71        let model_stats = provider_stats.models.entry(model.to_string()).or_default();
72        model_stats.request_count += 1;
73        model_stats.input_tokens += input_tokens;
74        model_stats.output_tokens += output_tokens;
75        model_stats.total_tokens += total_tokens;
76
77        // Update total stats
78        stats.total.request_count += 1;
79        stats.total.input_tokens += input_tokens;
80        stats.total.output_tokens += output_tokens;
81        stats.total.total_tokens += total_tokens;
82    }
83
84    /// Get current usage statistics
85    pub async fn get_stats(&self) -> UsageStats {
86        self.stats.read().await.clone()
87    }
88
89    /// Reset all statistics
90    pub async fn reset(&self) {
91        let mut stats = self.stats.write().await;
92        *stats = UsageStats::default();
93    }
94
95    /// Get usage for a specific provider
96    pub async fn get_provider_usage(&self, provider: &str) -> Option<ProviderUsage> {
97        let stats = self.stats.read().await;
98        stats.providers.get(provider).cloned()
99    }
100}
101
102impl Default for UsageTracker {
103    fn default() -> Self {
104        Self::new()
105    }
106}
107
108impl Clone for UsageTracker {
109    fn clone(&self) -> Self {
110        Self {
111            stats: Arc::clone(&self.stats),
112        }
113    }
114}