Skip to main content

byokey_proxy/
usage.rs

1//! In-memory usage statistics for request/token tracking, with optional
2//! persistent backing via [`UsageStore`].
3
4use byokey_types::{DEFAULT_ACCOUNT, UsageRecord, UsageStore};
5use serde::Serialize;
6use std::collections::HashMap;
7use std::sync::atomic::{AtomicU64, Ordering};
8use std::sync::{Arc, Mutex};
9use tokio::sync::mpsc;
10
11/// Global request/token counters.
12#[derive(Default)]
13pub struct UsageStats {
14    /// Total requests received.
15    pub total_requests: AtomicU64,
16    /// Successful requests (2xx from upstream).
17    pub success_requests: AtomicU64,
18    /// Failed requests (non-2xx or internal error).
19    pub failure_requests: AtomicU64,
20    /// Total input tokens across all requests.
21    pub input_tokens: AtomicU64,
22    /// Total output tokens across all requests.
23    pub output_tokens: AtomicU64,
24    /// Per-model request counts.
25    model_counts: Mutex<HashMap<String, ModelStats>>,
26}
27
28/// Per-model usage counters.
29#[derive(Default, Clone, Serialize)]
30pub struct ModelStats {
31    pub requests: u64,
32    pub success: u64,
33    pub failure: u64,
34    pub input_tokens: u64,
35    pub output_tokens: u64,
36}
37
38/// JSON-serializable snapshot of current usage.
39#[derive(Serialize)]
40pub struct UsageSnapshot {
41    pub total_requests: u64,
42    pub success_requests: u64,
43    pub failure_requests: u64,
44    pub input_tokens: u64,
45    pub output_tokens: u64,
46    pub models: HashMap<String, ModelStats>,
47}
48
49impl UsageStats {
50    /// Creates a new empty stats tracker.
51    #[must_use]
52    pub fn new() -> Self {
53        Self::default()
54    }
55
56    /// Record a successful request with optional token counts.
57    pub fn record_success(&self, model: &str, input_tokens: u64, output_tokens: u64) {
58        self.total_requests.fetch_add(1, Ordering::Relaxed);
59        self.success_requests.fetch_add(1, Ordering::Relaxed);
60        self.input_tokens.fetch_add(input_tokens, Ordering::Relaxed);
61        self.output_tokens
62            .fetch_add(output_tokens, Ordering::Relaxed);
63
64        if let Ok(mut map) = self.model_counts.lock() {
65            let entry = map.entry(model.to_string()).or_default();
66            entry.requests += 1;
67            entry.success += 1;
68            entry.input_tokens += input_tokens;
69            entry.output_tokens += output_tokens;
70        }
71    }
72
73    /// Record a failed request.
74    pub fn record_failure(&self, model: &str) {
75        self.total_requests.fetch_add(1, Ordering::Relaxed);
76        self.failure_requests.fetch_add(1, Ordering::Relaxed);
77
78        if let Ok(mut map) = self.model_counts.lock() {
79            let entry = map.entry(model.to_string()).or_default();
80            entry.requests += 1;
81            entry.failure += 1;
82        }
83    }
84
85    /// Take a JSON-serializable snapshot of current stats.
86    #[must_use]
87    pub fn snapshot(&self) -> UsageSnapshot {
88        let models = self
89            .model_counts
90            .lock()
91            .map(|m| m.clone())
92            .unwrap_or_default();
93        UsageSnapshot {
94            total_requests: self.total_requests.load(Ordering::Relaxed),
95            success_requests: self.success_requests.load(Ordering::Relaxed),
96            failure_requests: self.failure_requests.load(Ordering::Relaxed),
97            input_tokens: self.input_tokens.load(Ordering::Relaxed),
98            output_tokens: self.output_tokens.load(Ordering::Relaxed),
99            models,
100        }
101    }
102}
103
104/// Combines in-memory [`UsageStats`] with an optional persistent [`UsageStore`].
105///
106/// Every `record_*` call updates the in-memory counters immediately and, if a
107/// store is configured, sends the record to a single background task that
108/// batches writes to reduce spawn overhead and `SQLite` write contention.
109pub struct UsageRecorder {
110    stats: UsageStats,
111    store: Option<Arc<dyn UsageStore>>,
112    sender: Option<mpsc::UnboundedSender<UsageRecord>>,
113}
114
115impl UsageRecorder {
116    /// Creates a new recorder, optionally backed by a persistent store.
117    ///
118    /// When a store is provided a background flush loop is spawned that drains
119    /// records from an mpsc channel in micro-batches (up to 64 at a time).
120    #[must_use]
121    pub fn new(store: Option<Arc<dyn UsageStore>>) -> Self {
122        let sender = store.as_ref().map(|s| {
123            let (tx, rx) = mpsc::unbounded_channel::<UsageRecord>();
124            let flush_store = Arc::clone(s);
125            tokio::spawn(Self::flush_loop(flush_store, rx));
126            tx
127        });
128        Self {
129            stats: UsageStats::new(),
130            store,
131            sender,
132        }
133    }
134
135    /// Background loop that drains the record channel in micro-batches.
136    async fn flush_loop(store: Arc<dyn UsageStore>, mut rx: mpsc::UnboundedReceiver<UsageRecord>) {
137        const BATCH_CAP: usize = 64;
138        let mut buf: Vec<UsageRecord> = Vec::with_capacity(BATCH_CAP);
139
140        while let Some(record) = rx.recv().await {
141            buf.push(record);
142
143            // Drain any additional records already queued without blocking.
144            while buf.len() < BATCH_CAP {
145                match rx.try_recv() {
146                    Ok(r) => buf.push(r),
147                    Err(_) => break,
148                }
149            }
150
151            for record in buf.drain(..) {
152                if let Err(e) = store.record(&record).await {
153                    tracing::warn!(error = %e, "failed to persist usage record");
154                }
155            }
156        }
157    }
158
159    /// Record a successful request with token counts.
160    ///
161    /// Uses [`DEFAULT_ACCOUNT`] as the account attribution; call
162    /// [`record_success_for`](Self::record_success_for) when the specific
163    /// OAuth account is known.
164    pub fn record_success(
165        &self,
166        model: &str,
167        provider: &str,
168        input_tokens: u64,
169        output_tokens: u64,
170    ) {
171        self.record_success_for(
172            model,
173            provider,
174            DEFAULT_ACCOUNT,
175            input_tokens,
176            output_tokens,
177        );
178    }
179
180    /// Record a successful request with token counts, attributing it to a
181    /// specific account.
182    pub fn record_success_for(
183        &self,
184        model: &str,
185        provider: &str,
186        account_id: &str,
187        input_tokens: u64,
188        output_tokens: u64,
189    ) {
190        self.stats
191            .record_success(model, input_tokens, output_tokens);
192        self.persist(
193            model,
194            provider,
195            account_id,
196            input_tokens,
197            output_tokens,
198            true,
199        );
200    }
201
202    /// Record a failed request.
203    pub fn record_failure(&self, model: &str, provider: &str) {
204        self.record_failure_for(model, provider, DEFAULT_ACCOUNT);
205    }
206
207    /// Record a failed request, attributing it to a specific account.
208    pub fn record_failure_for(&self, model: &str, provider: &str, account_id: &str) {
209        self.stats.record_failure(model);
210        self.persist(model, provider, account_id, 0, 0, false);
211    }
212
213    /// Take a snapshot of in-memory stats.
214    #[must_use]
215    pub fn snapshot(&self) -> UsageSnapshot {
216        self.stats.snapshot()
217    }
218
219    /// Pre-load cumulative counters from historical totals (e.g. on startup).
220    pub fn preload(&self, model: &str, requests: u64, input_tokens: u64, output_tokens: u64) {
221        self.stats
222            .total_requests
223            .fetch_add(requests, Ordering::Relaxed);
224        self.stats
225            .success_requests
226            .fetch_add(requests, Ordering::Relaxed);
227        self.stats
228            .input_tokens
229            .fetch_add(input_tokens, Ordering::Relaxed);
230        self.stats
231            .output_tokens
232            .fetch_add(output_tokens, Ordering::Relaxed);
233
234        if let Ok(mut map) = self.stats.model_counts.lock() {
235            let entry = map.entry(model.to_string()).or_default();
236            entry.requests += requests;
237            entry.success += requests;
238            entry.input_tokens += input_tokens;
239            entry.output_tokens += output_tokens;
240        }
241    }
242
243    /// Returns a reference to the backing store (if configured).
244    pub fn store(&self) -> Option<&Arc<dyn UsageStore>> {
245        self.store.as_ref()
246    }
247
248    fn persist(
249        &self,
250        model: &str,
251        provider: &str,
252        account_id: &str,
253        input_tokens: u64,
254        output_tokens: u64,
255        success: bool,
256    ) {
257        if let Some(sender) = &self.sender {
258            let record = UsageRecord {
259                model: model.to_string(),
260                provider: provider.to_string(),
261                account_id: account_id.to_string(),
262                input_tokens,
263                output_tokens,
264                success,
265            };
266            let _ = sender.send(record);
267        }
268    }
269}
270
271#[cfg(test)]
272mod tests {
273    use super::*;
274
275    #[test]
276    fn test_record_success() {
277        let stats = UsageStats::new();
278        stats.record_success("claude-opus-4-5", 100, 200);
279        stats.record_success("claude-opus-4-5", 50, 100);
280        stats.record_success("gpt-4o", 80, 150);
281
282        let snap = stats.snapshot();
283        assert_eq!(snap.total_requests, 3);
284        assert_eq!(snap.success_requests, 3);
285        assert_eq!(snap.failure_requests, 0);
286        assert_eq!(snap.input_tokens, 230);
287        assert_eq!(snap.output_tokens, 450);
288
289        let claude = &snap.models["claude-opus-4-5"];
290        assert_eq!(claude.requests, 2);
291        assert_eq!(claude.success, 2);
292        assert_eq!(claude.input_tokens, 150);
293        assert_eq!(claude.output_tokens, 300);
294    }
295
296    #[test]
297    fn test_record_failure() {
298        let stats = UsageStats::new();
299        stats.record_failure("gpt-4o");
300        stats.record_success("gpt-4o", 10, 20);
301
302        let snap = stats.snapshot();
303        assert_eq!(snap.total_requests, 2);
304        assert_eq!(snap.success_requests, 1);
305        assert_eq!(snap.failure_requests, 1);
306
307        let model = &snap.models["gpt-4o"];
308        assert_eq!(model.requests, 2);
309        assert_eq!(model.failure, 1);
310        assert_eq!(model.success, 1);
311    }
312
313    #[test]
314    fn test_snapshot_empty() {
315        let stats = UsageStats::new();
316        let snap = stats.snapshot();
317        assert_eq!(snap.total_requests, 0);
318        assert!(snap.models.is_empty());
319    }
320}