1use byokey_types::{UsageRecord, UsageStore};
5use serde::Serialize;
6use std::collections::HashMap;
7use std::sync::atomic::{AtomicU64, Ordering};
8use std::sync::{Arc, Mutex};
9use tokio::sync::mpsc;
10use utoipa::ToSchema;
11
12#[derive(Default)]
14pub struct UsageStats {
15 pub total_requests: AtomicU64,
17 pub success_requests: AtomicU64,
19 pub failure_requests: AtomicU64,
21 pub input_tokens: AtomicU64,
23 pub output_tokens: AtomicU64,
25 model_counts: Mutex<HashMap<String, ModelStats>>,
27}
28
29#[derive(Default, Clone, Serialize, ToSchema)]
31pub struct ModelStats {
32 pub requests: u64,
33 pub success: u64,
34 pub failure: u64,
35 pub input_tokens: u64,
36 pub output_tokens: u64,
37}
38
39#[derive(Serialize, ToSchema)]
41pub struct UsageSnapshot {
42 pub total_requests: u64,
43 pub success_requests: u64,
44 pub failure_requests: u64,
45 pub input_tokens: u64,
46 pub output_tokens: u64,
47 pub models: HashMap<String, ModelStats>,
48}
49
50impl UsageStats {
51 #[must_use]
53 pub fn new() -> Self {
54 Self::default()
55 }
56
57 pub fn record_success(&self, model: &str, input_tokens: u64, output_tokens: u64) {
59 self.total_requests.fetch_add(1, Ordering::Relaxed);
60 self.success_requests.fetch_add(1, Ordering::Relaxed);
61 self.input_tokens.fetch_add(input_tokens, Ordering::Relaxed);
62 self.output_tokens
63 .fetch_add(output_tokens, Ordering::Relaxed);
64
65 if let Ok(mut map) = self.model_counts.lock() {
66 let entry = map.entry(model.to_string()).or_default();
67 entry.requests += 1;
68 entry.success += 1;
69 entry.input_tokens += input_tokens;
70 entry.output_tokens += output_tokens;
71 }
72 }
73
74 pub fn record_failure(&self, model: &str) {
76 self.total_requests.fetch_add(1, Ordering::Relaxed);
77 self.failure_requests.fetch_add(1, Ordering::Relaxed);
78
79 if let Ok(mut map) = self.model_counts.lock() {
80 let entry = map.entry(model.to_string()).or_default();
81 entry.requests += 1;
82 entry.failure += 1;
83 }
84 }
85
86 #[must_use]
88 pub fn snapshot(&self) -> UsageSnapshot {
89 let models = self
90 .model_counts
91 .lock()
92 .map(|m| m.clone())
93 .unwrap_or_default();
94 UsageSnapshot {
95 total_requests: self.total_requests.load(Ordering::Relaxed),
96 success_requests: self.success_requests.load(Ordering::Relaxed),
97 failure_requests: self.failure_requests.load(Ordering::Relaxed),
98 input_tokens: self.input_tokens.load(Ordering::Relaxed),
99 output_tokens: self.output_tokens.load(Ordering::Relaxed),
100 models,
101 }
102 }
103}
104
105pub struct UsageRecorder {
111 stats: UsageStats,
112 store: Option<Arc<dyn UsageStore>>,
113 sender: Option<mpsc::UnboundedSender<UsageRecord>>,
114}
115
116impl UsageRecorder {
117 #[must_use]
122 pub fn new(store: Option<Arc<dyn UsageStore>>) -> Self {
123 let sender = store.as_ref().map(|s| {
124 let (tx, rx) = mpsc::unbounded_channel::<UsageRecord>();
125 let flush_store = Arc::clone(s);
126 tokio::spawn(Self::flush_loop(flush_store, rx));
127 tx
128 });
129 Self {
130 stats: UsageStats::new(),
131 store,
132 sender,
133 }
134 }
135
136 async fn flush_loop(store: Arc<dyn UsageStore>, mut rx: mpsc::UnboundedReceiver<UsageRecord>) {
138 const BATCH_CAP: usize = 64;
139 let mut buf: Vec<UsageRecord> = Vec::with_capacity(BATCH_CAP);
140
141 while let Some(record) = rx.recv().await {
142 buf.push(record);
143
144 while buf.len() < BATCH_CAP {
146 match rx.try_recv() {
147 Ok(r) => buf.push(r),
148 Err(_) => break,
149 }
150 }
151
152 for record in buf.drain(..) {
153 if let Err(e) = store.record(&record).await {
154 tracing::warn!(error = %e, "failed to persist usage record");
155 }
156 }
157 }
158 }
159
160 pub fn record_success(
162 &self,
163 model: &str,
164 provider: &str,
165 input_tokens: u64,
166 output_tokens: u64,
167 ) {
168 self.stats
169 .record_success(model, input_tokens, output_tokens);
170 self.persist(model, provider, input_tokens, output_tokens, true);
171 }
172
173 pub fn record_failure(&self, model: &str, provider: &str) {
175 self.stats.record_failure(model);
176 self.persist(model, provider, 0, 0, false);
177 }
178
179 #[must_use]
181 pub fn snapshot(&self) -> UsageSnapshot {
182 self.stats.snapshot()
183 }
184
185 pub fn preload(&self, model: &str, requests: u64, input_tokens: u64, output_tokens: u64) {
187 self.stats
188 .total_requests
189 .fetch_add(requests, Ordering::Relaxed);
190 self.stats
191 .success_requests
192 .fetch_add(requests, Ordering::Relaxed);
193 self.stats
194 .input_tokens
195 .fetch_add(input_tokens, Ordering::Relaxed);
196 self.stats
197 .output_tokens
198 .fetch_add(output_tokens, Ordering::Relaxed);
199
200 if let Ok(mut map) = self.stats.model_counts.lock() {
201 let entry = map.entry(model.to_string()).or_default();
202 entry.requests += requests;
203 entry.success += requests;
204 entry.input_tokens += input_tokens;
205 entry.output_tokens += output_tokens;
206 }
207 }
208
209 pub fn store(&self) -> Option<&Arc<dyn UsageStore>> {
211 self.store.as_ref()
212 }
213
214 fn persist(
215 &self,
216 model: &str,
217 provider: &str,
218 input_tokens: u64,
219 output_tokens: u64,
220 success: bool,
221 ) {
222 if let Some(sender) = &self.sender {
223 let record = UsageRecord {
224 model: model.to_string(),
225 provider: provider.to_string(),
226 input_tokens,
227 output_tokens,
228 success,
229 };
230 let _ = sender.send(record);
231 }
232 }
233}
234
235#[cfg(test)]
236mod tests {
237 use super::*;
238
239 #[test]
240 fn test_record_success() {
241 let stats = UsageStats::new();
242 stats.record_success("claude-opus-4-5", 100, 200);
243 stats.record_success("claude-opus-4-5", 50, 100);
244 stats.record_success("gpt-4o", 80, 150);
245
246 let snap = stats.snapshot();
247 assert_eq!(snap.total_requests, 3);
248 assert_eq!(snap.success_requests, 3);
249 assert_eq!(snap.failure_requests, 0);
250 assert_eq!(snap.input_tokens, 230);
251 assert_eq!(snap.output_tokens, 450);
252
253 let claude = &snap.models["claude-opus-4-5"];
254 assert_eq!(claude.requests, 2);
255 assert_eq!(claude.success, 2);
256 assert_eq!(claude.input_tokens, 150);
257 assert_eq!(claude.output_tokens, 300);
258 }
259
260 #[test]
261 fn test_record_failure() {
262 let stats = UsageStats::new();
263 stats.record_failure("gpt-4o");
264 stats.record_success("gpt-4o", 10, 20);
265
266 let snap = stats.snapshot();
267 assert_eq!(snap.total_requests, 2);
268 assert_eq!(snap.success_requests, 1);
269 assert_eq!(snap.failure_requests, 1);
270
271 let model = &snap.models["gpt-4o"];
272 assert_eq!(model.requests, 2);
273 assert_eq!(model.failure, 1);
274 assert_eq!(model.success, 1);
275 }
276
277 #[test]
278 fn test_snapshot_empty() {
279 let stats = UsageStats::new();
280 let snap = stats.snapshot();
281 assert_eq!(snap.total_requests, 0);
282 assert!(snap.models.is_empty());
283 }
284}