1use byokey_types::{DEFAULT_ACCOUNT, UsageRecord, UsageStore};
5use serde::Serialize;
6use std::collections::HashMap;
7use std::sync::atomic::{AtomicU64, Ordering};
8use std::sync::{Arc, Mutex};
9use tokio::sync::mpsc;
10
11#[derive(Default)]
13pub struct UsageStats {
14 pub total_requests: AtomicU64,
16 pub success_requests: AtomicU64,
18 pub failure_requests: AtomicU64,
20 pub input_tokens: AtomicU64,
22 pub output_tokens: AtomicU64,
24 model_counts: Mutex<HashMap<String, ModelStats>>,
26}
27
28#[derive(Default, Clone, Serialize)]
30pub struct ModelStats {
31 pub requests: u64,
32 pub success: u64,
33 pub failure: u64,
34 pub input_tokens: u64,
35 pub output_tokens: u64,
36}
37
38#[derive(Serialize)]
40pub struct UsageSnapshot {
41 pub total_requests: u64,
42 pub success_requests: u64,
43 pub failure_requests: u64,
44 pub input_tokens: u64,
45 pub output_tokens: u64,
46 pub models: HashMap<String, ModelStats>,
47}
48
49impl UsageStats {
50 #[must_use]
52 pub fn new() -> Self {
53 Self::default()
54 }
55
56 pub fn record_success(&self, model: &str, input_tokens: u64, output_tokens: u64) {
58 self.total_requests.fetch_add(1, Ordering::Relaxed);
59 self.success_requests.fetch_add(1, Ordering::Relaxed);
60 self.input_tokens.fetch_add(input_tokens, Ordering::Relaxed);
61 self.output_tokens
62 .fetch_add(output_tokens, Ordering::Relaxed);
63
64 if let Ok(mut map) = self.model_counts.lock() {
65 let entry = map.entry(model.to_string()).or_default();
66 entry.requests += 1;
67 entry.success += 1;
68 entry.input_tokens += input_tokens;
69 entry.output_tokens += output_tokens;
70 }
71 }
72
73 pub fn record_failure(&self, model: &str) {
75 self.total_requests.fetch_add(1, Ordering::Relaxed);
76 self.failure_requests.fetch_add(1, Ordering::Relaxed);
77
78 if let Ok(mut map) = self.model_counts.lock() {
79 let entry = map.entry(model.to_string()).or_default();
80 entry.requests += 1;
81 entry.failure += 1;
82 }
83 }
84
85 #[must_use]
87 pub fn snapshot(&self) -> UsageSnapshot {
88 let models = self
89 .model_counts
90 .lock()
91 .map(|m| m.clone())
92 .unwrap_or_default();
93 UsageSnapshot {
94 total_requests: self.total_requests.load(Ordering::Relaxed),
95 success_requests: self.success_requests.load(Ordering::Relaxed),
96 failure_requests: self.failure_requests.load(Ordering::Relaxed),
97 input_tokens: self.input_tokens.load(Ordering::Relaxed),
98 output_tokens: self.output_tokens.load(Ordering::Relaxed),
99 models,
100 }
101 }
102}
103
104pub struct UsageRecorder {
110 stats: UsageStats,
111 store: Option<Arc<dyn UsageStore>>,
112 sender: Option<mpsc::UnboundedSender<UsageRecord>>,
113}
114
115impl UsageRecorder {
116 #[must_use]
121 pub fn new(store: Option<Arc<dyn UsageStore>>) -> Self {
122 let sender = store.as_ref().map(|s| {
123 let (tx, rx) = mpsc::unbounded_channel::<UsageRecord>();
124 let flush_store = Arc::clone(s);
125 tokio::spawn(Self::flush_loop(flush_store, rx));
126 tx
127 });
128 Self {
129 stats: UsageStats::new(),
130 store,
131 sender,
132 }
133 }
134
135 async fn flush_loop(store: Arc<dyn UsageStore>, mut rx: mpsc::UnboundedReceiver<UsageRecord>) {
137 const BATCH_CAP: usize = 64;
138 let mut buf: Vec<UsageRecord> = Vec::with_capacity(BATCH_CAP);
139
140 while let Some(record) = rx.recv().await {
141 buf.push(record);
142
143 while buf.len() < BATCH_CAP {
145 match rx.try_recv() {
146 Ok(r) => buf.push(r),
147 Err(_) => break,
148 }
149 }
150
151 for record in buf.drain(..) {
152 if let Err(e) = store.record(&record).await {
153 tracing::warn!(error = %e, "failed to persist usage record");
154 }
155 }
156 }
157 }
158
159 pub fn record_success(
165 &self,
166 model: &str,
167 provider: &str,
168 input_tokens: u64,
169 output_tokens: u64,
170 ) {
171 self.record_success_for(
172 model,
173 provider,
174 DEFAULT_ACCOUNT,
175 input_tokens,
176 output_tokens,
177 );
178 }
179
180 pub fn record_success_for(
183 &self,
184 model: &str,
185 provider: &str,
186 account_id: &str,
187 input_tokens: u64,
188 output_tokens: u64,
189 ) {
190 self.stats
191 .record_success(model, input_tokens, output_tokens);
192 self.persist(
193 model,
194 provider,
195 account_id,
196 input_tokens,
197 output_tokens,
198 true,
199 );
200 }
201
202 pub fn record_failure(&self, model: &str, provider: &str) {
204 self.record_failure_for(model, provider, DEFAULT_ACCOUNT);
205 }
206
207 pub fn record_failure_for(&self, model: &str, provider: &str, account_id: &str) {
209 self.stats.record_failure(model);
210 self.persist(model, provider, account_id, 0, 0, false);
211 }
212
213 #[must_use]
215 pub fn snapshot(&self) -> UsageSnapshot {
216 self.stats.snapshot()
217 }
218
219 pub fn preload(&self, model: &str, requests: u64, input_tokens: u64, output_tokens: u64) {
221 self.stats
222 .total_requests
223 .fetch_add(requests, Ordering::Relaxed);
224 self.stats
225 .success_requests
226 .fetch_add(requests, Ordering::Relaxed);
227 self.stats
228 .input_tokens
229 .fetch_add(input_tokens, Ordering::Relaxed);
230 self.stats
231 .output_tokens
232 .fetch_add(output_tokens, Ordering::Relaxed);
233
234 if let Ok(mut map) = self.stats.model_counts.lock() {
235 let entry = map.entry(model.to_string()).or_default();
236 entry.requests += requests;
237 entry.success += requests;
238 entry.input_tokens += input_tokens;
239 entry.output_tokens += output_tokens;
240 }
241 }
242
243 pub fn store(&self) -> Option<&Arc<dyn UsageStore>> {
245 self.store.as_ref()
246 }
247
248 fn persist(
249 &self,
250 model: &str,
251 provider: &str,
252 account_id: &str,
253 input_tokens: u64,
254 output_tokens: u64,
255 success: bool,
256 ) {
257 if let Some(sender) = &self.sender {
258 let record = UsageRecord {
259 model: model.to_string(),
260 provider: provider.to_string(),
261 account_id: account_id.to_string(),
262 input_tokens,
263 output_tokens,
264 success,
265 };
266 let _ = sender.send(record);
267 }
268 }
269}
270
271#[cfg(test)]
272mod tests {
273 use super::*;
274
275 #[test]
276 fn test_record_success() {
277 let stats = UsageStats::new();
278 stats.record_success("claude-opus-4-5", 100, 200);
279 stats.record_success("claude-opus-4-5", 50, 100);
280 stats.record_success("gpt-4o", 80, 150);
281
282 let snap = stats.snapshot();
283 assert_eq!(snap.total_requests, 3);
284 assert_eq!(snap.success_requests, 3);
285 assert_eq!(snap.failure_requests, 0);
286 assert_eq!(snap.input_tokens, 230);
287 assert_eq!(snap.output_tokens, 450);
288
289 let claude = &snap.models["claude-opus-4-5"];
290 assert_eq!(claude.requests, 2);
291 assert_eq!(claude.success, 2);
292 assert_eq!(claude.input_tokens, 150);
293 assert_eq!(claude.output_tokens, 300);
294 }
295
296 #[test]
297 fn test_record_failure() {
298 let stats = UsageStats::new();
299 stats.record_failure("gpt-4o");
300 stats.record_success("gpt-4o", 10, 20);
301
302 let snap = stats.snapshot();
303 assert_eq!(snap.total_requests, 2);
304 assert_eq!(snap.success_requests, 1);
305 assert_eq!(snap.failure_requests, 1);
306
307 let model = &snap.models["gpt-4o"];
308 assert_eq!(model.requests, 2);
309 assert_eq!(model.failure, 1);
310 assert_eq!(model.success, 1);
311 }
312
313 #[test]
314 fn test_snapshot_empty() {
315 let stats = UsageStats::new();
316 let snap = stats.snapshot();
317 assert_eq!(snap.total_requests, 0);
318 assert!(snap.models.is_empty());
319 }
320}