use axum::{
Json, Router,
extract::{Path, Request, State},
http::StatusCode,
middleware::{self, Next},
response::{IntoResponse, Response},
routing::{get, post},
};
use crabllm_core::{ApiError, KeyConfig, Prefix, Storage, storage_key};
use serde::{Deserialize, Serialize};
use std::{
collections::{HashMap, HashSet},
sync::{Arc, RwLock},
};
const KEY_PREFIX: Prefix = *b"keys";
#[derive(Clone)]
pub struct KeyAdminState {
storage: Arc<dyn Storage>,
key_map: Arc<RwLock<HashMap<String, String>>>,
admin_token: String,
toml_key_names: HashSet<String>,
toml_keys: Vec<KeyConfig>,
}
pub fn key_admin_routes(
storage: Arc<dyn Storage>,
key_map: Arc<RwLock<HashMap<String, String>>>,
admin_token: String,
toml_keys: Vec<KeyConfig>,
) -> Router {
let toml_key_names: HashSet<String> = toml_keys.iter().map(|k| k.name.clone()).collect();
let state = KeyAdminState {
storage,
key_map,
admin_token,
toml_key_names,
toml_keys,
};
Router::new()
.route("/v1/admin/keys", post(create_key).get(list_keys))
.route("/v1/admin/keys/{name}", get(get_key).delete(delete_key))
.route_layer(middleware::from_fn_with_state(state.clone(), admin_auth))
.with_state(state)
}
fn constant_time_eq(a: &str, b: &str) -> bool {
if a.len() != b.len() {
return false;
}
let mut diff = 0u8;
for (x, y) in a.bytes().zip(b.bytes()) {
diff |= x ^ y;
}
diff == 0
}
async fn admin_auth(State(state): State<KeyAdminState>, request: Request, next: Next) -> Response {
let auth_header = request
.headers()
.get("authorization")
.and_then(|v| v.to_str().ok());
let token = match auth_header.and_then(|h| h.strip_prefix("Bearer ")) {
Some(t) => t,
None => {
return err_response(
StatusCode::UNAUTHORIZED,
"missing or invalid Authorization header",
"authentication_error",
);
}
};
if !constant_time_eq(token, &state.admin_token) {
return err_response(
StatusCode::UNAUTHORIZED,
"invalid admin token",
"authentication_error",
);
}
next.run(request).await
}
fn err_response(status: StatusCode, message: &str, error_type: &str) -> Response {
(status, Json(ApiError::new(message, error_type))).into_response()
}
#[derive(Deserialize)]
struct CreateKeyRequest {
name: String,
#[serde(default = "default_models")]
models: Vec<String>,
}
fn default_models() -> Vec<String> {
vec!["*".to_string()]
}
#[derive(Serialize)]
struct KeyResponse {
name: String,
key: String,
models: Vec<String>,
}
#[derive(Serialize)]
struct KeySummary {
name: String,
key_prefix: String,
models: Vec<String>,
source: &'static str,
}
fn mask_key(key: &str) -> String {
let prefix: String = key.chars().take(8).collect();
if prefix.len() < key.len() {
format!("{prefix}...")
} else {
"***".to_string()
}
}
fn generate_key() -> String {
use rand::Rng;
let bytes: [u8; 32] = rand::rng().random();
let hex: String = bytes.iter().map(|b| format!("{b:02x}")).collect();
format!("sk-{hex}")
}
async fn create_key(
State(state): State<KeyAdminState>,
Json(body): Json<CreateKeyRequest>,
) -> Response {
if body.name.is_empty() {
return err_response(
StatusCode::BAD_REQUEST,
"name is required",
"invalid_request_error",
);
}
if state.toml_key_names.contains(&body.name) {
return err_response(
StatusCode::CONFLICT,
&format!("key '{}' is managed by config file", body.name),
"invalid_request_error",
);
}
let skey = storage_key(&KEY_PREFIX, body.name.as_bytes());
match state.storage.get(&skey).await {
Ok(Some(_)) => {
return err_response(
StatusCode::CONFLICT,
&format!("key '{}' already exists", body.name),
"invalid_request_error",
);
}
Err(e) => {
return err_response(
StatusCode::INTERNAL_SERVER_ERROR,
&e.to_string(),
"server_error",
);
}
Ok(None) => {}
}
let key = generate_key();
let config = KeyConfig {
name: body.name.clone(),
key: key.clone(),
models: body.models.clone(),
};
let value = match serde_json::to_vec(&config) {
Ok(v) => v,
Err(e) => {
return err_response(
StatusCode::INTERNAL_SERVER_ERROR,
&e.to_string(),
"server_error",
);
}
};
if let Err(e) = state.storage.set(&skey, value).await {
return err_response(
StatusCode::INTERNAL_SERVER_ERROR,
&e.to_string(),
"server_error",
);
}
state
.key_map
.write()
.unwrap_or_else(|e| e.into_inner())
.insert(key.clone(), body.name.clone());
(
StatusCode::CREATED,
Json(KeyResponse {
name: body.name,
key,
models: body.models,
}),
)
.into_response()
}
async fn list_keys(State(state): State<KeyAdminState>) -> Response {
let mut keys: Vec<KeySummary> = state
.toml_keys
.iter()
.map(|kc| KeySummary {
name: kc.name.clone(),
key_prefix: mask_key(&kc.key),
models: kc.models.clone(),
source: "config",
})
.collect();
let pairs = match state.storage.list(&KEY_PREFIX).await {
Ok(p) => p,
Err(e) => {
return err_response(
StatusCode::INTERNAL_SERVER_ERROR,
&e.to_string(),
"server_error",
);
}
};
for (_k, v) in pairs {
if let Ok(kc) = serde_json::from_slice::<KeyConfig>(&v) {
if state.toml_key_names.contains(&kc.name) {
continue;
}
keys.push(KeySummary {
name: kc.name,
key_prefix: mask_key(&kc.key),
models: kc.models,
source: "dynamic",
});
}
}
Json(keys).into_response()
}
async fn get_key(State(state): State<KeyAdminState>, Path(name): Path<String>) -> Response {
if let Some(kc) = state.toml_keys.iter().find(|k| k.name == name) {
return Json(KeySummary {
name: kc.name.clone(),
key_prefix: mask_key(&kc.key),
models: kc.models.clone(),
source: "config",
})
.into_response();
}
let skey = storage_key(&KEY_PREFIX, name.as_bytes());
match state.storage.get(&skey).await {
Ok(Some(bytes)) => match serde_json::from_slice::<KeyConfig>(&bytes) {
Ok(kc) => Json(KeySummary {
name: kc.name,
key_prefix: mask_key(&kc.key),
models: kc.models,
source: "dynamic",
})
.into_response(),
Err(e) => err_response(
StatusCode::INTERNAL_SERVER_ERROR,
&e.to_string(),
"server_error",
),
},
Ok(None) => err_response(
StatusCode::NOT_FOUND,
&format!("key '{name}' not found"),
"invalid_request_error",
),
Err(e) => err_response(
StatusCode::INTERNAL_SERVER_ERROR,
&e.to_string(),
"server_error",
),
}
}
async fn delete_key(State(state): State<KeyAdminState>, Path(name): Path<String>) -> Response {
if state.toml_key_names.contains(&name) {
return err_response(
StatusCode::FORBIDDEN,
&format!("key '{name}' is managed by config file and cannot be deleted via API"),
"invalid_request_error",
);
}
let skey = storage_key(&KEY_PREFIX, name.as_bytes());
let token = match state.storage.get(&skey).await {
Ok(Some(bytes)) => match serde_json::from_slice::<KeyConfig>(&bytes) {
Ok(kc) => kc.key,
Err(_) => {
return err_response(
StatusCode::INTERNAL_SERVER_ERROR,
"corrupt key data",
"server_error",
);
}
},
Ok(None) => {
return err_response(
StatusCode::NOT_FOUND,
&format!("key '{name}' not found"),
"invalid_request_error",
);
}
Err(e) => {
return err_response(
StatusCode::INTERNAL_SERVER_ERROR,
&e.to_string(),
"server_error",
);
}
};
if let Err(e) = state.storage.delete(&skey).await {
return err_response(
StatusCode::INTERNAL_SERVER_ERROR,
&e.to_string(),
"server_error",
);
}
state
.key_map
.write()
.unwrap_or_else(|e| e.into_inner())
.remove(&token);
StatusCode::NO_CONTENT.into_response()
}
pub async fn load_stored_keys(
storage: &dyn Storage,
toml_keys: &[KeyConfig],
key_map: &RwLock<HashMap<String, String>>,
) {
let pairs = match storage.list(&KEY_PREFIX).await {
Ok(p) => p,
Err(e) => {
eprintln!("warning: failed to load stored keys: {e}");
return;
}
};
let toml_names: HashSet<&str> = toml_keys.iter().map(|k| k.name.as_str()).collect();
let mut map = key_map.write().unwrap_or_else(|e| e.into_inner());
for (_k, v) in pairs {
if let Ok(kc) = serde_json::from_slice::<KeyConfig>(&v) {
if toml_names.contains(kc.name.as_str()) {
continue;
}
map.insert(kc.key, kc.name);
}
}
}