Skip to main content

crabllm_proxy/
admin.rs

1use crate::PREFIX_KEYS;
2use axum::{
3    Json, Router,
4    extract::{Path, Request, State},
5    http::StatusCode,
6    middleware::{self, Next},
7    response::{IntoResponse, Response},
8    routing::{get, post},
9};
10use crabllm_core::{ApiError, KeyConfig, KeyRateLimit, Storage, storage_key};
11use serde::{Deserialize, Serialize};
12use std::{
13    collections::{HashMap, HashSet},
14    sync::{Arc, RwLock},
15};
16
17#[derive(Clone)]
18pub struct KeyAdminState {
19    storage: Arc<dyn Storage>,
20    key_map: Arc<RwLock<HashMap<String, String>>>,
21    admin_token: String,
22    toml_key_names: HashSet<String>,
23    toml_keys: Vec<KeyConfig>,
24}
25
26/// Build admin key management routes, protected by admin token auth.
27pub fn key_admin_routes(
28    storage: Arc<dyn Storage>,
29    key_map: Arc<RwLock<HashMap<String, String>>>,
30    admin_token: String,
31    toml_keys: Vec<KeyConfig>,
32) -> Router {
33    let toml_key_names: HashSet<String> = toml_keys.iter().map(|k| k.name.clone()).collect();
34    let state = KeyAdminState {
35        storage,
36        key_map,
37        admin_token,
38        toml_key_names,
39        toml_keys,
40    };
41    Router::new()
42        .route("/v1/admin/keys", post(create_key).get(list_keys))
43        .route(
44            "/v1/admin/keys/{name}",
45            get(get_key).patch(update_key).delete(delete_key),
46        )
47        .route_layer(middleware::from_fn_with_state(state.clone(), admin_auth))
48        .with_state(state)
49}
50
51/// Timing-resistant token comparison. Leaks length but not content.
52fn constant_time_eq(a: &str, b: &str) -> bool {
53    if a.len() != b.len() {
54        return false;
55    }
56    let mut diff = 0u8;
57    for (x, y) in a.bytes().zip(b.bytes()) {
58        diff |= x ^ y;
59    }
60    diff == 0
61}
62
63/// Validate admin Bearer token from request headers.
64/// Returns `Ok(())` on success, `Err(Response)` with 401 on failure.
65#[allow(clippy::result_large_err)]
66pub(crate) fn check_admin_token(request: &Request, admin_token: &str) -> Result<(), Response> {
67    let token = request
68        .headers()
69        .get("authorization")
70        .and_then(|v| v.to_str().ok())
71        .and_then(|h| h.strip_prefix("Bearer "));
72
73    match token {
74        Some(t) if constant_time_eq(t, admin_token) => Ok(()),
75        _ => Err(err_response(
76            StatusCode::UNAUTHORIZED,
77            "missing or invalid admin token",
78            "authentication_error",
79        )),
80    }
81}
82
83async fn admin_auth(State(state): State<KeyAdminState>, request: Request, next: Next) -> Response {
84    if let Err(r) = check_admin_token(&request, &state.admin_token) {
85        return r;
86    }
87    next.run(request).await
88}
89
90pub(crate) fn err_response(status: StatusCode, message: &str, error_type: &str) -> Response {
91    (status, Json(ApiError::new(message, error_type))).into_response()
92}
93
94#[derive(Deserialize)]
95#[cfg_attr(feature = "openapi", derive(utoipa::ToSchema))]
96pub(crate) struct CreateKeyRequest {
97    name: String,
98    #[serde(default = "default_models")]
99    models: Vec<String>,
100    #[serde(default)]
101    rate_limit: Option<KeyRateLimit>,
102}
103
104fn default_models() -> Vec<String> {
105    vec!["*".to_string()]
106}
107
108#[derive(Serialize)]
109#[cfg_attr(feature = "openapi", derive(utoipa::ToSchema))]
110pub(crate) struct KeyResponse {
111    name: String,
112    key: String,
113    models: Vec<String>,
114    #[serde(skip_serializing_if = "Option::is_none")]
115    rate_limit: Option<KeyRateLimit>,
116}
117
118#[derive(Serialize)]
119#[cfg_attr(feature = "openapi", derive(utoipa::ToSchema))]
120pub(crate) struct KeySummary {
121    name: String,
122    key_prefix: String,
123    models: Vec<String>,
124    #[serde(skip_serializing_if = "Option::is_none")]
125    rate_limit: Option<KeyRateLimit>,
126    source: &'static str,
127}
128
129fn mask_key(key: &str) -> String {
130    let prefix: String = key.chars().take(8).collect();
131    if prefix.len() < key.len() {
132        format!("{prefix}...")
133    } else {
134        "***".to_string()
135    }
136}
137
138/// Generate a random `sk-`-prefixed secret (32 bytes of entropy, hex-encoded).
139/// Used for both admin tokens and virtual API keys.
140pub fn generate_key() -> String {
141    use rand::Rng;
142    let bytes: [u8; 32] = rand::rng().random();
143    let hex: String = bytes.iter().map(|b| format!("{b:02x}")).collect();
144    format!("sk-{hex}")
145}
146
147/// POST /v1/admin/keys — create a new virtual key.
148async fn create_key(
149    State(state): State<KeyAdminState>,
150    Json(body): Json<CreateKeyRequest>,
151) -> Response {
152    if body.name.is_empty() {
153        return err_response(
154            StatusCode::BAD_REQUEST,
155            "name is required",
156            "invalid_request_error",
157        );
158    }
159
160    // Reject names that collide with TOML-managed keys.
161    if state.toml_key_names.contains(&body.name) {
162        return err_response(
163            StatusCode::CONFLICT,
164            &format!("key '{}' is managed by config file", body.name),
165            "invalid_request_error",
166        );
167    }
168
169    // Check storage for existing name (storage is keyed by name, the
170    // authoritative source for dynamic keys).
171    let skey = storage_key(&PREFIX_KEYS, body.name.as_bytes());
172    match state.storage.get(&skey).await {
173        Ok(Some(_)) => {
174            return err_response(
175                StatusCode::CONFLICT,
176                &format!("key '{}' already exists", body.name),
177                "invalid_request_error",
178            );
179        }
180        Err(e) => {
181            return err_response(
182                StatusCode::INTERNAL_SERVER_ERROR,
183                &e.to_string(),
184                "server_error",
185            );
186        }
187        Ok(None) => {}
188    }
189
190    let key = generate_key();
191    let config = KeyConfig {
192        name: body.name.clone(),
193        key: key.clone(),
194        models: body.models.clone(),
195        rate_limit: body.rate_limit.clone(),
196    };
197
198    // Storage-first: persist before updating key_map.
199    let value = match serde_json::to_vec(&config) {
200        Ok(v) => v,
201        Err(e) => {
202            return err_response(
203                StatusCode::INTERNAL_SERVER_ERROR,
204                &e.to_string(),
205                "server_error",
206            );
207        }
208    };
209    if let Err(e) = state.storage.set(&skey, value).await {
210        return err_response(
211            StatusCode::INTERNAL_SERVER_ERROR,
212            &e.to_string(),
213            "server_error",
214        );
215    }
216
217    // Brief lock — no await while held.
218    state
219        .key_map
220        .write()
221        .unwrap_or_else(|e| e.into_inner())
222        .insert(key.clone(), body.name.clone());
223
224    (
225        StatusCode::CREATED,
226        Json(KeyResponse {
227            name: body.name,
228            key,
229            models: body.models,
230            rate_limit: body.rate_limit,
231        }),
232    )
233        .into_response()
234}
235
236/// GET /v1/admin/keys — list all virtual keys (TOML + dynamic).
237async fn list_keys(State(state): State<KeyAdminState>) -> Response {
238    let mut keys: Vec<KeySummary> = state
239        .toml_keys
240        .iter()
241        .map(|kc| KeySummary {
242            name: kc.name.clone(),
243            key_prefix: mask_key(&kc.key),
244            models: kc.models.clone(),
245            rate_limit: kc.rate_limit.clone(),
246            source: "config",
247        })
248        .collect();
249
250    let pairs = match state.storage.list(&PREFIX_KEYS).await {
251        Ok(p) => p,
252        Err(e) => {
253            return err_response(
254                StatusCode::INTERNAL_SERVER_ERROR,
255                &e.to_string(),
256                "server_error",
257            );
258        }
259    };
260
261    for (_k, v) in pairs {
262        if let Ok(kc) = serde_json::from_slice::<KeyConfig>(&v) {
263            // Skip storage keys that overlap with TOML (TOML already listed).
264            if state.toml_key_names.contains(&kc.name) {
265                continue;
266            }
267            keys.push(KeySummary {
268                name: kc.name,
269                key_prefix: mask_key(&kc.key),
270                models: kc.models,
271                rate_limit: kc.rate_limit,
272                source: "dynamic",
273            });
274        }
275    }
276
277    Json(keys).into_response()
278}
279
280/// GET /v1/admin/keys/:name — get a single key's details.
281async fn get_key(State(state): State<KeyAdminState>, Path(name): Path<String>) -> Response {
282    // Check TOML keys first.
283    if let Some(kc) = state.toml_keys.iter().find(|k| k.name == name) {
284        return Json(KeySummary {
285            name: kc.name.clone(),
286            key_prefix: mask_key(&kc.key),
287            models: kc.models.clone(),
288            rate_limit: kc.rate_limit.clone(),
289            source: "config",
290        })
291        .into_response();
292    }
293
294    let skey = storage_key(&PREFIX_KEYS, name.as_bytes());
295    match state.storage.get(&skey).await {
296        Ok(Some(bytes)) => match serde_json::from_slice::<KeyConfig>(&bytes) {
297            Ok(kc) => Json(KeySummary {
298                name: kc.name,
299                key_prefix: mask_key(&kc.key),
300                models: kc.models,
301                rate_limit: kc.rate_limit,
302                source: "dynamic",
303            })
304            .into_response(),
305            Err(e) => err_response(
306                StatusCode::INTERNAL_SERVER_ERROR,
307                &e.to_string(),
308                "server_error",
309            ),
310        },
311        Ok(None) => err_response(
312            StatusCode::NOT_FOUND,
313            &format!("key '{name}' not found"),
314            "invalid_request_error",
315        ),
316        Err(e) => err_response(
317            StatusCode::INTERNAL_SERVER_ERROR,
318            &e.to_string(),
319            "server_error",
320        ),
321    }
322}
323
324/// PATCH /v1/admin/keys/:name — partial update of a dynamic key.
325///
326/// JSON Merge Patch semantics: present fields are set, absent fields
327/// are unchanged, `null` clears the field. The `name` and `key` fields
328/// are immutable (not accepted in the patch body).
329async fn update_key(
330    State(state): State<KeyAdminState>,
331    Path(name): Path<String>,
332    Json(body): Json<serde_json::Value>,
333) -> Response {
334    if state.toml_key_names.contains(&name) {
335        return err_response(
336            StatusCode::FORBIDDEN,
337            &format!("key '{name}' is managed by config file and cannot be updated via API"),
338            "invalid_request_error",
339        );
340    }
341
342    let skey = storage_key(&PREFIX_KEYS, name.as_bytes());
343
344    // Load the existing config from storage.
345    let mut config = match state.storage.get(&skey).await {
346        Ok(Some(bytes)) => match serde_json::from_slice::<KeyConfig>(&bytes) {
347            Ok(kc) => kc,
348            Err(_) => {
349                return err_response(
350                    StatusCode::INTERNAL_SERVER_ERROR,
351                    "corrupt key data",
352                    "server_error",
353                );
354            }
355        },
356        Ok(None) => {
357            return err_response(
358                StatusCode::NOT_FOUND,
359                &format!("key '{name}' not found"),
360                "invalid_request_error",
361            );
362        }
363        Err(e) => {
364            return err_response(
365                StatusCode::INTERNAL_SERVER_ERROR,
366                &e.to_string(),
367                "server_error",
368            );
369        }
370    };
371
372    // Reject immutable fields.
373    if body.get("name").is_some() || body.get("key").is_some() {
374        return err_response(
375            StatusCode::BAD_REQUEST,
376            "'name' and 'key' are immutable and cannot be patched",
377            "invalid_request_error",
378        );
379    }
380
381    // Apply patch fields.
382    if let Some(models) = body.get("models") {
383        match serde_json::from_value::<Vec<String>>(models.clone()) {
384            Ok(m) => config.models = m,
385            Err(e) => {
386                return err_response(
387                    StatusCode::BAD_REQUEST,
388                    &format!("invalid 'models': {e}"),
389                    "invalid_request_error",
390                );
391            }
392        }
393    }
394
395    if body.get("rate_limit").is_some() {
396        match serde_json::from_value::<Option<KeyRateLimit>>(body["rate_limit"].clone()) {
397            Ok(rl) => config.rate_limit = rl,
398            Err(e) => {
399                return err_response(
400                    StatusCode::BAD_REQUEST,
401                    &format!("invalid 'rate_limit': {e}"),
402                    "invalid_request_error",
403                );
404            }
405        }
406    }
407
408    // Persist updated config.
409    let value = match serde_json::to_vec(&config) {
410        Ok(v) => v,
411        Err(e) => {
412            return err_response(
413                StatusCode::INTERNAL_SERVER_ERROR,
414                &e.to_string(),
415                "server_error",
416            );
417        }
418    };
419    if let Err(e) = state.storage.set(&skey, value).await {
420        return err_response(
421            StatusCode::INTERNAL_SERVER_ERROR,
422            &e.to_string(),
423            "server_error",
424        );
425    }
426
427    Json(KeySummary {
428        name: config.name,
429        key_prefix: mask_key(&config.key),
430        models: config.models,
431        rate_limit: config.rate_limit,
432        source: "dynamic",
433    })
434    .into_response()
435}
436
437/// DELETE /v1/admin/keys/:name — revoke a virtual key.
438async fn delete_key(State(state): State<KeyAdminState>, Path(name): Path<String>) -> Response {
439    // TOML-managed keys cannot be deleted via the API.
440    if state.toml_key_names.contains(&name) {
441        return err_response(
442            StatusCode::FORBIDDEN,
443            &format!("key '{name}' is managed by config file and cannot be deleted via API"),
444            "invalid_request_error",
445        );
446    }
447
448    let skey = storage_key(&PREFIX_KEYS, name.as_bytes());
449
450    // Load the key to find the token for key_map removal.
451    let token = match state.storage.get(&skey).await {
452        Ok(Some(bytes)) => match serde_json::from_slice::<KeyConfig>(&bytes) {
453            Ok(kc) => kc.key,
454            Err(_) => {
455                return err_response(
456                    StatusCode::INTERNAL_SERVER_ERROR,
457                    "corrupt key data",
458                    "server_error",
459                );
460            }
461        },
462        Ok(None) => {
463            return err_response(
464                StatusCode::NOT_FOUND,
465                &format!("key '{name}' not found"),
466                "invalid_request_error",
467            );
468        }
469        Err(e) => {
470            return err_response(
471                StatusCode::INTERNAL_SERVER_ERROR,
472                &e.to_string(),
473                "server_error",
474            );
475        }
476    };
477
478    // Storage-first: delete from storage before updating key_map.
479    if let Err(e) = state.storage.delete(&skey).await {
480        return err_response(
481            StatusCode::INTERNAL_SERVER_ERROR,
482            &e.to_string(),
483            "server_error",
484        );
485    }
486
487    state
488        .key_map
489        .write()
490        .unwrap_or_else(|e| e.into_inner())
491        .remove(&token);
492
493    StatusCode::NO_CONTENT.into_response()
494}
495
496/// Load keys from storage and merge with TOML config keys.
497/// TOML keys take precedence on name conflicts.
498pub async fn load_stored_keys(
499    storage: &dyn Storage,
500    toml_keys: &[KeyConfig],
501    key_map: &RwLock<HashMap<String, String>>,
502) {
503    let pairs = match storage.list(&PREFIX_KEYS).await {
504        Ok(p) => p,
505        Err(e) => {
506            tracing::warn!("failed to load stored keys: {e}");
507            return;
508        }
509    };
510
511    let toml_names: HashSet<&str> = toml_keys.iter().map(|k| k.name.as_str()).collect();
512
513    let mut map = key_map.write().unwrap_or_else(|e| e.into_inner());
514    for (_k, v) in pairs {
515        if let Ok(kc) = serde_json::from_slice::<KeyConfig>(&v) {
516            // TOML keys take precedence — skip storage keys that conflict.
517            if toml_names.contains(kc.name.as_str()) {
518                continue;
519            }
520            map.insert(kc.key, kc.name);
521        }
522    }
523}