1use axum::{
2 Json, Router,
3 extract::{Path, Request, State},
4 http::StatusCode,
5 middleware::{self, Next},
6 response::{IntoResponse, Response},
7 routing::{get, post},
8};
9use crabllm_core::{ApiError, KeyConfig, Prefix, Storage, storage_key};
10use serde::{Deserialize, Serialize};
11use std::{
12 collections::{HashMap, HashSet},
13 sync::{Arc, RwLock},
14};
15
16const KEY_PREFIX: Prefix = *b"keys";
17
18#[derive(Clone)]
19pub struct KeyAdminState {
20 storage: Arc<dyn Storage>,
21 key_map: Arc<RwLock<HashMap<String, String>>>,
22 admin_token: String,
23 toml_key_names: HashSet<String>,
24 toml_keys: Vec<KeyConfig>,
25}
26
27pub fn key_admin_routes(
29 storage: Arc<dyn Storage>,
30 key_map: Arc<RwLock<HashMap<String, String>>>,
31 admin_token: String,
32 toml_keys: Vec<KeyConfig>,
33) -> Router {
34 let toml_key_names: HashSet<String> = toml_keys.iter().map(|k| k.name.clone()).collect();
35 let state = KeyAdminState {
36 storage,
37 key_map,
38 admin_token,
39 toml_key_names,
40 toml_keys,
41 };
42 Router::new()
43 .route("/v1/admin/keys", post(create_key).get(list_keys))
44 .route("/v1/admin/keys/{name}", get(get_key).delete(delete_key))
45 .route_layer(middleware::from_fn_with_state(state.clone(), admin_auth))
46 .with_state(state)
47}
48
49fn constant_time_eq(a: &str, b: &str) -> bool {
51 if a.len() != b.len() {
52 return false;
53 }
54 let mut diff = 0u8;
55 for (x, y) in a.bytes().zip(b.bytes()) {
56 diff |= x ^ y;
57 }
58 diff == 0
59}
60
61async fn admin_auth(State(state): State<KeyAdminState>, request: Request, next: Next) -> Response {
63 let auth_header = request
64 .headers()
65 .get("authorization")
66 .and_then(|v| v.to_str().ok());
67
68 let token = match auth_header.and_then(|h| h.strip_prefix("Bearer ")) {
69 Some(t) => t,
70 None => {
71 return err_response(
72 StatusCode::UNAUTHORIZED,
73 "missing or invalid Authorization header",
74 "authentication_error",
75 );
76 }
77 };
78
79 if !constant_time_eq(token, &state.admin_token) {
80 return err_response(
81 StatusCode::UNAUTHORIZED,
82 "invalid admin token",
83 "authentication_error",
84 );
85 }
86
87 next.run(request).await
88}
89
90fn err_response(status: StatusCode, message: &str, error_type: &str) -> Response {
91 (status, Json(ApiError::new(message, error_type))).into_response()
92}
93
94#[derive(Deserialize)]
95struct CreateKeyRequest {
96 name: String,
97 #[serde(default = "default_models")]
98 models: Vec<String>,
99}
100
101fn default_models() -> Vec<String> {
102 vec!["*".to_string()]
103}
104
105#[derive(Serialize)]
106struct KeyResponse {
107 name: String,
108 key: String,
109 models: Vec<String>,
110}
111
112#[derive(Serialize)]
113struct KeySummary {
114 name: String,
115 key_prefix: String,
116 models: Vec<String>,
117 source: &'static str,
118}
119
120fn mask_key(key: &str) -> String {
121 let prefix: String = key.chars().take(8).collect();
122 if prefix.len() < key.len() {
123 format!("{prefix}...")
124 } else {
125 "***".to_string()
126 }
127}
128
129fn generate_key() -> String {
130 use rand::Rng;
131 let bytes: [u8; 32] = rand::rng().random();
132 let hex: String = bytes.iter().map(|b| format!("{b:02x}")).collect();
133 format!("sk-{hex}")
134}
135
136async fn create_key(
138 State(state): State<KeyAdminState>,
139 Json(body): Json<CreateKeyRequest>,
140) -> Response {
141 if body.name.is_empty() {
142 return err_response(
143 StatusCode::BAD_REQUEST,
144 "name is required",
145 "invalid_request_error",
146 );
147 }
148
149 if state.toml_key_names.contains(&body.name) {
151 return err_response(
152 StatusCode::CONFLICT,
153 &format!("key '{}' is managed by config file", body.name),
154 "invalid_request_error",
155 );
156 }
157
158 let skey = storage_key(&KEY_PREFIX, body.name.as_bytes());
161 match state.storage.get(&skey).await {
162 Ok(Some(_)) => {
163 return err_response(
164 StatusCode::CONFLICT,
165 &format!("key '{}' already exists", body.name),
166 "invalid_request_error",
167 );
168 }
169 Err(e) => {
170 return err_response(
171 StatusCode::INTERNAL_SERVER_ERROR,
172 &e.to_string(),
173 "server_error",
174 );
175 }
176 Ok(None) => {}
177 }
178
179 let key = generate_key();
180 let config = KeyConfig {
181 name: body.name.clone(),
182 key: key.clone(),
183 models: body.models.clone(),
184 };
185
186 let value = match serde_json::to_vec(&config) {
188 Ok(v) => v,
189 Err(e) => {
190 return err_response(
191 StatusCode::INTERNAL_SERVER_ERROR,
192 &e.to_string(),
193 "server_error",
194 );
195 }
196 };
197 if let Err(e) = state.storage.set(&skey, value).await {
198 return err_response(
199 StatusCode::INTERNAL_SERVER_ERROR,
200 &e.to_string(),
201 "server_error",
202 );
203 }
204
205 state
207 .key_map
208 .write()
209 .unwrap_or_else(|e| e.into_inner())
210 .insert(key.clone(), body.name.clone());
211
212 (
213 StatusCode::CREATED,
214 Json(KeyResponse {
215 name: body.name,
216 key,
217 models: body.models,
218 }),
219 )
220 .into_response()
221}
222
223async fn list_keys(State(state): State<KeyAdminState>) -> Response {
225 let mut keys: Vec<KeySummary> = state
226 .toml_keys
227 .iter()
228 .map(|kc| KeySummary {
229 name: kc.name.clone(),
230 key_prefix: mask_key(&kc.key),
231 models: kc.models.clone(),
232 source: "config",
233 })
234 .collect();
235
236 let pairs = match state.storage.list(&KEY_PREFIX).await {
237 Ok(p) => p,
238 Err(e) => {
239 return err_response(
240 StatusCode::INTERNAL_SERVER_ERROR,
241 &e.to_string(),
242 "server_error",
243 );
244 }
245 };
246
247 for (_k, v) in pairs {
248 if let Ok(kc) = serde_json::from_slice::<KeyConfig>(&v) {
249 if state.toml_key_names.contains(&kc.name) {
251 continue;
252 }
253 keys.push(KeySummary {
254 name: kc.name,
255 key_prefix: mask_key(&kc.key),
256 models: kc.models,
257 source: "dynamic",
258 });
259 }
260 }
261
262 Json(keys).into_response()
263}
264
265async fn get_key(State(state): State<KeyAdminState>, Path(name): Path<String>) -> Response {
267 if let Some(kc) = state.toml_keys.iter().find(|k| k.name == name) {
269 return Json(KeySummary {
270 name: kc.name.clone(),
271 key_prefix: mask_key(&kc.key),
272 models: kc.models.clone(),
273 source: "config",
274 })
275 .into_response();
276 }
277
278 let skey = storage_key(&KEY_PREFIX, name.as_bytes());
279 match state.storage.get(&skey).await {
280 Ok(Some(bytes)) => match serde_json::from_slice::<KeyConfig>(&bytes) {
281 Ok(kc) => Json(KeySummary {
282 name: kc.name,
283 key_prefix: mask_key(&kc.key),
284 models: kc.models,
285 source: "dynamic",
286 })
287 .into_response(),
288 Err(e) => err_response(
289 StatusCode::INTERNAL_SERVER_ERROR,
290 &e.to_string(),
291 "server_error",
292 ),
293 },
294 Ok(None) => err_response(
295 StatusCode::NOT_FOUND,
296 &format!("key '{name}' not found"),
297 "invalid_request_error",
298 ),
299 Err(e) => err_response(
300 StatusCode::INTERNAL_SERVER_ERROR,
301 &e.to_string(),
302 "server_error",
303 ),
304 }
305}
306
307async fn delete_key(State(state): State<KeyAdminState>, Path(name): Path<String>) -> Response {
309 if state.toml_key_names.contains(&name) {
311 return err_response(
312 StatusCode::FORBIDDEN,
313 &format!("key '{name}' is managed by config file and cannot be deleted via API"),
314 "invalid_request_error",
315 );
316 }
317
318 let skey = storage_key(&KEY_PREFIX, name.as_bytes());
319
320 let token = match state.storage.get(&skey).await {
322 Ok(Some(bytes)) => match serde_json::from_slice::<KeyConfig>(&bytes) {
323 Ok(kc) => kc.key,
324 Err(_) => {
325 return err_response(
326 StatusCode::INTERNAL_SERVER_ERROR,
327 "corrupt key data",
328 "server_error",
329 );
330 }
331 },
332 Ok(None) => {
333 return err_response(
334 StatusCode::NOT_FOUND,
335 &format!("key '{name}' not found"),
336 "invalid_request_error",
337 );
338 }
339 Err(e) => {
340 return err_response(
341 StatusCode::INTERNAL_SERVER_ERROR,
342 &e.to_string(),
343 "server_error",
344 );
345 }
346 };
347
348 if let Err(e) = state.storage.delete(&skey).await {
350 return err_response(
351 StatusCode::INTERNAL_SERVER_ERROR,
352 &e.to_string(),
353 "server_error",
354 );
355 }
356
357 state
358 .key_map
359 .write()
360 .unwrap_or_else(|e| e.into_inner())
361 .remove(&token);
362
363 StatusCode::NO_CONTENT.into_response()
364}
365
366pub async fn load_stored_keys(
369 storage: &dyn Storage,
370 toml_keys: &[KeyConfig],
371 key_map: &RwLock<HashMap<String, String>>,
372) {
373 let pairs = match storage.list(&KEY_PREFIX).await {
374 Ok(p) => p,
375 Err(e) => {
376 eprintln!("warning: failed to load stored keys: {e}");
377 return;
378 }
379 };
380
381 let toml_names: HashSet<&str> = toml_keys.iter().map(|k| k.name.as_str()).collect();
382
383 let mut map = key_map.write().unwrap_or_else(|e| e.into_inner());
384 for (_k, v) in pairs {
385 if let Ok(kc) = serde_json::from_slice::<KeyConfig>(&v) {
386 if toml_names.contains(kc.name.as_str()) {
388 continue;
389 }
390 map.insert(kc.key, kc.name);
391 }
392 }
393}