crabllm_proxy/ext/
budget.rs1use crate::PREFIX_BUDGET;
2use axum::{Json, Router, routing::get};
3use crabllm_core::{BoxFuture, ExtensionError, ModelInfo, RequestContext, Storage, storage_key};
4use serde::Serialize;
5use std::{collections::HashMap, sync::Arc};
6
7pub struct Budget {
8 storage: Arc<dyn Storage>,
9 models: HashMap<String, ModelInfo>,
10 default_budget_micros: i64,
11 key_budgets: HashMap<String, i64>,
12}
13
14impl Budget {
15 pub fn new(
16 config: &serde_json::Value,
17 storage: Arc<dyn Storage>,
18 models: HashMap<String, ModelInfo>,
19 ) -> Result<Self, String> {
20 let default_budget = config
21 .get("default_budget")
22 .and_then(|v| v.as_f64())
23 .ok_or("budget: missing or invalid 'default_budget' (USD float)")?;
24
25 if default_budget <= 0.0 {
26 return Err("budget: 'default_budget' must be positive".to_string());
27 }
28
29 let default_budget_micros = (default_budget * 1_000_000.0) as i64;
30
31 let mut key_budgets = HashMap::new();
32 if let Some(keys_table) = config.get("keys").and_then(|v| v.as_object()) {
33 for (principal, key_config) in keys_table {
34 let budget = key_config
35 .get("budget")
36 .and_then(|v| v.as_f64())
37 .ok_or(format!(
38 "budget: key '{principal}' missing or invalid 'budget'"
39 ))?;
40 key_budgets.insert(principal.clone(), (budget * 1_000_000.0) as i64);
41 }
42 }
43
44 Ok(Self {
45 storage,
46 models,
47 default_budget_micros,
48 key_budgets,
49 })
50 }
51
52 fn budget_for_key(&self, principal: &str) -> i64 {
53 self.key_budgets
54 .get(principal)
55 .copied()
56 .unwrap_or(self.default_budget_micros)
57 }
58
59 fn cost_micros(&self, model: &str, provider: &str, usage: &crabllm_core::Usage) -> i64 {
60 let qualified = format!("{provider}/{model}");
61 let info = self
62 .models
63 .get(qualified.as_str())
64 .or_else(|| self.models.get(model));
65 let Some(info) = info else {
66 return 0;
67 };
68 (info.cost(usage) * 1_000_000.0).round() as i64
69 }
70
71 pub fn admin_routes(&self) -> Router {
72 let storage = self.storage.clone();
73 let prefix = PREFIX_BUDGET;
74 let default_budget = self.default_budget_micros;
75 let key_budgets = self.key_budgets.clone();
76
77 Router::new().route(
78 "/v1/budget",
79 get(move || {
80 let storage = storage.clone();
81 let key_budgets = key_budgets.clone();
82 async move { budget_handler(storage, prefix, default_budget, key_budgets).await }
83 }),
84 )
85 }
86
87 async fn record_cost(
88 &self,
89 principal: &str,
90 model: &str,
91 provider: &str,
92 usage: &crabllm_core::Usage,
93 ) {
94 let micros = self.cost_micros(model, provider, usage);
95 if micros > 0 {
96 let key = storage_key(&PREFIX_BUDGET, principal.as_bytes());
97 let _ = self.storage.increment(&key, micros).await;
98 }
99 }
100}
101
102impl crabllm_core::Extension for Budget {
103 fn name(&self) -> &str {
104 "budget"
105 }
106
107 fn prefix(&self) -> crabllm_core::Prefix {
108 PREFIX_BUDGET
109 }
110
111 fn on_request(&self, ctx: &RequestContext) -> BoxFuture<'_, Result<(), ExtensionError>> {
112 let principal = ctx
113 .principal
114 .clone()
115 .unwrap_or_else(|| "__global".to_string());
116 let budget = self.budget_for_key(&principal);
117
118 Box::pin(async move {
119 let key = storage_key(&PREFIX_BUDGET, principal.as_bytes());
120 let spent = self.storage.increment(&key, 0).await.unwrap_or(0);
121
122 if spent >= budget {
123 return Err(ExtensionError::new(
124 429,
125 "budget exceeded",
126 "budget_exceeded",
127 ));
128 }
129
130 Ok(())
131 })
132 }
133
134 fn on_response(
135 &self,
136 ctx: &RequestContext,
137 _raw_request: &[u8],
138 raw_response: &[u8],
139 ) -> BoxFuture<'_, ()> {
140 let usage = crabllm_core::Usage::from(raw_response);
141 if usage.total_tokens() == 0 {
142 return Box::pin(async {});
143 }
144
145 let principal = ctx
146 .principal
147 .clone()
148 .unwrap_or_else(|| "__global".to_string());
149 let model = ctx.model.clone();
150 let provider = ctx.provider.clone();
151
152 Box::pin(async move {
153 self.record_cost(&principal, &model, &provider, &usage)
154 .await;
155 })
156 }
157
158 fn on_chunk(&self, ctx: &RequestContext, raw_chunk: &[u8]) -> BoxFuture<'_, ()> {
159 let usage = crabllm_core::Usage::from(raw_chunk);
160 if usage.total_tokens() == 0 {
161 return Box::pin(async {});
162 }
163
164 let principal = ctx
165 .principal
166 .clone()
167 .unwrap_or_else(|| "__global".to_string());
168 let model = ctx.model.clone();
169 let provider = ctx.provider.clone();
170
171 Box::pin(async move {
172 self.record_cost(&principal, &model, &provider, &usage)
173 .await;
174 })
175 }
176}
177
178#[derive(Serialize)]
179struct BudgetEntry {
180 key: String,
181 spent_usd: f64,
182 budget_usd: f64,
183 remaining_usd: f64,
184}
185
186async fn budget_handler(
187 storage: Arc<dyn Storage>,
188 prefix: crabllm_core::Prefix,
189 default_budget_micros: i64,
190 key_budgets: HashMap<String, i64>,
191) -> Json<Vec<BudgetEntry>> {
192 let pairs = storage.list(&prefix).await.unwrap_or_default();
193
194 let mut entries = Vec::new();
195 for (raw_key, raw_value) in &pairs {
196 let suffix = match std::str::from_utf8(&raw_key[crabllm_core::PREFIX_LEN..]) {
197 Ok(s) => s,
198 Err(_) => continue,
199 };
200
201 let spent_micros = raw_value
203 .get(..8)
204 .and_then(|b| b.try_into().ok())
205 .map(i64::from_le_bytes)
206 .unwrap_or(0);
207
208 let budget_micros = key_budgets
209 .get(suffix)
210 .copied()
211 .unwrap_or(default_budget_micros);
212
213 let spent_usd = spent_micros as f64 / 1_000_000.0;
214 let budget_usd = budget_micros as f64 / 1_000_000.0;
215
216 entries.push(BudgetEntry {
217 key: suffix.to_string(),
218 spent_usd,
219 budget_usd,
220 remaining_usd: (budget_usd - spent_usd).max(0.0),
221 });
222 }
223
224 Json(entries)
225}