crabtalk_proxy/ext/
usage.rs1use axum::{Json, Router, routing::get};
2use crabtalk_core::{
3 BoxFuture, ChatCompletionChunk, ChatCompletionRequest, ChatCompletionResponse, Prefix,
4 RequestContext, Storage, storage_key,
5};
6use serde::Serialize;
7use std::sync::Arc;
8
9pub struct UsageTracker {
10 storage: Arc<dyn Storage>,
11}
12
13impl UsageTracker {
14 pub fn new(_config: &serde_json::Value, storage: Arc<dyn Storage>) -> Result<Self, String> {
15 Ok(Self { storage })
16 }
17
18 const PREFIX: Prefix = *b"usge";
20
21 pub fn admin_routes(&self) -> Router {
22 let storage = self.storage.clone();
23 let prefix = Self::PREFIX;
24 Router::new().route(
25 "/v1/usage",
26 get(move || {
27 let storage = storage.clone();
28 async move { usage_handler(storage, prefix).await }
29 }),
30 )
31 }
32
33 async fn record(
35 &self,
36 key_name: &str,
37 model: &str,
38 prompt_tokens: u32,
39 completion_tokens: u32,
40 ) {
41 let prompt_suffix = format!("{key_name}:{model}:p");
42 let completion_suffix = format!("{key_name}:{model}:c");
43
44 let _ = self
45 .storage
46 .increment(
47 &storage_key(&Self::PREFIX, prompt_suffix.as_bytes()),
48 prompt_tokens as i64,
49 )
50 .await;
51 let _ = self
52 .storage
53 .increment(
54 &storage_key(&Self::PREFIX, completion_suffix.as_bytes()),
55 completion_tokens as i64,
56 )
57 .await;
58 }
59}
60
61impl crabtalk_core::Extension for UsageTracker {
62 fn name(&self) -> &str {
63 "usage"
64 }
65
66 fn prefix(&self) -> Prefix {
67 Self::PREFIX
68 }
69
70 fn on_response(
71 &self,
72 ctx: &RequestContext,
73 _request: &ChatCompletionRequest,
74 response: &ChatCompletionResponse,
75 ) -> BoxFuture<'_, ()> {
76 let key_name = ctx
77 .key_name
78 .clone()
79 .unwrap_or_else(|| "__global".to_string());
80 let model = ctx.model.clone();
81 let usage = response.usage.clone();
82
83 Box::pin(async move {
84 if let Some(u) = usage {
85 self.record(&key_name, &model, u.prompt_tokens, u.completion_tokens)
86 .await;
87 }
88 })
89 }
90
91 fn on_chunk(&self, ctx: &RequestContext, chunk: &ChatCompletionChunk) -> BoxFuture<'_, ()> {
92 let key_name = ctx
93 .key_name
94 .clone()
95 .unwrap_or_else(|| "__global".to_string());
96 let model = ctx.model.clone();
97 let usage = chunk.usage.clone();
98
99 Box::pin(async move {
100 if let Some(u) = usage {
101 self.record(&key_name, &model, u.prompt_tokens, u.completion_tokens)
102 .await;
103 }
104 })
105 }
106}
107
108#[derive(Serialize)]
109struct UsageEntry {
110 key: String,
111 model: String,
112 prompt_tokens: i64,
113 completion_tokens: i64,
114}
115
116async fn usage_handler(storage: Arc<dyn Storage>, prefix: Prefix) -> Json<Vec<UsageEntry>> {
117 let pairs = storage.list(&prefix).await.unwrap_or_default();
118
119 let mut entries: std::collections::HashMap<(String, String), (i64, i64)> =
121 std::collections::HashMap::new();
122
123 for (raw_key, raw_value) in &pairs {
124 let suffix = match std::str::from_utf8(&raw_key[crabtalk_core::PREFIX_LEN..]) {
126 Ok(s) => s,
127 Err(_) => continue,
128 };
129
130 let Some((rest, kind)) = suffix.rsplit_once(':') else {
133 continue;
134 };
135 let Some((key_name, model)) = rest.split_once(':') else {
136 continue;
137 };
138
139 let val = raw_value
141 .get(..8)
142 .and_then(|b| b.try_into().ok())
143 .map(i64::from_le_bytes)
144 .unwrap_or(0);
145
146 let entry = entries
147 .entry((key_name.to_string(), model.to_string()))
148 .or_insert((0, 0));
149
150 match kind {
151 "p" => entry.0 = val,
152 "c" => entry.1 = val,
153 _ => {}
154 }
155 }
156
157 let result: Vec<UsageEntry> = entries
158 .into_iter()
159 .map(|((key, model), (prompt, completion))| UsageEntry {
160 key,
161 model,
162 prompt_tokens: prompt,
163 completion_tokens: completion,
164 })
165 .collect();
166
167 Json(result)
168}