agent_core/controller/usage/
tracker.rs

1use std::collections::HashMap;
2
3use tokio::sync::RwLock;
4
5/// Token meter for tracking input and output tokens.
6#[derive(Debug, Clone, Default)]
7pub struct TokenMeter {
8    /// Number of input tokens.
9    pub input_tokens: i64,
10    /// Number of output tokens.
11    pub output_tokens: i64,
12}
13
14impl TokenMeter {
15    /// Create a new empty token meter.
16    pub fn new() -> Self {
17        Self::default()
18    }
19
20    /// Create a token meter with initial values.
21    pub fn with_values(input_tokens: i64, output_tokens: i64) -> Self {
22        Self {
23            input_tokens,
24            output_tokens,
25        }
26    }
27
28    /// Returns the total tokens (input + output).
29    pub fn total_tokens(&self) -> i64 {
30        self.input_tokens + self.output_tokens
31    }
32
33    /// Add tokens from another meter.
34    pub fn add(&mut self, other: &TokenMeter) {
35        self.input_tokens += other.input_tokens;
36        self.output_tokens += other.output_tokens;
37    }
38}
39
40/// Thread-safe token usage tracker.
41/// Tracks usage by session, model, and total.
42pub struct TokenUsageTracker {
43    tokens_per_session: RwLock<HashMap<i64, TokenMeter>>,
44    tokens_per_model: RwLock<HashMap<String, TokenMeter>>,
45    total_usage: RwLock<TokenMeter>,
46}
47
48impl TokenUsageTracker {
49    /// Create a new token usage tracker.
50    pub fn new() -> Self {
51        Self {
52            tokens_per_session: RwLock::new(HashMap::new()),
53            tokens_per_model: RwLock::new(HashMap::new()),
54            total_usage: RwLock::new(TokenMeter::new()),
55        }
56    }
57
58    /// Increment token usage for a session and model.
59    pub async fn increment(
60        &self,
61        session_id: i64,
62        model: &str,
63        input_tokens: i64,
64        output_tokens: i64,
65    ) {
66        // Update session usage
67        {
68            let mut sessions = self.tokens_per_session.write().await;
69            let meter = sessions.entry(session_id).or_insert_with(TokenMeter::new);
70            meter.input_tokens += input_tokens;
71            meter.output_tokens += output_tokens;
72        }
73
74        // Update model usage
75        {
76            let mut models = self.tokens_per_model.write().await;
77            let meter = models
78                .entry(model.to_string())
79                .or_insert_with(TokenMeter::new);
80            meter.input_tokens += input_tokens;
81            meter.output_tokens += output_tokens;
82        }
83
84        // Update total usage
85        {
86            let mut total = self.total_usage.write().await;
87            total.input_tokens += input_tokens;
88            total.output_tokens += output_tokens;
89        }
90    }
91
92    /// Get token usage for a specific session.
93    /// Returns None if the session has no recorded usage.
94    pub async fn get_session_usage(&self, session_id: i64) -> Option<TokenMeter> {
95        let sessions = self.tokens_per_session.read().await;
96        sessions.get(&session_id).cloned()
97    }
98
99    /// Get token usage for a specific model.
100    /// Returns None if the model has no recorded usage.
101    pub async fn get_model_usage(&self, model: &str) -> Option<TokenMeter> {
102        let models = self.tokens_per_model.read().await;
103        models.get(model).cloned()
104    }
105
106    /// Get total token usage across all sessions and models.
107    pub async fn get_total_usage(&self) -> TokenMeter {
108        let total = self.total_usage.read().await;
109        total.clone()
110    }
111
112    /// Get all session usage as a map.
113    pub async fn get_all_session_usage(&self) -> HashMap<i64, TokenMeter> {
114        let sessions = self.tokens_per_session.read().await;
115        sessions.clone()
116    }
117
118    /// Get all model usage as a map.
119    pub async fn get_all_model_usage(&self) -> HashMap<String, TokenMeter> {
120        let models = self.tokens_per_model.read().await;
121        models.clone()
122    }
123
124    /// Remove usage tracking for a session (e.g., when session is deleted).
125    pub async fn remove_session(&self, session_id: i64) {
126        let mut sessions = self.tokens_per_session.write().await;
127        sessions.remove(&session_id);
128    }
129
130    /// Get the number of tracked sessions.
131    pub async fn session_count(&self) -> usize {
132        let sessions = self.tokens_per_session.read().await;
133        sessions.len()
134    }
135
136    /// Get the number of tracked models.
137    pub async fn model_count(&self) -> usize {
138        let models = self.tokens_per_model.read().await;
139        models.len()
140    }
141}
142
143impl Default for TokenUsageTracker {
144    fn default() -> Self {
145        Self::new()
146    }
147}
148
149#[cfg(test)]
150mod tests {
151    use super::*;
152
153    #[tokio::test]
154    async fn test_increment_usage() {
155        let tracker = TokenUsageTracker::new();
156
157        tracker.increment(1, "claude-3-sonnet", 100, 50).await;
158        tracker.increment(1, "claude-3-sonnet", 200, 100).await;
159        tracker.increment(2, "gpt-4", 150, 75).await;
160
161        // Check session usage
162        let session1 = tracker.get_session_usage(1).await.unwrap();
163        assert_eq!(session1.input_tokens, 300);
164        assert_eq!(session1.output_tokens, 150);
165
166        let session2 = tracker.get_session_usage(2).await.unwrap();
167        assert_eq!(session2.input_tokens, 150);
168        assert_eq!(session2.output_tokens, 75);
169
170        // Check model usage
171        let claude = tracker.get_model_usage("claude-3-sonnet").await.unwrap();
172        assert_eq!(claude.input_tokens, 300);
173        assert_eq!(claude.output_tokens, 150);
174
175        let gpt4 = tracker.get_model_usage("gpt-4").await.unwrap();
176        assert_eq!(gpt4.input_tokens, 150);
177        assert_eq!(gpt4.output_tokens, 75);
178
179        // Check total usage
180        let total = tracker.get_total_usage().await;
181        assert_eq!(total.input_tokens, 450);
182        assert_eq!(total.output_tokens, 225);
183    }
184
185    #[tokio::test]
186    async fn test_nonexistent_session() {
187        let tracker = TokenUsageTracker::new();
188
189        let usage = tracker.get_session_usage(999).await;
190        assert!(usage.is_none());
191    }
192
193    #[tokio::test]
194    async fn test_remove_session() {
195        let tracker = TokenUsageTracker::new();
196
197        tracker.increment(1, "model", 100, 50).await;
198        assert!(tracker.get_session_usage(1).await.is_some());
199
200        tracker.remove_session(1).await;
201        assert!(tracker.get_session_usage(1).await.is_none());
202
203        // Total should still reflect the removed session's usage
204        let total = tracker.get_total_usage().await;
205        assert_eq!(total.input_tokens, 100);
206    }
207
208    #[tokio::test]
209    async fn test_token_meter() {
210        let meter = TokenMeter::with_values(100, 50);
211        assert_eq!(meter.total_tokens(), 150);
212
213        let mut meter2 = TokenMeter::new();
214        meter2.add(&meter);
215        assert_eq!(meter2.input_tokens, 100);
216        assert_eq!(meter2.output_tokens, 50);
217    }
218}