agent_core/controller/usage/
tracker.rs1use std::collections::HashMap;
2
3use tokio::sync::RwLock;
4
5#[derive(Debug, Clone, Default)]
7pub struct TokenMeter {
8 pub input_tokens: i64,
10 pub output_tokens: i64,
12}
13
14impl TokenMeter {
15 pub fn new() -> Self {
17 Self::default()
18 }
19
20 pub fn with_values(input_tokens: i64, output_tokens: i64) -> Self {
22 Self {
23 input_tokens,
24 output_tokens,
25 }
26 }
27
28 pub fn total_tokens(&self) -> i64 {
30 self.input_tokens + self.output_tokens
31 }
32
33 pub fn add(&mut self, other: &TokenMeter) {
35 self.input_tokens += other.input_tokens;
36 self.output_tokens += other.output_tokens;
37 }
38}
39
40pub struct TokenUsageTracker {
43 tokens_per_session: RwLock<HashMap<i64, TokenMeter>>,
44 tokens_per_model: RwLock<HashMap<String, TokenMeter>>,
45 total_usage: RwLock<TokenMeter>,
46}
47
48impl TokenUsageTracker {
49 pub fn new() -> Self {
51 Self {
52 tokens_per_session: RwLock::new(HashMap::new()),
53 tokens_per_model: RwLock::new(HashMap::new()),
54 total_usage: RwLock::new(TokenMeter::new()),
55 }
56 }
57
58 pub async fn increment(
60 &self,
61 session_id: i64,
62 model: &str,
63 input_tokens: i64,
64 output_tokens: i64,
65 ) {
66 {
68 let mut sessions = self.tokens_per_session.write().await;
69 let meter = sessions.entry(session_id).or_insert_with(TokenMeter::new);
70 meter.input_tokens += input_tokens;
71 meter.output_tokens += output_tokens;
72 }
73
74 {
76 let mut models = self.tokens_per_model.write().await;
77 let meter = models
78 .entry(model.to_string())
79 .or_insert_with(TokenMeter::new);
80 meter.input_tokens += input_tokens;
81 meter.output_tokens += output_tokens;
82 }
83
84 {
86 let mut total = self.total_usage.write().await;
87 total.input_tokens += input_tokens;
88 total.output_tokens += output_tokens;
89 }
90 }
91
92 pub async fn get_session_usage(&self, session_id: i64) -> Option<TokenMeter> {
95 let sessions = self.tokens_per_session.read().await;
96 sessions.get(&session_id).cloned()
97 }
98
99 pub async fn get_model_usage(&self, model: &str) -> Option<TokenMeter> {
102 let models = self.tokens_per_model.read().await;
103 models.get(model).cloned()
104 }
105
106 pub async fn get_total_usage(&self) -> TokenMeter {
108 let total = self.total_usage.read().await;
109 total.clone()
110 }
111
112 pub async fn get_all_session_usage(&self) -> HashMap<i64, TokenMeter> {
114 let sessions = self.tokens_per_session.read().await;
115 sessions.clone()
116 }
117
118 pub async fn get_all_model_usage(&self) -> HashMap<String, TokenMeter> {
120 let models = self.tokens_per_model.read().await;
121 models.clone()
122 }
123
124 pub async fn remove_session(&self, session_id: i64) {
126 let mut sessions = self.tokens_per_session.write().await;
127 sessions.remove(&session_id);
128 }
129
130 pub async fn session_count(&self) -> usize {
132 let sessions = self.tokens_per_session.read().await;
133 sessions.len()
134 }
135
136 pub async fn model_count(&self) -> usize {
138 let models = self.tokens_per_model.read().await;
139 models.len()
140 }
141}
142
143impl Default for TokenUsageTracker {
144 fn default() -> Self {
145 Self::new()
146 }
147}
148
149#[cfg(test)]
150mod tests {
151 use super::*;
152
153 #[tokio::test]
154 async fn test_increment_usage() {
155 let tracker = TokenUsageTracker::new();
156
157 tracker.increment(1, "claude-3-sonnet", 100, 50).await;
158 tracker.increment(1, "claude-3-sonnet", 200, 100).await;
159 tracker.increment(2, "gpt-4", 150, 75).await;
160
161 let session1 = tracker.get_session_usage(1).await.unwrap();
163 assert_eq!(session1.input_tokens, 300);
164 assert_eq!(session1.output_tokens, 150);
165
166 let session2 = tracker.get_session_usage(2).await.unwrap();
167 assert_eq!(session2.input_tokens, 150);
168 assert_eq!(session2.output_tokens, 75);
169
170 let claude = tracker.get_model_usage("claude-3-sonnet").await.unwrap();
172 assert_eq!(claude.input_tokens, 300);
173 assert_eq!(claude.output_tokens, 150);
174
175 let gpt4 = tracker.get_model_usage("gpt-4").await.unwrap();
176 assert_eq!(gpt4.input_tokens, 150);
177 assert_eq!(gpt4.output_tokens, 75);
178
179 let total = tracker.get_total_usage().await;
181 assert_eq!(total.input_tokens, 450);
182 assert_eq!(total.output_tokens, 225);
183 }
184
185 #[tokio::test]
186 async fn test_nonexistent_session() {
187 let tracker = TokenUsageTracker::new();
188
189 let usage = tracker.get_session_usage(999).await;
190 assert!(usage.is_none());
191 }
192
193 #[tokio::test]
194 async fn test_remove_session() {
195 let tracker = TokenUsageTracker::new();
196
197 tracker.increment(1, "model", 100, 50).await;
198 assert!(tracker.get_session_usage(1).await.is_some());
199
200 tracker.remove_session(1).await;
201 assert!(tracker.get_session_usage(1).await.is_none());
202
203 let total = tracker.get_total_usage().await;
205 assert_eq!(total.input_tokens, 100);
206 }
207
208 #[tokio::test]
209 async fn test_token_meter() {
210 let meter = TokenMeter::with_values(100, 50);
211 assert_eq!(meter.total_tokens(), 150);
212
213 let mut meter2 = TokenMeter::new();
214 meter2.add(&meter);
215 assert_eq!(meter2.input_tokens, 100);
216 assert_eq!(meter2.output_tokens, 50);
217 }
218}