Skip to main content

crabllm_proxy/ext/
budget.rs

1use crate::PREFIX_BUDGET;
2use axum::{Json, Router, routing::get};
3use crabllm_core::{BoxFuture, ExtensionError, ModelInfo, RequestContext, Storage, storage_key};
4use serde::Serialize;
5use std::{collections::HashMap, sync::Arc};
6
7pub struct Budget {
8    storage: Arc<dyn Storage>,
9    models: HashMap<String, ModelInfo>,
10    default_budget_micros: i64,
11    key_budgets: HashMap<String, i64>,
12}
13
14impl Budget {
15    pub fn new(
16        config: &serde_json::Value,
17        storage: Arc<dyn Storage>,
18        models: HashMap<String, ModelInfo>,
19    ) -> Result<Self, String> {
20        let default_budget = config
21            .get("default_budget")
22            .and_then(|v| v.as_f64())
23            .ok_or("budget: missing or invalid 'default_budget' (USD float)")?;
24
25        if default_budget <= 0.0 {
26            return Err("budget: 'default_budget' must be positive".to_string());
27        }
28
29        let default_budget_micros = (default_budget * 1_000_000.0) as i64;
30
31        let mut key_budgets = HashMap::new();
32        if let Some(keys_table) = config.get("keys").and_then(|v| v.as_object()) {
33            for (principal, key_config) in keys_table {
34                let budget = key_config
35                    .get("budget")
36                    .and_then(|v| v.as_f64())
37                    .ok_or(format!(
38                        "budget: key '{principal}' missing or invalid 'budget'"
39                    ))?;
40                key_budgets.insert(principal.clone(), (budget * 1_000_000.0) as i64);
41            }
42        }
43
44        Ok(Self {
45            storage,
46            models,
47            default_budget_micros,
48            key_budgets,
49        })
50    }
51
52    fn budget_for_key(&self, principal: &str) -> i64 {
53        self.key_budgets
54            .get(principal)
55            .copied()
56            .unwrap_or(self.default_budget_micros)
57    }
58
59    fn cost_micros(&self, model: &str, provider: &str, usage: &crabllm_core::Usage) -> i64 {
60        let qualified = format!("{provider}/{model}");
61        let info = self
62            .models
63            .get(qualified.as_str())
64            .or_else(|| self.models.get(model));
65        let Some(info) = info else {
66            return 0;
67        };
68        (info.cost(usage) * 1_000_000.0).round() as i64
69    }
70
71    pub fn admin_routes(&self) -> Router {
72        let storage = self.storage.clone();
73        let prefix = PREFIX_BUDGET;
74        let default_budget = self.default_budget_micros;
75        let key_budgets = self.key_budgets.clone();
76
77        Router::new().route(
78            "/v1/budget",
79            get(move || {
80                let storage = storage.clone();
81                let key_budgets = key_budgets.clone();
82                async move { budget_handler(storage, prefix, default_budget, key_budgets).await }
83            }),
84        )
85    }
86
87    async fn record_cost(
88        &self,
89        principal: &str,
90        model: &str,
91        provider: &str,
92        usage: &crabllm_core::Usage,
93    ) {
94        let micros = self.cost_micros(model, provider, usage);
95        if micros > 0 {
96            let key = storage_key(&PREFIX_BUDGET, principal.as_bytes());
97            let _ = self.storage.increment(&key, micros).await;
98        }
99    }
100}
101
102impl crabllm_core::Extension for Budget {
103    fn name(&self) -> &str {
104        "budget"
105    }
106
107    fn prefix(&self) -> crabllm_core::Prefix {
108        PREFIX_BUDGET
109    }
110
111    fn on_request(&self, ctx: &RequestContext) -> BoxFuture<'_, Result<(), ExtensionError>> {
112        let principal = ctx
113            .principal
114            .clone()
115            .unwrap_or_else(|| "__global".to_string());
116        let budget = self.budget_for_key(&principal);
117
118        Box::pin(async move {
119            let key = storage_key(&PREFIX_BUDGET, principal.as_bytes());
120            let spent = self.storage.increment(&key, 0).await.unwrap_or(0);
121
122            if spent >= budget {
123                return Err(ExtensionError::new(
124                    429,
125                    "budget exceeded",
126                    "budget_exceeded",
127                ));
128            }
129
130            Ok(())
131        })
132    }
133
134    fn on_response(
135        &self,
136        ctx: &RequestContext,
137        _raw_request: &[u8],
138        raw_response: &[u8],
139    ) -> BoxFuture<'_, ()> {
140        let usage = crabllm_core::Usage::from(raw_response);
141        if usage.total_tokens() == 0 {
142            return Box::pin(async {});
143        }
144
145        let principal = ctx
146            .principal
147            .clone()
148            .unwrap_or_else(|| "__global".to_string());
149        let model = ctx.model.clone();
150        let provider = ctx.provider.clone();
151
152        Box::pin(async move {
153            self.record_cost(&principal, &model, &provider, &usage)
154                .await;
155        })
156    }
157
158    fn on_chunk(&self, ctx: &RequestContext, raw_chunk: &[u8]) -> BoxFuture<'_, ()> {
159        let usage = crabllm_core::Usage::from(raw_chunk);
160        if usage.total_tokens() == 0 {
161            return Box::pin(async {});
162        }
163
164        let principal = ctx
165            .principal
166            .clone()
167            .unwrap_or_else(|| "__global".to_string());
168        let model = ctx.model.clone();
169        let provider = ctx.provider.clone();
170
171        Box::pin(async move {
172            self.record_cost(&principal, &model, &provider, &usage)
173                .await;
174        })
175    }
176}
177
178#[derive(Serialize)]
179struct BudgetEntry {
180    key: String,
181    spent_usd: f64,
182    budget_usd: f64,
183    remaining_usd: f64,
184}
185
186async fn budget_handler(
187    storage: Arc<dyn Storage>,
188    prefix: crabllm_core::Prefix,
189    default_budget_micros: i64,
190    key_budgets: HashMap<String, i64>,
191) -> Json<Vec<BudgetEntry>> {
192    let pairs = storage.list(&prefix).await.unwrap_or_default();
193
194    let mut entries = Vec::new();
195    for (raw_key, raw_value) in &pairs {
196        let suffix = match std::str::from_utf8(&raw_key[crabllm_core::PREFIX_LEN..]) {
197            Ok(s) => s,
198            Err(_) => continue,
199        };
200
201        // Parse the counter value directly from the list() result bytes.
202        let spent_micros = raw_value
203            .get(..8)
204            .and_then(|b| b.try_into().ok())
205            .map(i64::from_le_bytes)
206            .unwrap_or(0);
207
208        let budget_micros = key_budgets
209            .get(suffix)
210            .copied()
211            .unwrap_or(default_budget_micros);
212
213        let spent_usd = spent_micros as f64 / 1_000_000.0;
214        let budget_usd = budget_micros as f64 / 1_000_000.0;
215
216        entries.push(BudgetEntry {
217            key: suffix.to_string(),
218            spent_usd,
219            budget_usd,
220            remaining_usd: (budget_usd - spent_usd).max(0.0),
221        });
222    }
223
224    Json(entries)
225}