1use crate::PREFIX_KEYS;
2use axum::{
3 Json, Router,
4 extract::{Path, Request, State},
5 http::StatusCode,
6 middleware::{self, Next},
7 response::{IntoResponse, Response},
8 routing::{get, post},
9};
10use crabllm_core::{ApiError, KeyConfig, KeyRateLimit, Storage, storage_key};
11use serde::{Deserialize, Serialize};
12use std::{
13 collections::{HashMap, HashSet},
14 sync::{Arc, RwLock},
15};
16
17#[derive(Clone)]
18pub struct KeyAdminState {
19 storage: Arc<dyn Storage>,
20 key_map: Arc<RwLock<HashMap<String, String>>>,
21 admin_token: String,
22 toml_key_names: HashSet<String>,
23 toml_keys: Vec<KeyConfig>,
24}
25
26pub fn key_admin_routes(
28 storage: Arc<dyn Storage>,
29 key_map: Arc<RwLock<HashMap<String, String>>>,
30 admin_token: String,
31 toml_keys: Vec<KeyConfig>,
32) -> Router {
33 let toml_key_names: HashSet<String> = toml_keys.iter().map(|k| k.name.clone()).collect();
34 let state = KeyAdminState {
35 storage,
36 key_map,
37 admin_token,
38 toml_key_names,
39 toml_keys,
40 };
41 Router::new()
42 .route("/v1/admin/keys", post(create_key).get(list_keys))
43 .route(
44 "/v1/admin/keys/{name}",
45 get(get_key).patch(update_key).delete(delete_key),
46 )
47 .route_layer(middleware::from_fn_with_state(state.clone(), admin_auth))
48 .with_state(state)
49}
50
51fn constant_time_eq(a: &str, b: &str) -> bool {
53 if a.len() != b.len() {
54 return false;
55 }
56 let mut diff = 0u8;
57 for (x, y) in a.bytes().zip(b.bytes()) {
58 diff |= x ^ y;
59 }
60 diff == 0
61}
62
63#[allow(clippy::result_large_err)]
66pub(crate) fn check_admin_token(request: &Request, admin_token: &str) -> Result<(), Response> {
67 let token = request
68 .headers()
69 .get("authorization")
70 .and_then(|v| v.to_str().ok())
71 .and_then(|h| h.strip_prefix("Bearer "));
72
73 match token {
74 Some(t) if constant_time_eq(t, admin_token) => Ok(()),
75 _ => Err(err_response(
76 StatusCode::UNAUTHORIZED,
77 "missing or invalid admin token",
78 "authentication_error",
79 )),
80 }
81}
82
83async fn admin_auth(State(state): State<KeyAdminState>, request: Request, next: Next) -> Response {
84 if let Err(r) = check_admin_token(&request, &state.admin_token) {
85 return r;
86 }
87 next.run(request).await
88}
89
90pub(crate) fn err_response(status: StatusCode, message: &str, error_type: &str) -> Response {
91 (status, Json(ApiError::new(message, error_type))).into_response()
92}
93
94#[derive(Deserialize)]
95#[cfg_attr(feature = "openapi", derive(utoipa::ToSchema))]
96pub(crate) struct CreateKeyRequest {
97 name: String,
98 #[serde(default = "default_models")]
99 models: Vec<String>,
100 #[serde(default)]
101 rate_limit: Option<KeyRateLimit>,
102}
103
104fn default_models() -> Vec<String> {
105 vec!["*".to_string()]
106}
107
108#[derive(Serialize)]
109#[cfg_attr(feature = "openapi", derive(utoipa::ToSchema))]
110pub(crate) struct KeyResponse {
111 name: String,
112 key: String,
113 models: Vec<String>,
114 #[serde(skip_serializing_if = "Option::is_none")]
115 rate_limit: Option<KeyRateLimit>,
116}
117
118#[derive(Serialize)]
119#[cfg_attr(feature = "openapi", derive(utoipa::ToSchema))]
120pub(crate) struct KeySummary {
121 name: String,
122 key_prefix: String,
123 models: Vec<String>,
124 #[serde(skip_serializing_if = "Option::is_none")]
125 rate_limit: Option<KeyRateLimit>,
126 source: &'static str,
127}
128
129fn mask_key(key: &str) -> String {
130 let prefix: String = key.chars().take(8).collect();
131 if prefix.len() < key.len() {
132 format!("{prefix}...")
133 } else {
134 "***".to_string()
135 }
136}
137
138pub fn generate_key() -> String {
141 use rand::Rng;
142 let bytes: [u8; 32] = rand::rng().random();
143 let hex: String = bytes.iter().map(|b| format!("{b:02x}")).collect();
144 format!("sk-{hex}")
145}
146
147async fn create_key(
149 State(state): State<KeyAdminState>,
150 Json(body): Json<CreateKeyRequest>,
151) -> Response {
152 if body.name.is_empty() {
153 return err_response(
154 StatusCode::BAD_REQUEST,
155 "name is required",
156 "invalid_request_error",
157 );
158 }
159
160 if state.toml_key_names.contains(&body.name) {
162 return err_response(
163 StatusCode::CONFLICT,
164 &format!("key '{}' is managed by config file", body.name),
165 "invalid_request_error",
166 );
167 }
168
169 let skey = storage_key(&PREFIX_KEYS, body.name.as_bytes());
172 match state.storage.get(&skey).await {
173 Ok(Some(_)) => {
174 return err_response(
175 StatusCode::CONFLICT,
176 &format!("key '{}' already exists", body.name),
177 "invalid_request_error",
178 );
179 }
180 Err(e) => {
181 return err_response(
182 StatusCode::INTERNAL_SERVER_ERROR,
183 &e.to_string(),
184 "server_error",
185 );
186 }
187 Ok(None) => {}
188 }
189
190 let key = generate_key();
191 let config = KeyConfig {
192 name: body.name.clone(),
193 key: key.clone(),
194 models: body.models.clone(),
195 rate_limit: body.rate_limit.clone(),
196 };
197
198 let value = match serde_json::to_vec(&config) {
200 Ok(v) => v,
201 Err(e) => {
202 return err_response(
203 StatusCode::INTERNAL_SERVER_ERROR,
204 &e.to_string(),
205 "server_error",
206 );
207 }
208 };
209 if let Err(e) = state.storage.set(&skey, value).await {
210 return err_response(
211 StatusCode::INTERNAL_SERVER_ERROR,
212 &e.to_string(),
213 "server_error",
214 );
215 }
216
217 state
219 .key_map
220 .write()
221 .unwrap_or_else(|e| e.into_inner())
222 .insert(key.clone(), body.name.clone());
223
224 (
225 StatusCode::CREATED,
226 Json(KeyResponse {
227 name: body.name,
228 key,
229 models: body.models,
230 rate_limit: body.rate_limit,
231 }),
232 )
233 .into_response()
234}
235
236async fn list_keys(State(state): State<KeyAdminState>) -> Response {
238 let mut keys: Vec<KeySummary> = state
239 .toml_keys
240 .iter()
241 .map(|kc| KeySummary {
242 name: kc.name.clone(),
243 key_prefix: mask_key(&kc.key),
244 models: kc.models.clone(),
245 rate_limit: kc.rate_limit.clone(),
246 source: "config",
247 })
248 .collect();
249
250 let pairs = match state.storage.list(&PREFIX_KEYS).await {
251 Ok(p) => p,
252 Err(e) => {
253 return err_response(
254 StatusCode::INTERNAL_SERVER_ERROR,
255 &e.to_string(),
256 "server_error",
257 );
258 }
259 };
260
261 for (_k, v) in pairs {
262 if let Ok(kc) = serde_json::from_slice::<KeyConfig>(&v) {
263 if state.toml_key_names.contains(&kc.name) {
265 continue;
266 }
267 keys.push(KeySummary {
268 name: kc.name,
269 key_prefix: mask_key(&kc.key),
270 models: kc.models,
271 rate_limit: kc.rate_limit,
272 source: "dynamic",
273 });
274 }
275 }
276
277 Json(keys).into_response()
278}
279
280async fn get_key(State(state): State<KeyAdminState>, Path(name): Path<String>) -> Response {
282 if let Some(kc) = state.toml_keys.iter().find(|k| k.name == name) {
284 return Json(KeySummary {
285 name: kc.name.clone(),
286 key_prefix: mask_key(&kc.key),
287 models: kc.models.clone(),
288 rate_limit: kc.rate_limit.clone(),
289 source: "config",
290 })
291 .into_response();
292 }
293
294 let skey = storage_key(&PREFIX_KEYS, name.as_bytes());
295 match state.storage.get(&skey).await {
296 Ok(Some(bytes)) => match serde_json::from_slice::<KeyConfig>(&bytes) {
297 Ok(kc) => Json(KeySummary {
298 name: kc.name,
299 key_prefix: mask_key(&kc.key),
300 models: kc.models,
301 rate_limit: kc.rate_limit,
302 source: "dynamic",
303 })
304 .into_response(),
305 Err(e) => err_response(
306 StatusCode::INTERNAL_SERVER_ERROR,
307 &e.to_string(),
308 "server_error",
309 ),
310 },
311 Ok(None) => err_response(
312 StatusCode::NOT_FOUND,
313 &format!("key '{name}' not found"),
314 "invalid_request_error",
315 ),
316 Err(e) => err_response(
317 StatusCode::INTERNAL_SERVER_ERROR,
318 &e.to_string(),
319 "server_error",
320 ),
321 }
322}
323
324async fn update_key(
330 State(state): State<KeyAdminState>,
331 Path(name): Path<String>,
332 Json(body): Json<serde_json::Value>,
333) -> Response {
334 if state.toml_key_names.contains(&name) {
335 return err_response(
336 StatusCode::FORBIDDEN,
337 &format!("key '{name}' is managed by config file and cannot be updated via API"),
338 "invalid_request_error",
339 );
340 }
341
342 let skey = storage_key(&PREFIX_KEYS, name.as_bytes());
343
344 let mut config = match state.storage.get(&skey).await {
346 Ok(Some(bytes)) => match serde_json::from_slice::<KeyConfig>(&bytes) {
347 Ok(kc) => kc,
348 Err(_) => {
349 return err_response(
350 StatusCode::INTERNAL_SERVER_ERROR,
351 "corrupt key data",
352 "server_error",
353 );
354 }
355 },
356 Ok(None) => {
357 return err_response(
358 StatusCode::NOT_FOUND,
359 &format!("key '{name}' not found"),
360 "invalid_request_error",
361 );
362 }
363 Err(e) => {
364 return err_response(
365 StatusCode::INTERNAL_SERVER_ERROR,
366 &e.to_string(),
367 "server_error",
368 );
369 }
370 };
371
372 if body.get("name").is_some() || body.get("key").is_some() {
374 return err_response(
375 StatusCode::BAD_REQUEST,
376 "'name' and 'key' are immutable and cannot be patched",
377 "invalid_request_error",
378 );
379 }
380
381 if let Some(models) = body.get("models") {
383 match serde_json::from_value::<Vec<String>>(models.clone()) {
384 Ok(m) => config.models = m,
385 Err(e) => {
386 return err_response(
387 StatusCode::BAD_REQUEST,
388 &format!("invalid 'models': {e}"),
389 "invalid_request_error",
390 );
391 }
392 }
393 }
394
395 if body.get("rate_limit").is_some() {
396 match serde_json::from_value::<Option<KeyRateLimit>>(body["rate_limit"].clone()) {
397 Ok(rl) => config.rate_limit = rl,
398 Err(e) => {
399 return err_response(
400 StatusCode::BAD_REQUEST,
401 &format!("invalid 'rate_limit': {e}"),
402 "invalid_request_error",
403 );
404 }
405 }
406 }
407
408 let value = match serde_json::to_vec(&config) {
410 Ok(v) => v,
411 Err(e) => {
412 return err_response(
413 StatusCode::INTERNAL_SERVER_ERROR,
414 &e.to_string(),
415 "server_error",
416 );
417 }
418 };
419 if let Err(e) = state.storage.set(&skey, value).await {
420 return err_response(
421 StatusCode::INTERNAL_SERVER_ERROR,
422 &e.to_string(),
423 "server_error",
424 );
425 }
426
427 Json(KeySummary {
428 name: config.name,
429 key_prefix: mask_key(&config.key),
430 models: config.models,
431 rate_limit: config.rate_limit,
432 source: "dynamic",
433 })
434 .into_response()
435}
436
437async fn delete_key(State(state): State<KeyAdminState>, Path(name): Path<String>) -> Response {
439 if state.toml_key_names.contains(&name) {
441 return err_response(
442 StatusCode::FORBIDDEN,
443 &format!("key '{name}' is managed by config file and cannot be deleted via API"),
444 "invalid_request_error",
445 );
446 }
447
448 let skey = storage_key(&PREFIX_KEYS, name.as_bytes());
449
450 let token = match state.storage.get(&skey).await {
452 Ok(Some(bytes)) => match serde_json::from_slice::<KeyConfig>(&bytes) {
453 Ok(kc) => kc.key,
454 Err(_) => {
455 return err_response(
456 StatusCode::INTERNAL_SERVER_ERROR,
457 "corrupt key data",
458 "server_error",
459 );
460 }
461 },
462 Ok(None) => {
463 return err_response(
464 StatusCode::NOT_FOUND,
465 &format!("key '{name}' not found"),
466 "invalid_request_error",
467 );
468 }
469 Err(e) => {
470 return err_response(
471 StatusCode::INTERNAL_SERVER_ERROR,
472 &e.to_string(),
473 "server_error",
474 );
475 }
476 };
477
478 if let Err(e) = state.storage.delete(&skey).await {
480 return err_response(
481 StatusCode::INTERNAL_SERVER_ERROR,
482 &e.to_string(),
483 "server_error",
484 );
485 }
486
487 state
488 .key_map
489 .write()
490 .unwrap_or_else(|e| e.into_inner())
491 .remove(&token);
492
493 StatusCode::NO_CONTENT.into_response()
494}
495
496pub async fn load_stored_keys(
499 storage: &dyn Storage,
500 toml_keys: &[KeyConfig],
501 key_map: &RwLock<HashMap<String, String>>,
502) {
503 let pairs = match storage.list(&PREFIX_KEYS).await {
504 Ok(p) => p,
505 Err(e) => {
506 tracing::warn!("failed to load stored keys: {e}");
507 return;
508 }
509 };
510
511 let toml_names: HashSet<&str> = toml_keys.iter().map(|k| k.name.as_str()).collect();
512
513 let mut map = key_map.write().unwrap_or_else(|e| e.into_inner());
514 for (_k, v) in pairs {
515 if let Ok(kc) = serde_json::from_slice::<KeyConfig>(&v) {
516 if toml_names.contains(kc.name.as_str()) {
518 continue;
519 }
520 map.insert(kc.key, kc.name);
521 }
522 }
523}