Skip to main content

aster_server/routes/
config_management.rs

1use crate::routes::utils::check_provider_configured;
2use crate::state::AppState;
3use aster::config::declarative_providers::LoadedProvider;
4use aster::config::paths::Paths;
5use aster::config::ExtensionEntry;
6use aster::config::{Config, ConfigError};
7use aster::model::ModelConfig;
8use aster::providers::auto_detect::detect_provider_from_api_key;
9use aster::providers::base::{ProviderMetadata, ProviderType};
10use aster::providers::canonical::maybe_get_canonical_model;
11use aster::providers::create_with_default_model;
12use aster::providers::errors::ProviderError;
13use aster::providers::providers as get_providers;
14use aster::providers::{retry_operation, RetryConfig};
15use aster::{
16    agents::execute_commands, agents::ExtensionConfig, config::permission::PermissionLevel,
17    slash_commands,
18};
19use axum::routing::put;
20use axum::{
21    extract::Path,
22    routing::{delete, get, post},
23    Json, Router,
24};
25use http::StatusCode;
26use serde::{Deserialize, Serialize};
27use serde_json::Value;
28use serde_yaml;
29use std::{collections::HashMap, sync::Arc};
30use utoipa::ToSchema;
31
32#[derive(Serialize, ToSchema)]
33pub struct ExtensionResponse {
34    pub extensions: Vec<ExtensionEntry>,
35    #[serde(default)]
36    pub warnings: Vec<String>,
37}
38
39#[derive(Deserialize, ToSchema)]
40pub struct ExtensionQuery {
41    pub name: String,
42    pub config: ExtensionConfig,
43    pub enabled: bool,
44}
45
46#[derive(Deserialize, ToSchema)]
47pub struct UpsertConfigQuery {
48    pub key: String,
49    pub value: Value,
50    pub is_secret: bool,
51}
52
53#[derive(Deserialize, Serialize, ToSchema)]
54pub struct ConfigKeyQuery {
55    pub key: String,
56    pub is_secret: bool,
57}
58
59#[derive(Serialize, ToSchema)]
60pub struct ConfigResponse {
61    pub config: HashMap<String, Value>,
62}
63
64#[derive(Debug, Serialize, Deserialize, ToSchema)]
65pub struct ProviderDetails {
66    pub name: String,
67    pub metadata: ProviderMetadata,
68    pub is_configured: bool,
69    pub provider_type: ProviderType,
70}
71
72#[derive(Serialize, ToSchema)]
73pub struct ProvidersResponse {
74    pub providers: Vec<ProviderDetails>,
75}
76
77#[derive(Debug, Serialize, Deserialize, ToSchema)]
78pub struct ToolPermission {
79    pub tool_name: String,
80    pub permission: PermissionLevel,
81}
82
83#[derive(Deserialize, ToSchema)]
84pub struct UpsertPermissionsQuery {
85    pub tool_permissions: Vec<ToolPermission>,
86}
87
88#[derive(Deserialize, ToSchema)]
89pub struct UpdateCustomProviderRequest {
90    pub engine: String,
91    pub display_name: String,
92    pub api_url: String,
93    pub api_key: String,
94    pub models: Vec<String>,
95    pub supports_streaming: Option<bool>,
96    pub headers: Option<std::collections::HashMap<String, String>>,
97}
98
99#[derive(Deserialize, ToSchema)]
100pub struct CheckProviderRequest {
101    pub provider: String,
102}
103
104#[derive(Deserialize, ToSchema)]
105pub struct SetProviderRequest {
106    pub provider: String,
107    pub model: String,
108}
109
110#[derive(Serialize, ToSchema)]
111#[serde(rename_all = "camelCase")]
112pub struct MaskedSecret {
113    pub masked_value: String,
114}
115
116#[derive(Serialize, ToSchema)]
117#[serde(untagged)]
118pub enum ConfigValueResponse {
119    Value(Value),
120    MaskedValue(MaskedSecret),
121}
122
123#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)]
124pub enum CommandType {
125    Builtin,
126    Recipe,
127}
128
129#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)]
130pub struct SlashCommand {
131    pub command: String,
132    pub help: String,
133    pub command_type: CommandType,
134}
135#[derive(Serialize, ToSchema)]
136pub struct SlashCommandsResponse {
137    pub commands: Vec<SlashCommand>,
138}
139
140#[derive(Deserialize, ToSchema)]
141pub struct DetectProviderRequest {
142    pub api_key: String,
143}
144
145#[derive(Serialize, ToSchema)]
146pub struct DetectProviderResponse {
147    pub provider_name: String,
148    pub models: Vec<String>,
149}
150#[utoipa::path(
151    post,
152    path = "/config/upsert",
153    request_body = UpsertConfigQuery,
154    responses(
155        (status = 200, description = "Configuration value upserted successfully", body = String),
156        (status = 500, description = "Internal server error")
157    )
158)]
159pub async fn upsert_config(
160    Json(query): Json<UpsertConfigQuery>,
161) -> Result<Json<Value>, StatusCode> {
162    let config = Config::global();
163    let result = config.set(&query.key, &query.value, query.is_secret);
164
165    match result {
166        Ok(_) => Ok(Json(Value::String(format!("Upserted key {}", query.key)))),
167        Err(_) => Err(StatusCode::INTERNAL_SERVER_ERROR),
168    }
169}
170
171#[utoipa::path(
172    post,
173    path = "/config/remove",
174    request_body = ConfigKeyQuery,
175    responses(
176        (status = 200, description = "Configuration value removed successfully", body = String),
177        (status = 404, description = "Configuration key not found"),
178        (status = 500, description = "Internal server error")
179    )
180)]
181pub async fn remove_config(Json(query): Json<ConfigKeyQuery>) -> Result<Json<String>, StatusCode> {
182    let config = Config::global();
183
184    let result = if query.is_secret {
185        config.delete_secret(&query.key)
186    } else {
187        config.delete(&query.key)
188    };
189
190    match result {
191        Ok(_) => Ok(Json(format!("Removed key {}", query.key))),
192        Err(_) => Err(StatusCode::NOT_FOUND),
193    }
194}
195
196const SECRET_MASK_SHOW_LEN: usize = 8;
197
198fn mask_secret(secret: Value) -> String {
199    let as_string = match secret {
200        Value::String(s) => s,
201        _ => serde_json::to_string(&secret).unwrap_or_else(|_| secret.to_string()),
202    };
203
204    let chars: Vec<_> = as_string.chars().collect();
205    let show_len = std::cmp::min(chars.len() / 2, SECRET_MASK_SHOW_LEN);
206    let visible: String = chars.iter().take(show_len).collect();
207    let mask = "*".repeat(chars.len() - show_len);
208
209    format!("{}{}", visible, mask)
210}
211
212#[utoipa::path(
213    post,
214    path = "/config/read",
215    request_body = ConfigKeyQuery,
216    responses(
217        (status = 200, description = "Configuration value retrieved successfully", body = Value),
218        (status = 500, description = "Unable to get the configuration value"),
219    )
220)]
221pub async fn read_config(
222    Json(query): Json<ConfigKeyQuery>,
223) -> Result<Json<ConfigValueResponse>, StatusCode> {
224    if query.key == "model-limits" {
225        let limits = ModelConfig::get_all_model_limits();
226        return Ok(Json(ConfigValueResponse::Value(
227            serde_json::to_value(limits).map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?,
228        )));
229    }
230
231    let config = Config::global();
232
233    let response_value = match config.get(&query.key, query.is_secret) {
234        Ok(value) => {
235            if query.is_secret {
236                ConfigValueResponse::MaskedValue(MaskedSecret {
237                    masked_value: mask_secret(value),
238                })
239            } else {
240                ConfigValueResponse::Value(value)
241            }
242        }
243        Err(ConfigError::NotFound(_)) => ConfigValueResponse::Value(Value::Null),
244        Err(_) => {
245            return Err(StatusCode::INTERNAL_SERVER_ERROR);
246        }
247    };
248    Ok(Json(response_value))
249}
250
251#[utoipa::path(
252    get,
253    path = "/config/extensions",
254    responses(
255        (status = 200, description = "All extensions retrieved successfully", body = ExtensionResponse),
256        (status = 500, description = "Internal server error")
257    )
258)]
259pub async fn get_extensions() -> Result<Json<ExtensionResponse>, StatusCode> {
260    let extensions = aster::config::get_all_extensions();
261    let warnings = aster::config::get_warnings();
262    Ok(Json(ExtensionResponse {
263        extensions,
264        warnings,
265    }))
266}
267
268#[utoipa::path(
269    post,
270    path = "/config/extensions",
271    request_body = ExtensionQuery,
272    responses(
273        (status = 200, description = "Extension added or updated successfully", body = String),
274        (status = 400, description = "Invalid request"),
275        (status = 422, description = "Could not serialize config.yaml"),
276        (status = 500, description = "Internal server error")
277    )
278)]
279pub async fn add_extension(
280    Json(extension_query): Json<ExtensionQuery>,
281) -> Result<Json<String>, StatusCode> {
282    let extensions = aster::config::get_all_extensions();
283    let key = aster::config::extensions::name_to_key(&extension_query.name);
284
285    let is_update = extensions.iter().any(|e| e.config.key() == key);
286
287    aster::config::set_extension(ExtensionEntry {
288        enabled: extension_query.enabled,
289        config: extension_query.config,
290    });
291
292    if is_update {
293        Ok(Json(format!("Updated extension {}", extension_query.name)))
294    } else {
295        Ok(Json(format!("Added extension {}", extension_query.name)))
296    }
297}
298
299#[utoipa::path(
300    delete,
301    path = "/config/extensions/{name}",
302    responses(
303        (status = 200, description = "Extension removed successfully", body = String),
304        (status = 404, description = "Extension not found"),
305        (status = 500, description = "Internal server error")
306    )
307)]
308pub async fn remove_extension(Path(name): Path<String>) -> Result<Json<String>, StatusCode> {
309    let key = aster::config::extensions::name_to_key(&name);
310    aster::config::remove_extension(&key);
311    Ok(Json(format!("Removed extension {}", name)))
312}
313
314#[utoipa::path(
315    get,
316    path = "/config",
317    responses(
318        (status = 200, description = "All configuration values retrieved successfully", body = ConfigResponse)
319    )
320)]
321pub async fn read_all_config() -> Result<Json<ConfigResponse>, StatusCode> {
322    let config = Config::global();
323
324    let values = config
325        .all_values()
326        .map_err(|_| StatusCode::UNPROCESSABLE_ENTITY)?;
327
328    Ok(Json(ConfigResponse { config: values }))
329}
330
331#[utoipa::path(
332    get,
333    path = "/config/providers",
334    responses(
335        (status = 200, description = "All configuration values retrieved successfully", body = [ProviderDetails])
336    )
337)]
338pub async fn providers() -> Result<Json<Vec<ProviderDetails>>, StatusCode> {
339    let providers = get_providers().await;
340    let providers_response: Vec<ProviderDetails> = providers
341        .into_iter()
342        .map(|(metadata, provider_type)| {
343            let is_configured = check_provider_configured(&metadata, provider_type);
344
345            ProviderDetails {
346                name: metadata.name.clone(),
347                metadata,
348                is_configured,
349                provider_type,
350            }
351        })
352        .collect();
353
354    Ok(Json(providers_response))
355}
356
357#[utoipa::path(
358    get,
359    path = "/config/providers/{name}/models",
360    params(
361        ("name" = String, Path, description = "Provider name (e.g., openai)")
362    ),
363    responses(
364        (status = 200, description = "Models fetched successfully", body = [String]),
365        (status = 400, description = "Unknown provider, provider not configured, or authentication error"),
366        (status = 429, description = "Rate limit exceeded"),
367        (status = 500, description = "Internal server error")
368    )
369)]
370pub async fn get_provider_models(
371    Path(name): Path<String>,
372) -> Result<Json<Vec<String>>, StatusCode> {
373    let loaded_provider = aster::config::declarative_providers::load_provider(name.as_str()).ok();
374    // TODO(Douwe): support a get models url for custom providers
375    if let Some(loaded_provider) = loaded_provider {
376        return Ok(Json(
377            loaded_provider
378                .config
379                .models
380                .into_iter()
381                .map(|m| m.name)
382                .collect::<Vec<_>>(),
383        ));
384    }
385
386    let all = get_providers()
387        .await
388        .into_iter()
389        //.map(|(m, p)| m)
390        .collect::<Vec<_>>();
391    let Some((metadata, provider_type)) = all.into_iter().find(|(m, _)| m.name == name) else {
392        return Err(StatusCode::BAD_REQUEST);
393    };
394    if !check_provider_configured(&metadata, provider_type) {
395        return Err(StatusCode::BAD_REQUEST);
396    }
397
398    let model_config =
399        ModelConfig::new(&metadata.default_model).map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
400    let provider = aster::providers::create(&name, model_config)
401        .await
402        .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
403
404    let models_result = retry_operation(&RetryConfig::default(), || async {
405        provider.fetch_recommended_models().await
406    })
407    .await;
408
409    match models_result {
410        Ok(Some(models)) => Ok(Json(models)),
411        Ok(None) => Ok(Json(Vec::new())),
412        Err(provider_error) => {
413            let status_code = match provider_error {
414                // Permanent misconfigurations - client should fix configuration
415                ProviderError::Authentication(_) => StatusCode::BAD_REQUEST,
416                ProviderError::UsageError(_) => StatusCode::BAD_REQUEST,
417
418                // Transient errors - client should retry later
419                ProviderError::RateLimitExceeded { .. } => StatusCode::TOO_MANY_REQUESTS,
420
421                // All other errors - internal server error
422                _ => StatusCode::INTERNAL_SERVER_ERROR,
423            };
424
425            tracing::warn!(
426                "Provider {} failed to fetch models: {}",
427                name,
428                provider_error
429            );
430            Err(status_code)
431        }
432    }
433}
434
435#[utoipa::path(
436    get,
437    path = "/config/slash_commands",
438    responses(
439        (status = 200, description = "Slash commands retrieved successfully", body = SlashCommandsResponse)
440    )
441)]
442pub async fn get_slash_commands() -> Result<Json<SlashCommandsResponse>, StatusCode> {
443    let mut commands: Vec<_> = slash_commands::list_commands()
444        .iter()
445        .map(|command| SlashCommand {
446            command: command.command.clone(),
447            help: command.recipe_path.clone(),
448            command_type: CommandType::Recipe,
449        })
450        .collect();
451
452    for cmd_def in execute_commands::list_commands() {
453        commands.push(SlashCommand {
454            command: cmd_def.name.to_string(),
455            help: cmd_def.description.to_string(),
456            command_type: CommandType::Builtin,
457        });
458    }
459
460    Ok(Json(SlashCommandsResponse { commands }))
461}
462
463#[derive(Serialize, ToSchema)]
464pub struct PricingData {
465    pub provider: String,
466    pub model: String,
467    pub input_token_cost: f64,
468    pub output_token_cost: f64,
469    pub currency: String,
470    pub context_length: Option<u32>,
471}
472
473#[derive(Serialize, ToSchema)]
474pub struct PricingResponse {
475    pub pricing: Vec<PricingData>,
476    pub source: String,
477}
478
479#[derive(Deserialize, ToSchema)]
480pub struct PricingQuery {
481    pub provider: String,
482    pub model: String,
483}
484
485#[utoipa::path(
486    post,
487    path = "/config/pricing",
488    request_body = PricingQuery,
489    responses(
490        (status = 200, description = "Model pricing data retrieved successfully", body = PricingResponse)
491    )
492)]
493pub async fn get_pricing(
494    Json(query): Json<PricingQuery>,
495) -> Result<Json<PricingResponse>, StatusCode> {
496    let canonical_model =
497        maybe_get_canonical_model(&query.provider, &query.model).ok_or(StatusCode::NOT_FOUND)?;
498
499    let mut pricing_data = Vec::new();
500
501    if let (Some(input_cost), Some(output_cost)) = (
502        canonical_model.pricing.prompt,
503        canonical_model.pricing.completion,
504    ) {
505        pricing_data.push(PricingData {
506            provider: query.provider.clone(),
507            model: query.model.clone(),
508            input_token_cost: input_cost,
509            output_token_cost: output_cost,
510            currency: "$".to_string(),
511            context_length: Some(canonical_model.context_length as u32),
512        });
513    }
514
515    Ok(Json(PricingResponse {
516        pricing: pricing_data,
517        source: "canonical".to_string(),
518    }))
519}
520
521#[utoipa::path(
522    post,
523    path = "/config/init",
524    responses(
525        (status = 200, description = "Config initialization check completed", body = String),
526        (status = 500, description = "Internal server error")
527    )
528)]
529pub async fn init_config() -> Result<Json<String>, StatusCode> {
530    let config = Config::global();
531
532    if config.exists() {
533        return Ok(Json("Config already exists".to_string()));
534    }
535
536    // Use the shared function to load init-config.yaml
537    match aster::config::base::load_init_config_from_workspace() {
538        Ok(init_values) => match config.initialize_if_empty(init_values) {
539            Ok(_) => Ok(Json("Config initialized successfully".to_string())),
540            Err(_) => Err(StatusCode::INTERNAL_SERVER_ERROR),
541        },
542        Err(_) => Ok(Json(
543            "No init-config.yaml found, using default configuration".to_string(),
544        )),
545    }
546}
547
548#[utoipa::path(
549    post,
550    path = "/config/permissions",
551    request_body = UpsertPermissionsQuery,
552    responses(
553        (status = 200, description = "Permission update completed", body = String),
554        (status = 400, description = "Invalid request"),
555    )
556)]
557pub async fn upsert_permissions(
558    Json(query): Json<UpsertPermissionsQuery>,
559) -> Result<Json<String>, StatusCode> {
560    let mut permission_manager = aster::config::PermissionManager::default();
561
562    for tool_permission in &query.tool_permissions {
563        permission_manager.update_user_permission(
564            &tool_permission.tool_name,
565            tool_permission.permission.clone(),
566        );
567    }
568
569    Ok(Json("Permissions updated successfully".to_string()))
570}
571
572#[utoipa::path(
573    post,
574    path = "/config/detect-provider",
575    request_body = DetectProviderRequest,
576    responses(
577        (status = 200, description = "Provider detected successfully", body = DetectProviderResponse),
578        (status = 404, description = "No matching provider found"),
579    )
580)]
581pub async fn detect_provider(
582    Json(detect_request): Json<DetectProviderRequest>,
583) -> Result<Json<DetectProviderResponse>, StatusCode> {
584    let api_key = detect_request.api_key.trim();
585
586    match detect_provider_from_api_key(api_key).await {
587        Some((provider_name, models)) => Ok(Json(DetectProviderResponse {
588            provider_name,
589            models,
590        })),
591        None => Err(StatusCode::NOT_FOUND),
592    }
593}
594
595#[utoipa::path(
596    post,
597    path = "/config/backup",
598    responses(
599        (status = 200, description = "Config file backed up", body = String),
600        (status = 500, description = "Internal server error")
601    )
602)]
603pub async fn backup_config() -> Result<Json<String>, StatusCode> {
604    let config_path = Paths::config_dir().join("config.yaml");
605
606    if config_path.exists() {
607        let file_name = config_path
608            .file_name()
609            .ok_or(StatusCode::INTERNAL_SERVER_ERROR)?;
610
611        let mut backup_name = file_name.to_os_string();
612        backup_name.push(".bak");
613
614        let backup = config_path.with_file_name(backup_name);
615        match std::fs::copy(&config_path, &backup) {
616            Ok(_) => Ok(Json(format!("Copied {:?} to {:?}", config_path, backup))),
617            Err(_) => Err(StatusCode::INTERNAL_SERVER_ERROR),
618        }
619    } else {
620        Err(StatusCode::INTERNAL_SERVER_ERROR)
621    }
622}
623
624#[utoipa::path(
625    post,
626    path = "/config/recover",
627    responses(
628        (status = 200, description = "Config recovery attempted", body = String),
629        (status = 500, description = "Internal server error")
630    )
631)]
632pub async fn recover_config() -> Result<Json<String>, StatusCode> {
633    let config = Config::global();
634
635    // Force a reload which will trigger recovery if needed
636    match config.all_values() {
637        Ok(values) => {
638            let recovered_keys: Vec<String> = values.keys().cloned().collect();
639            if recovered_keys.is_empty() {
640                Ok(Json("Config recovery completed, but no data was recoverable. Starting with empty configuration.".to_string()))
641            } else {
642                Ok(Json(format!(
643                    "Config recovery completed. Recovered {} keys: {}",
644                    recovered_keys.len(),
645                    recovered_keys.join(", ")
646                )))
647            }
648        }
649        Err(e) => {
650            tracing::error!("Config recovery failed: {}", e);
651            Err(StatusCode::INTERNAL_SERVER_ERROR)
652        }
653    }
654}
655
656#[utoipa::path(
657    get,
658    path = "/config/validate",
659    responses(
660        (status = 200, description = "Config validation result", body = String),
661        (status = 422, description = "Config file is corrupted")
662    )
663)]
664pub async fn validate_config() -> Result<Json<String>, StatusCode> {
665    let config_path = Paths::config_dir().join("config.yaml");
666
667    if !config_path.exists() {
668        return Ok(Json("Config file does not exist".to_string()));
669    }
670
671    match std::fs::read_to_string(&config_path) {
672        Ok(content) => match serde_yaml::from_str::<serde_yaml::Value>(&content) {
673            Ok(_) => Ok(Json("Config file is valid".to_string())),
674            Err(e) => {
675                tracing::warn!("Config validation failed: {}", e);
676                Err(StatusCode::UNPROCESSABLE_ENTITY)
677            }
678        },
679        Err(e) => {
680            tracing::error!("Failed to read config file: {}", e);
681            Err(StatusCode::INTERNAL_SERVER_ERROR)
682        }
683    }
684}
685#[utoipa::path(
686    post,
687    path = "/config/custom-providers",
688    request_body = UpdateCustomProviderRequest,
689    responses(
690        (status = 200, description = "Custom provider created successfully", body = String),
691        (status = 400, description = "Invalid request"),
692        (status = 500, description = "Internal server error")
693    )
694)]
695pub async fn create_custom_provider(
696    Json(request): Json<UpdateCustomProviderRequest>,
697) -> Result<Json<String>, StatusCode> {
698    let config = aster::config::declarative_providers::create_custom_provider(
699        &request.engine,
700        request.display_name,
701        request.api_url,
702        request.api_key,
703        request.models,
704        request.supports_streaming,
705        request.headers,
706    )
707    .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
708
709    if let Err(e) = aster::providers::refresh_custom_providers().await {
710        tracing::warn!("Failed to refresh custom providers after creation: {}", e);
711    }
712
713    Ok(Json(format!("Custom provider added - ID: {}", config.id())))
714}
715
716#[utoipa::path(
717    get,
718    path = "/config/custom-providers/{id}",
719    responses(
720        (status = 200, description = "Custom provider retrieved successfully", body = LoadedProvider),
721        (status = 404, description = "Provider not found"),
722        (status = 500, description = "Internal server error")
723    )
724)]
725pub async fn get_custom_provider(
726    Path(id): Path<String>,
727) -> Result<Json<LoadedProvider>, StatusCode> {
728    let loaded_provider = aster::config::declarative_providers::load_provider(id.as_str())
729        .map_err(|_| StatusCode::NOT_FOUND)?;
730
731    Ok(Json(loaded_provider))
732}
733
734#[utoipa::path(
735    delete,
736    path = "/config/custom-providers/{id}",
737    responses(
738        (status = 200, description = "Custom provider removed successfully", body = String),
739        (status = 404, description = "Provider not found"),
740        (status = 500, description = "Internal server error")
741    )
742)]
743pub async fn remove_custom_provider(Path(id): Path<String>) -> Result<Json<String>, StatusCode> {
744    aster::config::declarative_providers::remove_custom_provider(&id)
745        .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
746
747    if let Err(e) = aster::providers::refresh_custom_providers().await {
748        tracing::warn!("Failed to refresh custom providers after deletion: {}", e);
749    }
750
751    Ok(Json(format!("Removed custom provider: {}", id)))
752}
753
754#[utoipa::path(
755    put,
756    path = "/config/custom-providers/{id}",
757    request_body = UpdateCustomProviderRequest,
758    responses(
759        (status = 200, description = "Custom provider updated successfully", body = String),
760        (status = 404, description = "Provider not found"),
761        (status = 500, description = "Internal server error")
762    )
763)]
764pub async fn update_custom_provider(
765    Path(id): Path<String>,
766    Json(request): Json<UpdateCustomProviderRequest>,
767) -> Result<Json<String>, StatusCode> {
768    aster::config::declarative_providers::update_custom_provider(
769        &id,
770        &request.engine,
771        request.display_name,
772        request.api_url,
773        request.api_key,
774        request.models,
775        request.supports_streaming,
776    )
777    .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
778
779    if let Err(e) = aster::providers::refresh_custom_providers().await {
780        tracing::warn!("Failed to refresh custom providers after update: {}", e);
781    }
782
783    Ok(Json(format!("Updated custom provider: {}", id)))
784}
785
786#[utoipa::path(
787    post,
788    path = "/config/check_provider",
789    request_body = CheckProviderRequest,
790)]
791pub async fn check_provider(
792    Json(CheckProviderRequest { provider }): Json<CheckProviderRequest>,
793) -> Result<(), (StatusCode, String)> {
794    create_with_default_model(&provider)
795        .await
796        .map_err(|err| (StatusCode::BAD_REQUEST, err.to_string()))?;
797    Ok(())
798}
799
800#[utoipa::path(
801    post,
802    path = "/config/set_provider",
803    request_body = SetProviderRequest,
804)]
805pub async fn set_config_provider(
806    Json(SetProviderRequest { provider, model }): Json<SetProviderRequest>,
807) -> Result<(), (StatusCode, String)> {
808    create_with_default_model(&provider)
809        .await
810        .and_then(|_| {
811            let config = Config::global();
812            config
813                .set_aster_provider(provider)
814                .and_then(|_| config.set_aster_model(model))
815                .map_err(|e| anyhow::anyhow!(e))
816        })
817        .map_err(|err| (StatusCode::BAD_REQUEST, err.to_string()))?;
818    Ok(())
819}
820
821pub fn routes(state: Arc<AppState>) -> Router {
822    Router::new()
823        .route("/config", get(read_all_config))
824        .route("/config/upsert", post(upsert_config))
825        .route("/config/remove", post(remove_config))
826        .route("/config/read", post(read_config))
827        .route("/config/extensions", get(get_extensions))
828        .route("/config/extensions", post(add_extension))
829        .route("/config/extensions/{name}", delete(remove_extension))
830        .route("/config/providers", get(providers))
831        .route("/config/providers/{name}/models", get(get_provider_models))
832        .route("/config/detect-provider", post(detect_provider))
833        .route("/config/slash_commands", get(get_slash_commands))
834        .route("/config/pricing", post(get_pricing))
835        .route("/config/init", post(init_config))
836        .route("/config/backup", post(backup_config))
837        .route("/config/recover", post(recover_config))
838        .route("/config/validate", get(validate_config))
839        .route("/config/permissions", post(upsert_permissions))
840        .route("/config/custom-providers", post(create_custom_provider))
841        .route(
842            "/config/custom-providers/{id}",
843            delete(remove_custom_provider),
844        )
845        .route("/config/custom-providers/{id}", put(update_custom_provider))
846        .route("/config/custom-providers/{id}", get(get_custom_provider))
847        .route("/config/check_provider", post(check_provider))
848        .route("/config/set_provider", post(set_config_provider))
849        .with_state(state)
850}
851
852#[cfg(test)]
853mod tests {
854    use http::HeaderMap;
855
856    use super::*;
857
858    #[tokio::test]
859    async fn test_read_model_limits() {
860        let mut headers = HeaderMap::new();
861        headers.insert("X-Secret-Key", "test".parse().unwrap());
862
863        let result = read_config(Json(ConfigKeyQuery {
864            key: "model-limits".to_string(),
865            is_secret: false,
866        }))
867        .await;
868
869        assert!(result.is_ok());
870        let response = match result.unwrap().0 {
871            ConfigValueResponse::Value(value) => value,
872            ConfigValueResponse::MaskedValue(_) => panic!("unexpected secret"),
873        };
874
875        let limits: Vec<aster::model::ModelLimitConfig> = serde_json::from_value(response).unwrap();
876        assert!(!limits.is_empty());
877
878        let gpt4_limit = limits.iter().find(|l| l.pattern == "gpt-4o");
879        assert!(gpt4_limit.is_some());
880        assert_eq!(gpt4_limit.unwrap().context_limit, 128_000);
881    }
882}