Skip to main content

crabllm_proxy/
admin_providers.rs

1use crate::PREFIX_PROVIDERS;
2use arc_swap::ArcSwap;
3use axum::{
4    Json, Router,
5    extract::{Path, Request, State},
6    http::StatusCode,
7    middleware::{self, Next},
8    response::{IntoResponse, Response},
9    routing::{get, post},
10};
11use crabllm_core::{
12    Error, GatewayConfig, Provider, ProviderConfig, ProviderKind, Storage, storage_key,
13};
14use crabllm_provider::{HttpClient, ProviderRegistry};
15use serde::{Deserialize, Serialize};
16use std::{collections::HashSet, path::PathBuf, sync::Arc};
17use tokio::sync::Mutex;
18
19/// A closure that rebuilds the provider registry from a config.
20/// The binary provides this because the proxy crate doesn't know the
21/// concrete `P` construction path (e.g. `Dispatch::Remote`).
22pub type Rebuilder<P> =
23    Arc<dyn Fn(&GatewayConfig) -> Result<ProviderRegistry<P>, Error> + Send + Sync>;
24
25struct ProviderAdminState<P: Provider> {
26    registry: Arc<ArcSwap<ProviderRegistry<P>>>,
27    config_path: PathBuf,
28    admin_token: String,
29    rebuilder: Rebuilder<P>,
30    storage: Arc<dyn Storage>,
31    /// Serializes mutation paths (create/update/delete/reload). Read paths
32    /// (list/get) don't acquire this — they only touch storage + the
33    /// config file, both of which are fine under concurrent access.
34    write_lock: Arc<Mutex<()>>,
35}
36
37impl<P: Provider> Clone for ProviderAdminState<P> {
38    fn clone(&self) -> Self {
39        Self {
40            registry: self.registry.clone(),
41            config_path: self.config_path.clone(),
42            admin_token: self.admin_token.clone(),
43            rebuilder: self.rebuilder.clone(),
44            storage: self.storage.clone(),
45            write_lock: self.write_lock.clone(),
46        }
47    }
48}
49
50/// Build admin provider management routes, protected by admin token auth.
51pub fn provider_admin_routes<P: Provider + 'static>(
52    registry: Arc<ArcSwap<ProviderRegistry<P>>>,
53    config_path: PathBuf,
54    admin_token: String,
55    rebuilder: Rebuilder<P>,
56    storage: Arc<dyn Storage>,
57) -> Router {
58    let state = ProviderAdminState {
59        registry,
60        config_path,
61        admin_token,
62        rebuilder,
63        storage,
64        write_lock: Arc::new(Mutex::new(())),
65    };
66    Router::new()
67        .route(
68            "/v1/admin/providers",
69            post(create_provider::<P>).get(list_providers::<P>),
70        )
71        .route(
72            "/v1/admin/providers/{name}",
73            get(get_provider::<P>)
74                .patch(update_provider::<P>)
75                .delete(delete_provider::<P>),
76        )
77        .route_layer(middleware::from_fn_with_state(
78            state.clone(),
79            admin_auth::<P>,
80        ))
81        .with_state(state)
82}
83
84async fn admin_auth<P: Provider>(
85    State(state): State<ProviderAdminState<P>>,
86    request: Request,
87    next: Next,
88) -> Response {
89    if let Err(r) = crate::admin::check_admin_token(&request, &state.admin_token) {
90        return r;
91    }
92    next.run(request).await
93}
94
95// ── CRUD ──
96
97/// Request body for `POST /v1/admin/providers`. Flat shape: a
98/// provider name plus the full `ProviderConfig` inline.
99#[derive(Deserialize)]
100#[cfg_attr(feature = "openapi", derive(utoipa::ToSchema))]
101pub(crate) struct CreateProviderRequest {
102    name: String,
103    #[serde(default, alias = "standard")]
104    kind: ProviderKind,
105    #[serde(default)]
106    api_key: Option<String>,
107    #[serde(default)]
108    base_url: Option<String>,
109    #[serde(default)]
110    models: Vec<String>,
111    #[serde(default)]
112    weight: Option<u16>,
113    #[serde(default)]
114    max_retries: Option<u32>,
115    #[serde(default)]
116    api_version: Option<String>,
117    #[serde(default)]
118    timeout: Option<u64>,
119    #[serde(default)]
120    retry_deadline: Option<u64>,
121    #[serde(default)]
122    region: Option<String>,
123    #[serde(default)]
124    access_key: Option<String>,
125    #[serde(default)]
126    secret_key: Option<String>,
127}
128
129impl CreateProviderRequest {
130    fn into_parts(self) -> (String, ProviderConfig) {
131        (
132            self.name,
133            ProviderConfig {
134                kind: self.kind,
135                api_key: self.api_key,
136                base_url: self.base_url,
137                models: self.models,
138                weight: self.weight,
139                max_retries: self.max_retries,
140                api_version: self.api_version,
141                timeout: self.timeout,
142                retry_deadline: self.retry_deadline,
143                region: self.region,
144                access_key: self.access_key,
145                secret_key: self.secret_key,
146            },
147        )
148    }
149}
150
151/// Response shape for provider GET — secrets masked.
152#[derive(Serialize)]
153#[cfg_attr(feature = "openapi", derive(utoipa::ToSchema))]
154pub(crate) struct ProviderSummary {
155    name: String,
156    kind: ProviderKind,
157    #[serde(skip_serializing_if = "Option::is_none")]
158    api_key_prefix: Option<String>,
159    #[serde(skip_serializing_if = "Option::is_none")]
160    base_url: Option<String>,
161    models: Vec<String>,
162    #[serde(skip_serializing_if = "Option::is_none")]
163    weight: Option<u16>,
164    #[serde(skip_serializing_if = "Option::is_none")]
165    max_retries: Option<u32>,
166    #[serde(skip_serializing_if = "Option::is_none")]
167    api_version: Option<String>,
168    #[serde(skip_serializing_if = "Option::is_none")]
169    timeout: Option<u64>,
170    #[serde(skip_serializing_if = "Option::is_none")]
171    region: Option<String>,
172    #[serde(skip_serializing_if = "Option::is_none")]
173    access_key_prefix: Option<String>,
174    source: &'static str,
175}
176
177fn summarize(name: &str, cfg: &ProviderConfig, source: &'static str) -> ProviderSummary {
178    ProviderSummary {
179        name: name.to_string(),
180        kind: cfg.kind.clone(),
181        api_key_prefix: cfg.api_key.as_deref().map(mask),
182        base_url: cfg.base_url.clone(),
183        models: cfg.models.clone(),
184        weight: cfg.weight,
185        max_retries: cfg.max_retries,
186        api_version: cfg.api_version.clone(),
187        timeout: cfg.timeout,
188        region: cfg.region.clone(),
189        access_key_prefix: cfg.access_key.as_deref().map(mask),
190        source,
191    }
192}
193
194fn mask(secret: &str) -> String {
195    let prefix: String = secret.chars().take(8).collect();
196    if prefix.len() < secret.len() {
197        format!("{prefix}...")
198    } else {
199        "***".to_string()
200    }
201}
202
203/// GET /v1/admin/providers — list TOML + dynamic providers. Secrets masked.
204async fn list_providers<P: Provider>(State(state): State<ProviderAdminState<P>>) -> Response {
205    let toml_config = match read_toml_config(&state.config_path).await {
206        Ok(c) => c,
207        Err(r) => return r,
208    };
209    let toml_names: HashSet<String> = toml_config.providers.keys().cloned().collect();
210
211    let mut summaries: Vec<ProviderSummary> = toml_config
212        .providers
213        .iter()
214        .map(|(name, cfg)| summarize(name, cfg, "config"))
215        .collect();
216
217    let pairs = match state.storage.list(&PREFIX_PROVIDERS).await {
218        Ok(p) => p,
219        Err(e) => {
220            return crate::admin::err_response(
221                StatusCode::INTERNAL_SERVER_ERROR,
222                &e.to_string(),
223                "server_error",
224            );
225        }
226    };
227
228    for (_k, v) in pairs {
229        let Ok(cfg) = serde_json::from_slice::<StoredProvider>(&v) else {
230            continue;
231        };
232        if toml_names.contains(&cfg.name) {
233            continue;
234        }
235        summaries.push(summarize(&cfg.name, &cfg.config, "dynamic"));
236    }
237
238    Json(summaries).into_response()
239}
240
241/// GET /v1/admin/providers/{name} — get one provider.
242async fn get_provider<P: Provider>(
243    State(state): State<ProviderAdminState<P>>,
244    Path(name): Path<String>,
245) -> Response {
246    let toml_config = match read_toml_config(&state.config_path).await {
247        Ok(c) => c,
248        Err(r) => return r,
249    };
250    if let Some(cfg) = toml_config.providers.get(&name) {
251        return Json(summarize(&name, cfg, "config")).into_response();
252    }
253
254    match load_stored(state.storage.as_ref(), &name).await {
255        Ok(Some(cfg)) => Json(summarize(&name, &cfg, "dynamic")).into_response(),
256        Ok(None) => crate::admin::err_response(
257            StatusCode::NOT_FOUND,
258            &format!("provider '{name}' not found"),
259            "invalid_request_error",
260        ),
261        Err(r) => r,
262    }
263}
264
265/// POST /v1/admin/providers — create a new dynamic provider.
266async fn create_provider<P: Provider>(
267    State(state): State<ProviderAdminState<P>>,
268    Json(body): Json<CreateProviderRequest>,
269) -> Response {
270    if body.name.is_empty() {
271        return crate::admin::err_response(
272            StatusCode::BAD_REQUEST,
273            "name is required",
274            "invalid_request_error",
275        );
276    }
277    let (name, mut config) = body.into_parts();
278
279    let _guard = state.write_lock.lock().await;
280
281    let toml_config = match read_toml_config(&state.config_path).await {
282        Ok(c) => c,
283        Err(r) => return r,
284    };
285    if toml_config.providers.contains_key(&name) {
286        return crate::admin::err_response(
287            StatusCode::CONFLICT,
288            &format!("provider '{name}' is managed by config file"),
289            "invalid_request_error",
290        );
291    }
292
293    let skey = storage_key(&PREFIX_PROVIDERS, name.as_bytes());
294    match state.storage.get(&skey).await {
295        Ok(Some(_)) => {
296            return crate::admin::err_response(
297                StatusCode::CONFLICT,
298                &format!("provider '{name}' already exists"),
299                "invalid_request_error",
300            );
301        }
302        Err(e) => {
303            return crate::admin::err_response(
304                StatusCode::INTERNAL_SERVER_ERROR,
305                &e.to_string(),
306                "server_error",
307            );
308        }
309        Ok(None) => {}
310    }
311
312    if let Err(e) = autofill_models(&mut config).await {
313        return crate::admin::err_response(StatusCode::BAD_REQUEST, &e, "invalid_request_error");
314    }
315
316    if let Err(e) = validate_single(&name, &config) {
317        return crate::admin::err_response(StatusCode::BAD_REQUEST, &e, "invalid_request_error");
318    }
319
320    if let Err(r) = commit_change(&state, &name, Some(&config)).await {
321        return r;
322    }
323
324    (
325        StatusCode::CREATED,
326        Json(summarize(&name, &config, "dynamic")),
327    )
328        .into_response()
329}
330
331/// PATCH /v1/admin/providers/{name} — partial update of a dynamic provider.
332async fn update_provider<P: Provider>(
333    State(state): State<ProviderAdminState<P>>,
334    Path(name): Path<String>,
335    Json(body): Json<serde_json::Value>,
336) -> Response {
337    if body.get("name").is_some() {
338        return crate::admin::err_response(
339            StatusCode::BAD_REQUEST,
340            "'name' is immutable and cannot be patched",
341            "invalid_request_error",
342        );
343    }
344
345    let _guard = state.write_lock.lock().await;
346
347    let toml_config = match read_toml_config(&state.config_path).await {
348        Ok(c) => c,
349        Err(r) => return r,
350    };
351    if toml_config.providers.contains_key(&name) {
352        return crate::admin::err_response(
353            StatusCode::FORBIDDEN,
354            &format!("provider '{name}' is managed by config file and cannot be updated via API"),
355            "invalid_request_error",
356        );
357    }
358
359    let mut config = match load_stored(state.storage.as_ref(), &name).await {
360        Ok(Some(c)) => c,
361        Ok(None) => {
362            return crate::admin::err_response(
363                StatusCode::NOT_FOUND,
364                &format!("provider '{name}' not found"),
365                "invalid_request_error",
366            );
367        }
368        Err(r) => return r,
369    };
370
371    if let Err(r) = apply_patch(&mut config, &body) {
372        return r;
373    }
374
375    if let Err(e) = validate_single(&name, &config) {
376        return crate::admin::err_response(StatusCode::BAD_REQUEST, &e, "invalid_request_error");
377    }
378
379    if let Err(r) = commit_change(&state, &name, Some(&config)).await {
380        return r;
381    }
382
383    Json(summarize(&name, &config, "dynamic")).into_response()
384}
385
386/// DELETE /v1/admin/providers/{name} — delete a dynamic provider.
387async fn delete_provider<P: Provider>(
388    State(state): State<ProviderAdminState<P>>,
389    Path(name): Path<String>,
390) -> Response {
391    let _guard = state.write_lock.lock().await;
392
393    let toml_config = match read_toml_config(&state.config_path).await {
394        Ok(c) => c,
395        Err(r) => return r,
396    };
397    if toml_config.providers.contains_key(&name) {
398        return crate::admin::err_response(
399            StatusCode::FORBIDDEN,
400            &format!("provider '{name}' is managed by config file and cannot be deleted via API"),
401            "invalid_request_error",
402        );
403    }
404
405    let skey = storage_key(&PREFIX_PROVIDERS, name.as_bytes());
406    match state.storage.get(&skey).await {
407        Ok(None) => {
408            return crate::admin::err_response(
409                StatusCode::NOT_FOUND,
410                &format!("provider '{name}' not found"),
411                "invalid_request_error",
412            );
413        }
414        Err(e) => {
415            return crate::admin::err_response(
416                StatusCode::INTERNAL_SERVER_ERROR,
417                &e.to_string(),
418                "server_error",
419            );
420        }
421        Ok(Some(_)) => {}
422    }
423
424    if let Err(r) = commit_change(&state, &name, None).await {
425        return r;
426    }
427
428    StatusCode::NO_CONTENT.into_response()
429}
430
431// ── Helpers ──
432
433/// Storage row: name + config. We store the name inside the value so
434/// listing doesn't need to decode raw storage keys.
435#[derive(Serialize, Deserialize)]
436struct StoredProvider {
437    name: String,
438    #[serde(flatten)]
439    config: ProviderConfig,
440}
441
442/// Merge dynamic providers from storage into a config's provider map.
443/// TOML providers take precedence on name conflicts. Called at startup,
444/// during reload, and after every CRUD mutation before rebuilding.
445pub async fn merge_stored_providers(storage: &dyn Storage, config: &mut GatewayConfig) {
446    let pairs = match storage.list(&PREFIX_PROVIDERS).await {
447        Ok(p) => p,
448        Err(e) => {
449            tracing::warn!("failed to load stored providers: {e}");
450            return;
451        }
452    };
453    for (_k, v) in pairs {
454        let Ok(sp) = serde_json::from_slice::<StoredProvider>(&v) else {
455            continue;
456        };
457        // TOML precedence: log a warning and skip if name already present.
458        if config.providers.contains_key(&sp.name) {
459            tracing::warn!(
460                name = %sp.name,
461                "dynamic provider shadowed by TOML-managed provider of the same name"
462            );
463            continue;
464        }
465        config.providers.insert(sp.name, sp.config);
466    }
467}
468
469async fn read_toml_config(path: &PathBuf) -> Result<GatewayConfig, Response> {
470    let raw = tokio::fs::read_to_string(path).await.map_err(|e| {
471        crate::admin::err_response(
472            StatusCode::INTERNAL_SERVER_ERROR,
473            &format!("failed to read config file: {e}"),
474            "server_error",
475        )
476    })?;
477    toml::from_str::<GatewayConfig>(&raw).map_err(|e| {
478        crate::admin::err_response(
479            StatusCode::INTERNAL_SERVER_ERROR,
480            &format!("failed to parse config: {e}"),
481            "server_error",
482        )
483    })
484}
485
486async fn load_stored(
487    storage: &dyn Storage,
488    name: &str,
489) -> Result<Option<ProviderConfig>, Response> {
490    let skey = storage_key(&PREFIX_PROVIDERS, name.as_bytes());
491    match storage.get(&skey).await {
492        Ok(Some(bytes)) => match serde_json::from_slice::<StoredProvider>(&bytes) {
493            Ok(sp) => Ok(Some(sp.config)),
494            Err(_) => Err(crate::admin::err_response(
495                StatusCode::INTERNAL_SERVER_ERROR,
496                "corrupt provider data",
497                "server_error",
498            )),
499        },
500        Ok(None) => Ok(None),
501        Err(e) => Err(crate::admin::err_response(
502            StatusCode::INTERNAL_SERVER_ERROR,
503            &e.to_string(),
504            "server_error",
505        )),
506    }
507}
508
509#[allow(clippy::result_large_err)]
510fn apply_patch(config: &mut ProviderConfig, body: &serde_json::Value) -> Result<(), Response> {
511    let obj = body.as_object().ok_or_else(|| {
512        crate::admin::err_response(
513            StatusCode::BAD_REQUEST,
514            "request body must be a JSON object",
515            "invalid_request_error",
516        )
517    })?;
518
519    for (key, value) in obj {
520        match key.as_str() {
521            "kind" => {
522                config.kind = serde_json::from_value(value.clone()).map_err(|e| {
523                    crate::admin::err_response(
524                        StatusCode::BAD_REQUEST,
525                        &format!("invalid 'kind': {e}"),
526                        "invalid_request_error",
527                    )
528                })?;
529            }
530            "api_key" => {
531                config.api_key = from_value_opt(value, "api_key")?;
532            }
533            "base_url" => {
534                config.base_url = from_value_opt(value, "base_url")?;
535            }
536            "models" => {
537                config.models = serde_json::from_value(value.clone()).map_err(|e| {
538                    crate::admin::err_response(
539                        StatusCode::BAD_REQUEST,
540                        &format!("invalid 'models': {e}"),
541                        "invalid_request_error",
542                    )
543                })?;
544            }
545            "weight" => config.weight = from_value_opt(value, "weight")?,
546            "max_retries" => config.max_retries = from_value_opt(value, "max_retries")?,
547            "api_version" => config.api_version = from_value_opt(value, "api_version")?,
548            "timeout" => config.timeout = from_value_opt(value, "timeout")?,
549            "region" => config.region = from_value_opt(value, "region")?,
550            "access_key" => config.access_key = from_value_opt(value, "access_key")?,
551            "secret_key" => config.secret_key = from_value_opt(value, "secret_key")?,
552            other => {
553                return Err(crate::admin::err_response(
554                    StatusCode::BAD_REQUEST,
555                    &format!("unknown field '{other}'"),
556                    "invalid_request_error",
557                ));
558            }
559        }
560    }
561    Ok(())
562}
563
564#[allow(clippy::result_large_err)]
565fn from_value_opt<T: for<'de> Deserialize<'de>>(
566    value: &serde_json::Value,
567    field: &str,
568) -> Result<Option<T>, Response> {
569    if value.is_null() {
570        return Ok(None);
571    }
572    serde_json::from_value(value.clone())
573        .map(Some)
574        .map_err(|e| {
575            crate::admin::err_response(
576                StatusCode::BAD_REQUEST,
577                &format!("invalid '{field}': {e}"),
578                "invalid_request_error",
579            )
580        })
581}
582
583fn validate_single(name: &str, config: &ProviderConfig) -> Result<(), String> {
584    config.validate(name)
585}
586
587/// If `config.models` is empty, query the provider's `GET {base_url}/models`
588/// and populate from the response. Only OpenAI-compatible kinds expose a
589/// standard models endpoint — other kinds error out asking for an explicit
590/// `--models`.
591async fn autofill_models(config: &mut ProviderConfig) -> Result<(), String> {
592    if !config.models.is_empty() {
593        return Ok(());
594    }
595
596    let base_url = match &config.kind {
597        crabllm_core::ProviderKind::Openai => config
598            .base_url
599            .as_deref()
600            .unwrap_or("https://api.openai.com/v1"),
601        crabllm_core::ProviderKind::Ollama => config
602            .base_url
603            .as_deref()
604            .unwrap_or("http://localhost:11434/v1"),
605        crabllm_core::ProviderKind::Custom(_) => config.base_url.as_deref().ok_or_else(|| {
606            "models is empty and base_url is not set; cannot auto-fetch".to_string()
607        })?,
608        other => {
609            return Err(format!(
610                "models is required for kind '{other}' — auto-fetch only supported for \
611                 openai, ollama, and custom kinds"
612            ));
613        }
614    };
615
616    let url = format!("{}/models", base_url.trim_end_matches('/'));
617    let auth = config.api_key.as_ref().map(|k| format!("Bearer {k}"));
618    let mut headers: Vec<(&str, &str)> = Vec::new();
619    if let Some(h) = auth.as_deref() {
620        headers.push(("authorization", h));
621    }
622
623    let client = HttpClient::new();
624    let resp = client
625        .get(&url, &headers)
626        .await
627        .map_err(|e| format!("failed to auto-fetch models from {url}: {e}"))?;
628
629    if !(200..300).contains(&resp.status) {
630        return Err(format!(
631            "{url} returned {}; pass --models explicitly",
632            resp.status
633        ));
634    }
635
636    let body: serde_json::Value =
637        serde_json::from_slice(&resp.body).map_err(|e| format!("invalid JSON from {url}: {e}"))?;
638    let data = body
639        .get("data")
640        .and_then(|v| v.as_array())
641        .ok_or_else(|| format!("{url} missing 'data' array; pass --models explicitly"))?;
642
643    let models: Vec<String> = data
644        .iter()
645        .filter_map(|m| m.get("id").and_then(|v| v.as_str()).map(String::from))
646        .collect();
647
648    if models.is_empty() {
649        return Err(format!(
650            "{url} returned no models; pass --models explicitly"
651        ));
652    }
653
654    tracing::info!(
655        kind = %config.kind,
656        base_url,
657        count = models.len(),
658        "auto-fetched models from provider",
659    );
660
661    config.models = models;
662    Ok(())
663}
664
665/// Apply a single-provider mutation: build registry first (on a projected
666/// config), persist to storage only if build succeeds, then swap. This
667/// ordering guarantees that a rebuild failure never leaves a corrupted
668/// row behind in storage.
669///
670/// `new_config = Some(_)` means create/update; `None` means delete.
671/// Caller must hold `state.write_lock` to serialize mutations.
672#[allow(clippy::result_large_err)]
673async fn commit_change<P: Provider>(
674    state: &ProviderAdminState<P>,
675    name: &str,
676    new_config: Option<&ProviderConfig>,
677) -> Result<(), Response> {
678    let mut config = read_toml_config(&state.config_path).await?;
679    merge_stored_providers(state.storage.as_ref(), &mut config).await;
680    match new_config {
681        Some(c) => {
682            config.providers.insert(name.to_string(), c.clone());
683        }
684        None => {
685            config.providers.remove(name);
686        }
687    }
688
689    // Build first — no side effects on failure.
690    let new_registry = (state.rebuilder)(&config).map_err(|e| {
691        crate::admin::err_response(
692            StatusCode::BAD_REQUEST,
693            &format!("failed to rebuild registry: {e}"),
694            "invalid_request_error",
695        )
696    })?;
697
698    // Persist second — storage matches the registry we're about to swap in.
699    let skey = storage_key(&PREFIX_PROVIDERS, name.as_bytes());
700    match new_config {
701        Some(c) => {
702            let stored = StoredProvider {
703                name: name.to_string(),
704                config: c.clone(),
705            };
706            let value = serde_json::to_vec(&stored).map_err(|e| {
707                crate::admin::err_response(
708                    StatusCode::INTERNAL_SERVER_ERROR,
709                    &e.to_string(),
710                    "server_error",
711                )
712            })?;
713            state.storage.set(&skey, value).await.map_err(|e| {
714                crate::admin::err_response(
715                    StatusCode::INTERNAL_SERVER_ERROR,
716                    &e.to_string(),
717                    "server_error",
718                )
719            })?;
720        }
721        None => {
722            state.storage.delete(&skey).await.map_err(|e| {
723                crate::admin::err_response(
724                    StatusCode::INTERNAL_SERVER_ERROR,
725                    &e.to_string(),
726                    "server_error",
727                )
728            })?;
729        }
730    }
731
732    // Swap last — infallible.
733    state.registry.store(Arc::new(new_registry));
734    Ok(())
735}