Skip to main content

byokey_proxy/
usage.rs

1//! In-memory usage statistics for request/token tracking, with optional
2//! persistent backing via [`UsageStore`].
3
4use byokey_types::{UsageRecord, UsageStore};
5use serde::Serialize;
6use std::collections::HashMap;
7use std::sync::atomic::{AtomicU64, Ordering};
8use std::sync::{Arc, Mutex};
9use tokio::sync::mpsc;
10use utoipa::ToSchema;
11
12/// Global request/token counters.
13#[derive(Default)]
14pub struct UsageStats {
15    /// Total requests received.
16    pub total_requests: AtomicU64,
17    /// Successful requests (2xx from upstream).
18    pub success_requests: AtomicU64,
19    /// Failed requests (non-2xx or internal error).
20    pub failure_requests: AtomicU64,
21    /// Total input tokens across all requests.
22    pub input_tokens: AtomicU64,
23    /// Total output tokens across all requests.
24    pub output_tokens: AtomicU64,
25    /// Per-model request counts.
26    model_counts: Mutex<HashMap<String, ModelStats>>,
27}
28
29/// Per-model usage counters.
30#[derive(Default, Clone, Serialize, ToSchema)]
31pub struct ModelStats {
32    pub requests: u64,
33    pub success: u64,
34    pub failure: u64,
35    pub input_tokens: u64,
36    pub output_tokens: u64,
37}
38
39/// JSON-serializable snapshot of current usage.
40#[derive(Serialize, ToSchema)]
41pub struct UsageSnapshot {
42    pub total_requests: u64,
43    pub success_requests: u64,
44    pub failure_requests: u64,
45    pub input_tokens: u64,
46    pub output_tokens: u64,
47    pub models: HashMap<String, ModelStats>,
48}
49
50impl UsageStats {
51    /// Creates a new empty stats tracker.
52    #[must_use]
53    pub fn new() -> Self {
54        Self::default()
55    }
56
57    /// Record a successful request with optional token counts.
58    pub fn record_success(&self, model: &str, input_tokens: u64, output_tokens: u64) {
59        self.total_requests.fetch_add(1, Ordering::Relaxed);
60        self.success_requests.fetch_add(1, Ordering::Relaxed);
61        self.input_tokens.fetch_add(input_tokens, Ordering::Relaxed);
62        self.output_tokens
63            .fetch_add(output_tokens, Ordering::Relaxed);
64
65        if let Ok(mut map) = self.model_counts.lock() {
66            let entry = map.entry(model.to_string()).or_default();
67            entry.requests += 1;
68            entry.success += 1;
69            entry.input_tokens += input_tokens;
70            entry.output_tokens += output_tokens;
71        }
72    }
73
74    /// Record a failed request.
75    pub fn record_failure(&self, model: &str) {
76        self.total_requests.fetch_add(1, Ordering::Relaxed);
77        self.failure_requests.fetch_add(1, Ordering::Relaxed);
78
79        if let Ok(mut map) = self.model_counts.lock() {
80            let entry = map.entry(model.to_string()).or_default();
81            entry.requests += 1;
82            entry.failure += 1;
83        }
84    }
85
86    /// Take a JSON-serializable snapshot of current stats.
87    #[must_use]
88    pub fn snapshot(&self) -> UsageSnapshot {
89        let models = self
90            .model_counts
91            .lock()
92            .map(|m| m.clone())
93            .unwrap_or_default();
94        UsageSnapshot {
95            total_requests: self.total_requests.load(Ordering::Relaxed),
96            success_requests: self.success_requests.load(Ordering::Relaxed),
97            failure_requests: self.failure_requests.load(Ordering::Relaxed),
98            input_tokens: self.input_tokens.load(Ordering::Relaxed),
99            output_tokens: self.output_tokens.load(Ordering::Relaxed),
100            models,
101        }
102    }
103}
104
105/// Combines in-memory [`UsageStats`] with an optional persistent [`UsageStore`].
106///
107/// Every `record_*` call updates the in-memory counters immediately and, if a
108/// store is configured, sends the record to a single background task that
109/// batches writes to reduce spawn overhead and `SQLite` write contention.
110pub struct UsageRecorder {
111    stats: UsageStats,
112    store: Option<Arc<dyn UsageStore>>,
113    sender: Option<mpsc::UnboundedSender<UsageRecord>>,
114}
115
116impl UsageRecorder {
117    /// Creates a new recorder, optionally backed by a persistent store.
118    ///
119    /// When a store is provided a background flush loop is spawned that drains
120    /// records from an mpsc channel in micro-batches (up to 64 at a time).
121    #[must_use]
122    pub fn new(store: Option<Arc<dyn UsageStore>>) -> Self {
123        let sender = store.as_ref().map(|s| {
124            let (tx, rx) = mpsc::unbounded_channel::<UsageRecord>();
125            let flush_store = Arc::clone(s);
126            tokio::spawn(Self::flush_loop(flush_store, rx));
127            tx
128        });
129        Self {
130            stats: UsageStats::new(),
131            store,
132            sender,
133        }
134    }
135
136    /// Background loop that drains the record channel in micro-batches.
137    async fn flush_loop(store: Arc<dyn UsageStore>, mut rx: mpsc::UnboundedReceiver<UsageRecord>) {
138        const BATCH_CAP: usize = 64;
139        let mut buf: Vec<UsageRecord> = Vec::with_capacity(BATCH_CAP);
140
141        while let Some(record) = rx.recv().await {
142            buf.push(record);
143
144            // Drain any additional records already queued without blocking.
145            while buf.len() < BATCH_CAP {
146                match rx.try_recv() {
147                    Ok(r) => buf.push(r),
148                    Err(_) => break,
149                }
150            }
151
152            for record in buf.drain(..) {
153                if let Err(e) = store.record(&record).await {
154                    tracing::warn!(error = %e, "failed to persist usage record");
155                }
156            }
157        }
158    }
159
160    /// Record a successful request with token counts.
161    pub fn record_success(
162        &self,
163        model: &str,
164        provider: &str,
165        input_tokens: u64,
166        output_tokens: u64,
167    ) {
168        self.stats
169            .record_success(model, input_tokens, output_tokens);
170        self.persist(model, provider, input_tokens, output_tokens, true);
171    }
172
173    /// Record a failed request.
174    pub fn record_failure(&self, model: &str, provider: &str) {
175        self.stats.record_failure(model);
176        self.persist(model, provider, 0, 0, false);
177    }
178
179    /// Take a snapshot of in-memory stats.
180    #[must_use]
181    pub fn snapshot(&self) -> UsageSnapshot {
182        self.stats.snapshot()
183    }
184
185    /// Pre-load cumulative counters from historical totals (e.g. on startup).
186    pub fn preload(&self, model: &str, requests: u64, input_tokens: u64, output_tokens: u64) {
187        self.stats
188            .total_requests
189            .fetch_add(requests, Ordering::Relaxed);
190        self.stats
191            .success_requests
192            .fetch_add(requests, Ordering::Relaxed);
193        self.stats
194            .input_tokens
195            .fetch_add(input_tokens, Ordering::Relaxed);
196        self.stats
197            .output_tokens
198            .fetch_add(output_tokens, Ordering::Relaxed);
199
200        if let Ok(mut map) = self.stats.model_counts.lock() {
201            let entry = map.entry(model.to_string()).or_default();
202            entry.requests += requests;
203            entry.success += requests;
204            entry.input_tokens += input_tokens;
205            entry.output_tokens += output_tokens;
206        }
207    }
208
209    /// Returns a reference to the backing store (if configured).
210    pub fn store(&self) -> Option<&Arc<dyn UsageStore>> {
211        self.store.as_ref()
212    }
213
214    fn persist(
215        &self,
216        model: &str,
217        provider: &str,
218        input_tokens: u64,
219        output_tokens: u64,
220        success: bool,
221    ) {
222        if let Some(sender) = &self.sender {
223            let record = UsageRecord {
224                model: model.to_string(),
225                provider: provider.to_string(),
226                input_tokens,
227                output_tokens,
228                success,
229            };
230            let _ = sender.send(record);
231        }
232    }
233}
234
235#[cfg(test)]
236mod tests {
237    use super::*;
238
239    #[test]
240    fn test_record_success() {
241        let stats = UsageStats::new();
242        stats.record_success("claude-opus-4-5", 100, 200);
243        stats.record_success("claude-opus-4-5", 50, 100);
244        stats.record_success("gpt-4o", 80, 150);
245
246        let snap = stats.snapshot();
247        assert_eq!(snap.total_requests, 3);
248        assert_eq!(snap.success_requests, 3);
249        assert_eq!(snap.failure_requests, 0);
250        assert_eq!(snap.input_tokens, 230);
251        assert_eq!(snap.output_tokens, 450);
252
253        let claude = &snap.models["claude-opus-4-5"];
254        assert_eq!(claude.requests, 2);
255        assert_eq!(claude.success, 2);
256        assert_eq!(claude.input_tokens, 150);
257        assert_eq!(claude.output_tokens, 300);
258    }
259
260    #[test]
261    fn test_record_failure() {
262        let stats = UsageStats::new();
263        stats.record_failure("gpt-4o");
264        stats.record_success("gpt-4o", 10, 20);
265
266        let snap = stats.snapshot();
267        assert_eq!(snap.total_requests, 2);
268        assert_eq!(snap.success_requests, 1);
269        assert_eq!(snap.failure_requests, 1);
270
271        let model = &snap.models["gpt-4o"];
272        assert_eq!(model.requests, 2);
273        assert_eq!(model.failure, 1);
274        assert_eq!(model.success, 1);
275    }
276
277    #[test]
278    fn test_snapshot_empty() {
279        let stats = UsageStats::new();
280        let snap = stats.snapshot();
281        assert_eq!(snap.total_requests, 0);
282        assert!(snap.models.is_empty());
283    }
284}