Skip to main content

crabllm_proxy/
admin.rs

1use axum::{
2    Json, Router,
3    extract::{Path, Request, State},
4    http::StatusCode,
5    middleware::{self, Next},
6    response::{IntoResponse, Response},
7    routing::{get, post},
8};
9use crabllm_core::{ApiError, KeyConfig, Prefix, Storage, storage_key};
10use serde::{Deserialize, Serialize};
11use std::{
12    collections::{HashMap, HashSet},
13    sync::{Arc, RwLock},
14};
15
16const KEY_PREFIX: Prefix = *b"keys";
17
18#[derive(Clone)]
19pub struct KeyAdminState {
20    storage: Arc<dyn Storage>,
21    key_map: Arc<RwLock<HashMap<String, String>>>,
22    admin_token: String,
23    toml_key_names: HashSet<String>,
24    toml_keys: Vec<KeyConfig>,
25}
26
27/// Build admin key management routes, protected by admin token auth.
28pub fn key_admin_routes(
29    storage: Arc<dyn Storage>,
30    key_map: Arc<RwLock<HashMap<String, String>>>,
31    admin_token: String,
32    toml_keys: Vec<KeyConfig>,
33) -> Router {
34    let toml_key_names: HashSet<String> = toml_keys.iter().map(|k| k.name.clone()).collect();
35    let state = KeyAdminState {
36        storage,
37        key_map,
38        admin_token,
39        toml_key_names,
40        toml_keys,
41    };
42    Router::new()
43        .route("/v1/admin/keys", post(create_key).get(list_keys))
44        .route("/v1/admin/keys/{name}", get(get_key).delete(delete_key))
45        .route_layer(middleware::from_fn_with_state(state.clone(), admin_auth))
46        .with_state(state)
47}
48
49/// Constant-time token comparison to prevent timing attacks.
50fn constant_time_eq(a: &str, b: &str) -> bool {
51    if a.len() != b.len() {
52        return false;
53    }
54    let mut diff = 0u8;
55    for (x, y) in a.bytes().zip(b.bytes()) {
56        diff |= x ^ y;
57    }
58    diff == 0
59}
60
61/// Admin auth middleware — validates Bearer token against admin_token.
62async fn admin_auth(State(state): State<KeyAdminState>, request: Request, next: Next) -> Response {
63    let auth_header = request
64        .headers()
65        .get("authorization")
66        .and_then(|v| v.to_str().ok());
67
68    let token = match auth_header.and_then(|h| h.strip_prefix("Bearer ")) {
69        Some(t) => t,
70        None => {
71            return err_response(
72                StatusCode::UNAUTHORIZED,
73                "missing or invalid Authorization header",
74                "authentication_error",
75            );
76        }
77    };
78
79    if !constant_time_eq(token, &state.admin_token) {
80        return err_response(
81            StatusCode::UNAUTHORIZED,
82            "invalid admin token",
83            "authentication_error",
84        );
85    }
86
87    next.run(request).await
88}
89
90fn err_response(status: StatusCode, message: &str, error_type: &str) -> Response {
91    (status, Json(ApiError::new(message, error_type))).into_response()
92}
93
94#[derive(Deserialize)]
95struct CreateKeyRequest {
96    name: String,
97    #[serde(default = "default_models")]
98    models: Vec<String>,
99}
100
101fn default_models() -> Vec<String> {
102    vec!["*".to_string()]
103}
104
105#[derive(Serialize)]
106struct KeyResponse {
107    name: String,
108    key: String,
109    models: Vec<String>,
110}
111
112#[derive(Serialize)]
113struct KeySummary {
114    name: String,
115    key_prefix: String,
116    models: Vec<String>,
117    source: &'static str,
118}
119
120fn mask_key(key: &str) -> String {
121    let prefix: String = key.chars().take(8).collect();
122    if prefix.len() < key.len() {
123        format!("{prefix}...")
124    } else {
125        "***".to_string()
126    }
127}
128
129fn generate_key() -> String {
130    use rand::Rng;
131    let bytes: [u8; 32] = rand::rng().random();
132    let hex: String = bytes.iter().map(|b| format!("{b:02x}")).collect();
133    format!("sk-{hex}")
134}
135
136/// POST /v1/admin/keys — create a new virtual key.
137async fn create_key(
138    State(state): State<KeyAdminState>,
139    Json(body): Json<CreateKeyRequest>,
140) -> Response {
141    if body.name.is_empty() {
142        return err_response(
143            StatusCode::BAD_REQUEST,
144            "name is required",
145            "invalid_request_error",
146        );
147    }
148
149    // Reject names that collide with TOML-managed keys.
150    if state.toml_key_names.contains(&body.name) {
151        return err_response(
152            StatusCode::CONFLICT,
153            &format!("key '{}' is managed by config file", body.name),
154            "invalid_request_error",
155        );
156    }
157
158    // Check storage for existing name (storage is keyed by name, the
159    // authoritative source for dynamic keys).
160    let skey = storage_key(&KEY_PREFIX, body.name.as_bytes());
161    match state.storage.get(&skey).await {
162        Ok(Some(_)) => {
163            return err_response(
164                StatusCode::CONFLICT,
165                &format!("key '{}' already exists", body.name),
166                "invalid_request_error",
167            );
168        }
169        Err(e) => {
170            return err_response(
171                StatusCode::INTERNAL_SERVER_ERROR,
172                &e.to_string(),
173                "server_error",
174            );
175        }
176        Ok(None) => {}
177    }
178
179    let key = generate_key();
180    let config = KeyConfig {
181        name: body.name.clone(),
182        key: key.clone(),
183        models: body.models.clone(),
184    };
185
186    // Storage-first: persist before updating key_map.
187    let value = match serde_json::to_vec(&config) {
188        Ok(v) => v,
189        Err(e) => {
190            return err_response(
191                StatusCode::INTERNAL_SERVER_ERROR,
192                &e.to_string(),
193                "server_error",
194            );
195        }
196    };
197    if let Err(e) = state.storage.set(&skey, value).await {
198        return err_response(
199            StatusCode::INTERNAL_SERVER_ERROR,
200            &e.to_string(),
201            "server_error",
202        );
203    }
204
205    // Brief lock — no await while held.
206    state
207        .key_map
208        .write()
209        .unwrap_or_else(|e| e.into_inner())
210        .insert(key.clone(), body.name.clone());
211
212    (
213        StatusCode::CREATED,
214        Json(KeyResponse {
215            name: body.name,
216            key,
217            models: body.models,
218        }),
219    )
220        .into_response()
221}
222
223/// GET /v1/admin/keys — list all virtual keys (TOML + dynamic).
224async fn list_keys(State(state): State<KeyAdminState>) -> Response {
225    let mut keys: Vec<KeySummary> = state
226        .toml_keys
227        .iter()
228        .map(|kc| KeySummary {
229            name: kc.name.clone(),
230            key_prefix: mask_key(&kc.key),
231            models: kc.models.clone(),
232            source: "config",
233        })
234        .collect();
235
236    let pairs = match state.storage.list(&KEY_PREFIX).await {
237        Ok(p) => p,
238        Err(e) => {
239            return err_response(
240                StatusCode::INTERNAL_SERVER_ERROR,
241                &e.to_string(),
242                "server_error",
243            );
244        }
245    };
246
247    for (_k, v) in pairs {
248        if let Ok(kc) = serde_json::from_slice::<KeyConfig>(&v) {
249            // Skip storage keys that overlap with TOML (TOML already listed).
250            if state.toml_key_names.contains(&kc.name) {
251                continue;
252            }
253            keys.push(KeySummary {
254                name: kc.name,
255                key_prefix: mask_key(&kc.key),
256                models: kc.models,
257                source: "dynamic",
258            });
259        }
260    }
261
262    Json(keys).into_response()
263}
264
265/// GET /v1/admin/keys/:name — get a single key's details.
266async fn get_key(State(state): State<KeyAdminState>, Path(name): Path<String>) -> Response {
267    // Check TOML keys first.
268    if let Some(kc) = state.toml_keys.iter().find(|k| k.name == name) {
269        return Json(KeySummary {
270            name: kc.name.clone(),
271            key_prefix: mask_key(&kc.key),
272            models: kc.models.clone(),
273            source: "config",
274        })
275        .into_response();
276    }
277
278    let skey = storage_key(&KEY_PREFIX, name.as_bytes());
279    match state.storage.get(&skey).await {
280        Ok(Some(bytes)) => match serde_json::from_slice::<KeyConfig>(&bytes) {
281            Ok(kc) => Json(KeySummary {
282                name: kc.name,
283                key_prefix: mask_key(&kc.key),
284                models: kc.models,
285                source: "dynamic",
286            })
287            .into_response(),
288            Err(e) => err_response(
289                StatusCode::INTERNAL_SERVER_ERROR,
290                &e.to_string(),
291                "server_error",
292            ),
293        },
294        Ok(None) => err_response(
295            StatusCode::NOT_FOUND,
296            &format!("key '{name}' not found"),
297            "invalid_request_error",
298        ),
299        Err(e) => err_response(
300            StatusCode::INTERNAL_SERVER_ERROR,
301            &e.to_string(),
302            "server_error",
303        ),
304    }
305}
306
307/// DELETE /v1/admin/keys/:name — revoke a virtual key.
308async fn delete_key(State(state): State<KeyAdminState>, Path(name): Path<String>) -> Response {
309    // TOML-managed keys cannot be deleted via the API.
310    if state.toml_key_names.contains(&name) {
311        return err_response(
312            StatusCode::FORBIDDEN,
313            &format!("key '{name}' is managed by config file and cannot be deleted via API"),
314            "invalid_request_error",
315        );
316    }
317
318    let skey = storage_key(&KEY_PREFIX, name.as_bytes());
319
320    // Load the key to find the token for key_map removal.
321    let token = match state.storage.get(&skey).await {
322        Ok(Some(bytes)) => match serde_json::from_slice::<KeyConfig>(&bytes) {
323            Ok(kc) => kc.key,
324            Err(_) => {
325                return err_response(
326                    StatusCode::INTERNAL_SERVER_ERROR,
327                    "corrupt key data",
328                    "server_error",
329                );
330            }
331        },
332        Ok(None) => {
333            return err_response(
334                StatusCode::NOT_FOUND,
335                &format!("key '{name}' not found"),
336                "invalid_request_error",
337            );
338        }
339        Err(e) => {
340            return err_response(
341                StatusCode::INTERNAL_SERVER_ERROR,
342                &e.to_string(),
343                "server_error",
344            );
345        }
346    };
347
348    // Storage-first: delete from storage before updating key_map.
349    if let Err(e) = state.storage.delete(&skey).await {
350        return err_response(
351            StatusCode::INTERNAL_SERVER_ERROR,
352            &e.to_string(),
353            "server_error",
354        );
355    }
356
357    state
358        .key_map
359        .write()
360        .unwrap_or_else(|e| e.into_inner())
361        .remove(&token);
362
363    StatusCode::NO_CONTENT.into_response()
364}
365
366/// Load keys from storage and merge with TOML config keys.
367/// TOML keys take precedence on name conflicts.
368pub async fn load_stored_keys(
369    storage: &dyn Storage,
370    toml_keys: &[KeyConfig],
371    key_map: &RwLock<HashMap<String, String>>,
372) {
373    let pairs = match storage.list(&KEY_PREFIX).await {
374        Ok(p) => p,
375        Err(e) => {
376            eprintln!("warning: failed to load stored keys: {e}");
377            return;
378        }
379    };
380
381    let toml_names: HashSet<&str> = toml_keys.iter().map(|k| k.name.as_str()).collect();
382
383    let mut map = key_map.write().unwrap_or_else(|e| e.into_inner());
384    for (_k, v) in pairs {
385        if let Ok(kc) = serde_json::from_slice::<KeyConfig>(&v) {
386            // TOML keys take precedence — skip storage keys that conflict.
387            if toml_names.contains(kc.name.as_str()) {
388                continue;
389            }
390            map.insert(kc.key, kc.name);
391        }
392    }
393}