Skip to main content

crabtalk_proxy/ext/
usage.rs

1use axum::{Json, Router, routing::get};
2use crabtalk_core::{
3    BoxFuture, ChatCompletionChunk, ChatCompletionRequest, ChatCompletionResponse, Prefix,
4    RequestContext, Storage, storage_key,
5};
6use serde::Serialize;
7use std::sync::Arc;
8
9pub struct UsageTracker {
10    storage: Arc<dyn Storage>,
11}
12
13impl UsageTracker {
14    pub fn new(_config: &serde_json::Value, storage: Arc<dyn Storage>) -> Result<Self, String> {
15        Ok(Self { storage })
16    }
17
18    /// The fixed prefix for this extension.
19    const PREFIX: Prefix = *b"usge";
20
21    pub fn admin_routes(&self) -> Router {
22        let storage = self.storage.clone();
23        let prefix = Self::PREFIX;
24        Router::new().route(
25            "/v1/usage",
26            get(move || {
27                let storage = storage.clone();
28                async move { usage_handler(storage, prefix).await }
29            }),
30        )
31    }
32
33    /// Record token usage for a given key and model.
34    async fn record(
35        &self,
36        key_name: &str,
37        model: &str,
38        prompt_tokens: u32,
39        completion_tokens: u32,
40    ) {
41        let prompt_suffix = format!("{key_name}:{model}:p");
42        let completion_suffix = format!("{key_name}:{model}:c");
43
44        let _ = self
45            .storage
46            .increment(
47                &storage_key(&Self::PREFIX, prompt_suffix.as_bytes()),
48                prompt_tokens as i64,
49            )
50            .await;
51        let _ = self
52            .storage
53            .increment(
54                &storage_key(&Self::PREFIX, completion_suffix.as_bytes()),
55                completion_tokens as i64,
56            )
57            .await;
58    }
59}
60
61impl crabtalk_core::Extension for UsageTracker {
62    fn name(&self) -> &str {
63        "usage"
64    }
65
66    fn prefix(&self) -> Prefix {
67        Self::PREFIX
68    }
69
70    fn on_response(
71        &self,
72        ctx: &RequestContext,
73        _request: &ChatCompletionRequest,
74        response: &ChatCompletionResponse,
75    ) -> BoxFuture<'_, ()> {
76        let key_name = ctx
77            .key_name
78            .clone()
79            .unwrap_or_else(|| "__global".to_string());
80        let model = ctx.model.clone();
81        let usage = response.usage.clone();
82
83        Box::pin(async move {
84            if let Some(u) = usage {
85                self.record(&key_name, &model, u.prompt_tokens, u.completion_tokens)
86                    .await;
87            }
88        })
89    }
90
91    fn on_chunk(&self, ctx: &RequestContext, chunk: &ChatCompletionChunk) -> BoxFuture<'_, ()> {
92        let key_name = ctx
93            .key_name
94            .clone()
95            .unwrap_or_else(|| "__global".to_string());
96        let model = ctx.model.clone();
97        let usage = chunk.usage.clone();
98
99        Box::pin(async move {
100            if let Some(u) = usage {
101                self.record(&key_name, &model, u.prompt_tokens, u.completion_tokens)
102                    .await;
103            }
104        })
105    }
106}
107
108#[derive(Serialize)]
109struct UsageEntry {
110    key: String,
111    model: String,
112    prompt_tokens: i64,
113    completion_tokens: i64,
114}
115
116async fn usage_handler(storage: Arc<dyn Storage>, prefix: Prefix) -> Json<Vec<UsageEntry>> {
117    let pairs = storage.list(&prefix).await.unwrap_or_default();
118
119    // Group by (key, model) — keys are PREFIX + "{key_name}:{model}:{p|c}"
120    let mut entries: std::collections::HashMap<(String, String), (i64, i64)> =
121        std::collections::HashMap::new();
122
123    for (raw_key, raw_value) in &pairs {
124        // Skip the prefix bytes, parse the suffix as UTF-8.
125        let suffix = match std::str::from_utf8(&raw_key[crabtalk_core::PREFIX_LEN..]) {
126            Ok(s) => s,
127            Err(_) => continue,
128        };
129
130        // suffix format: "{key_name}:{model}:{p|c}"
131        // Split from the right to handle key names or models containing ":"
132        let Some((rest, kind)) = suffix.rsplit_once(':') else {
133            continue;
134        };
135        let Some((key_name, model)) = rest.split_once(':') else {
136            continue;
137        };
138
139        // Parse the counter value directly from the list() result bytes.
140        let val = raw_value
141            .get(..8)
142            .and_then(|b| b.try_into().ok())
143            .map(i64::from_le_bytes)
144            .unwrap_or(0);
145
146        let entry = entries
147            .entry((key_name.to_string(), model.to_string()))
148            .or_insert((0, 0));
149
150        match kind {
151            "p" => entry.0 = val,
152            "c" => entry.1 = val,
153            _ => {}
154        }
155    }
156
157    let result: Vec<UsageEntry> = entries
158        .into_iter()
159        .map(|((key, model), (prompt, completion))| UsageEntry {
160            key,
161            model,
162            prompt_tokens: prompt,
163            completion_tokens: completion,
164        })
165        .collect();
166
167    Json(result)
168}