1use crate::routes::utils::check_provider_configured;
2use crate::state::AppState;
3use aster::config::declarative_providers::LoadedProvider;
4use aster::config::paths::Paths;
5use aster::config::ExtensionEntry;
6use aster::config::{Config, ConfigError};
7use aster::model::ModelConfig;
8use aster::providers::auto_detect::detect_provider_from_api_key;
9use aster::providers::base::{ProviderMetadata, ProviderType};
10use aster::providers::canonical::maybe_get_canonical_model;
11use aster::providers::create_with_default_model;
12use aster::providers::errors::ProviderError;
13use aster::providers::providers as get_providers;
14use aster::providers::{retry_operation, RetryConfig};
15use aster::{
16 agents::execute_commands, agents::ExtensionConfig, config::permission::PermissionLevel,
17 slash_commands,
18};
19use axum::routing::put;
20use axum::{
21 extract::Path,
22 routing::{delete, get, post},
23 Json, Router,
24};
25use http::StatusCode;
26use serde::{Deserialize, Serialize};
27use serde_json::Value;
28use serde_yaml;
29use std::{collections::HashMap, sync::Arc};
30use utoipa::ToSchema;
31
32#[derive(Serialize, ToSchema)]
33pub struct ExtensionResponse {
34 pub extensions: Vec<ExtensionEntry>,
35 #[serde(default)]
36 pub warnings: Vec<String>,
37}
38
39#[derive(Deserialize, ToSchema)]
40pub struct ExtensionQuery {
41 pub name: String,
42 pub config: ExtensionConfig,
43 pub enabled: bool,
44}
45
46#[derive(Deserialize, ToSchema)]
47pub struct UpsertConfigQuery {
48 pub key: String,
49 pub value: Value,
50 pub is_secret: bool,
51}
52
53#[derive(Deserialize, Serialize, ToSchema)]
54pub struct ConfigKeyQuery {
55 pub key: String,
56 pub is_secret: bool,
57}
58
59#[derive(Serialize, ToSchema)]
60pub struct ConfigResponse {
61 pub config: HashMap<String, Value>,
62}
63
64#[derive(Debug, Serialize, Deserialize, ToSchema)]
65pub struct ProviderDetails {
66 pub name: String,
67 pub metadata: ProviderMetadata,
68 pub is_configured: bool,
69 pub provider_type: ProviderType,
70}
71
72#[derive(Serialize, ToSchema)]
73pub struct ProvidersResponse {
74 pub providers: Vec<ProviderDetails>,
75}
76
77#[derive(Debug, Serialize, Deserialize, ToSchema)]
78pub struct ToolPermission {
79 pub tool_name: String,
80 pub permission: PermissionLevel,
81}
82
83#[derive(Deserialize, ToSchema)]
84pub struct UpsertPermissionsQuery {
85 pub tool_permissions: Vec<ToolPermission>,
86}
87
88#[derive(Deserialize, ToSchema)]
89pub struct UpdateCustomProviderRequest {
90 pub engine: String,
91 pub display_name: String,
92 pub api_url: String,
93 pub api_key: String,
94 pub models: Vec<String>,
95 pub supports_streaming: Option<bool>,
96 pub headers: Option<std::collections::HashMap<String, String>>,
97}
98
99#[derive(Deserialize, ToSchema)]
100pub struct CheckProviderRequest {
101 pub provider: String,
102}
103
104#[derive(Deserialize, ToSchema)]
105pub struct SetProviderRequest {
106 pub provider: String,
107 pub model: String,
108}
109
110#[derive(Serialize, ToSchema)]
111#[serde(rename_all = "camelCase")]
112pub struct MaskedSecret {
113 pub masked_value: String,
114}
115
116#[derive(Serialize, ToSchema)]
117#[serde(untagged)]
118pub enum ConfigValueResponse {
119 Value(Value),
120 MaskedValue(MaskedSecret),
121}
122
123#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)]
124pub enum CommandType {
125 Builtin,
126 Recipe,
127}
128
129#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)]
130pub struct SlashCommand {
131 pub command: String,
132 pub help: String,
133 pub command_type: CommandType,
134}
135#[derive(Serialize, ToSchema)]
136pub struct SlashCommandsResponse {
137 pub commands: Vec<SlashCommand>,
138}
139
140#[derive(Deserialize, ToSchema)]
141pub struct DetectProviderRequest {
142 pub api_key: String,
143}
144
145#[derive(Serialize, ToSchema)]
146pub struct DetectProviderResponse {
147 pub provider_name: String,
148 pub models: Vec<String>,
149}
150#[utoipa::path(
151 post,
152 path = "/config/upsert",
153 request_body = UpsertConfigQuery,
154 responses(
155 (status = 200, description = "Configuration value upserted successfully", body = String),
156 (status = 500, description = "Internal server error")
157 )
158)]
159pub async fn upsert_config(
160 Json(query): Json<UpsertConfigQuery>,
161) -> Result<Json<Value>, StatusCode> {
162 let config = Config::global();
163 let result = config.set(&query.key, &query.value, query.is_secret);
164
165 match result {
166 Ok(_) => Ok(Json(Value::String(format!("Upserted key {}", query.key)))),
167 Err(_) => Err(StatusCode::INTERNAL_SERVER_ERROR),
168 }
169}
170
171#[utoipa::path(
172 post,
173 path = "/config/remove",
174 request_body = ConfigKeyQuery,
175 responses(
176 (status = 200, description = "Configuration value removed successfully", body = String),
177 (status = 404, description = "Configuration key not found"),
178 (status = 500, description = "Internal server error")
179 )
180)]
181pub async fn remove_config(Json(query): Json<ConfigKeyQuery>) -> Result<Json<String>, StatusCode> {
182 let config = Config::global();
183
184 let result = if query.is_secret {
185 config.delete_secret(&query.key)
186 } else {
187 config.delete(&query.key)
188 };
189
190 match result {
191 Ok(_) => Ok(Json(format!("Removed key {}", query.key))),
192 Err(_) => Err(StatusCode::NOT_FOUND),
193 }
194}
195
196const SECRET_MASK_SHOW_LEN: usize = 8;
197
198fn mask_secret(secret: Value) -> String {
199 let as_string = match secret {
200 Value::String(s) => s,
201 _ => serde_json::to_string(&secret).unwrap_or_else(|_| secret.to_string()),
202 };
203
204 let chars: Vec<_> = as_string.chars().collect();
205 let show_len = std::cmp::min(chars.len() / 2, SECRET_MASK_SHOW_LEN);
206 let visible: String = chars.iter().take(show_len).collect();
207 let mask = "*".repeat(chars.len() - show_len);
208
209 format!("{}{}", visible, mask)
210}
211
212#[utoipa::path(
213 post,
214 path = "/config/read",
215 request_body = ConfigKeyQuery,
216 responses(
217 (status = 200, description = "Configuration value retrieved successfully", body = Value),
218 (status = 500, description = "Unable to get the configuration value"),
219 )
220)]
221pub async fn read_config(
222 Json(query): Json<ConfigKeyQuery>,
223) -> Result<Json<ConfigValueResponse>, StatusCode> {
224 if query.key == "model-limits" {
225 let limits = ModelConfig::get_all_model_limits();
226 return Ok(Json(ConfigValueResponse::Value(
227 serde_json::to_value(limits).map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?,
228 )));
229 }
230
231 let config = Config::global();
232
233 let response_value = match config.get(&query.key, query.is_secret) {
234 Ok(value) => {
235 if query.is_secret {
236 ConfigValueResponse::MaskedValue(MaskedSecret {
237 masked_value: mask_secret(value),
238 })
239 } else {
240 ConfigValueResponse::Value(value)
241 }
242 }
243 Err(ConfigError::NotFound(_)) => ConfigValueResponse::Value(Value::Null),
244 Err(_) => {
245 return Err(StatusCode::INTERNAL_SERVER_ERROR);
246 }
247 };
248 Ok(Json(response_value))
249}
250
251#[utoipa::path(
252 get,
253 path = "/config/extensions",
254 responses(
255 (status = 200, description = "All extensions retrieved successfully", body = ExtensionResponse),
256 (status = 500, description = "Internal server error")
257 )
258)]
259pub async fn get_extensions() -> Result<Json<ExtensionResponse>, StatusCode> {
260 let extensions = aster::config::get_all_extensions();
261 let warnings = aster::config::get_warnings();
262 Ok(Json(ExtensionResponse {
263 extensions,
264 warnings,
265 }))
266}
267
268#[utoipa::path(
269 post,
270 path = "/config/extensions",
271 request_body = ExtensionQuery,
272 responses(
273 (status = 200, description = "Extension added or updated successfully", body = String),
274 (status = 400, description = "Invalid request"),
275 (status = 422, description = "Could not serialize config.yaml"),
276 (status = 500, description = "Internal server error")
277 )
278)]
279pub async fn add_extension(
280 Json(extension_query): Json<ExtensionQuery>,
281) -> Result<Json<String>, StatusCode> {
282 let extensions = aster::config::get_all_extensions();
283 let key = aster::config::extensions::name_to_key(&extension_query.name);
284
285 let is_update = extensions.iter().any(|e| e.config.key() == key);
286
287 aster::config::set_extension(ExtensionEntry {
288 enabled: extension_query.enabled,
289 config: extension_query.config,
290 });
291
292 if is_update {
293 Ok(Json(format!("Updated extension {}", extension_query.name)))
294 } else {
295 Ok(Json(format!("Added extension {}", extension_query.name)))
296 }
297}
298
299#[utoipa::path(
300 delete,
301 path = "/config/extensions/{name}",
302 responses(
303 (status = 200, description = "Extension removed successfully", body = String),
304 (status = 404, description = "Extension not found"),
305 (status = 500, description = "Internal server error")
306 )
307)]
308pub async fn remove_extension(Path(name): Path<String>) -> Result<Json<String>, StatusCode> {
309 let key = aster::config::extensions::name_to_key(&name);
310 aster::config::remove_extension(&key);
311 Ok(Json(format!("Removed extension {}", name)))
312}
313
314#[utoipa::path(
315 get,
316 path = "/config",
317 responses(
318 (status = 200, description = "All configuration values retrieved successfully", body = ConfigResponse)
319 )
320)]
321pub async fn read_all_config() -> Result<Json<ConfigResponse>, StatusCode> {
322 let config = Config::global();
323
324 let values = config
325 .all_values()
326 .map_err(|_| StatusCode::UNPROCESSABLE_ENTITY)?;
327
328 Ok(Json(ConfigResponse { config: values }))
329}
330
331#[utoipa::path(
332 get,
333 path = "/config/providers",
334 responses(
335 (status = 200, description = "All configuration values retrieved successfully", body = [ProviderDetails])
336 )
337)]
338pub async fn providers() -> Result<Json<Vec<ProviderDetails>>, StatusCode> {
339 let providers = get_providers().await;
340 let providers_response: Vec<ProviderDetails> = providers
341 .into_iter()
342 .map(|(metadata, provider_type)| {
343 let is_configured = check_provider_configured(&metadata, provider_type);
344
345 ProviderDetails {
346 name: metadata.name.clone(),
347 metadata,
348 is_configured,
349 provider_type,
350 }
351 })
352 .collect();
353
354 Ok(Json(providers_response))
355}
356
357#[utoipa::path(
358 get,
359 path = "/config/providers/{name}/models",
360 params(
361 ("name" = String, Path, description = "Provider name (e.g., openai)")
362 ),
363 responses(
364 (status = 200, description = "Models fetched successfully", body = [String]),
365 (status = 400, description = "Unknown provider, provider not configured, or authentication error"),
366 (status = 429, description = "Rate limit exceeded"),
367 (status = 500, description = "Internal server error")
368 )
369)]
370pub async fn get_provider_models(
371 Path(name): Path<String>,
372) -> Result<Json<Vec<String>>, StatusCode> {
373 let loaded_provider = aster::config::declarative_providers::load_provider(name.as_str()).ok();
374 if let Some(loaded_provider) = loaded_provider {
376 return Ok(Json(
377 loaded_provider
378 .config
379 .models
380 .into_iter()
381 .map(|m| m.name)
382 .collect::<Vec<_>>(),
383 ));
384 }
385
386 let all = get_providers()
387 .await
388 .into_iter()
389 .collect::<Vec<_>>();
391 let Some((metadata, provider_type)) = all.into_iter().find(|(m, _)| m.name == name) else {
392 return Err(StatusCode::BAD_REQUEST);
393 };
394 if !check_provider_configured(&metadata, provider_type) {
395 return Err(StatusCode::BAD_REQUEST);
396 }
397
398 let model_config =
399 ModelConfig::new(&metadata.default_model).map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
400 let provider = aster::providers::create(&name, model_config)
401 .await
402 .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
403
404 let models_result = retry_operation(&RetryConfig::default(), || async {
405 provider.fetch_recommended_models().await
406 })
407 .await;
408
409 match models_result {
410 Ok(Some(models)) => Ok(Json(models)),
411 Ok(None) => Ok(Json(Vec::new())),
412 Err(provider_error) => {
413 let status_code = match provider_error {
414 ProviderError::Authentication(_) => StatusCode::BAD_REQUEST,
416 ProviderError::UsageError(_) => StatusCode::BAD_REQUEST,
417
418 ProviderError::RateLimitExceeded { .. } => StatusCode::TOO_MANY_REQUESTS,
420
421 _ => StatusCode::INTERNAL_SERVER_ERROR,
423 };
424
425 tracing::warn!(
426 "Provider {} failed to fetch models: {}",
427 name,
428 provider_error
429 );
430 Err(status_code)
431 }
432 }
433}
434
435#[utoipa::path(
436 get,
437 path = "/config/slash_commands",
438 responses(
439 (status = 200, description = "Slash commands retrieved successfully", body = SlashCommandsResponse)
440 )
441)]
442pub async fn get_slash_commands() -> Result<Json<SlashCommandsResponse>, StatusCode> {
443 let mut commands: Vec<_> = slash_commands::list_commands()
444 .iter()
445 .map(|command| SlashCommand {
446 command: command.command.clone(),
447 help: command.recipe_path.clone(),
448 command_type: CommandType::Recipe,
449 })
450 .collect();
451
452 for cmd_def in execute_commands::list_commands() {
453 commands.push(SlashCommand {
454 command: cmd_def.name.to_string(),
455 help: cmd_def.description.to_string(),
456 command_type: CommandType::Builtin,
457 });
458 }
459
460 Ok(Json(SlashCommandsResponse { commands }))
461}
462
463#[derive(Serialize, ToSchema)]
464pub struct PricingData {
465 pub provider: String,
466 pub model: String,
467 pub input_token_cost: f64,
468 pub output_token_cost: f64,
469 pub currency: String,
470 pub context_length: Option<u32>,
471}
472
473#[derive(Serialize, ToSchema)]
474pub struct PricingResponse {
475 pub pricing: Vec<PricingData>,
476 pub source: String,
477}
478
479#[derive(Deserialize, ToSchema)]
480pub struct PricingQuery {
481 pub provider: String,
482 pub model: String,
483}
484
485#[utoipa::path(
486 post,
487 path = "/config/pricing",
488 request_body = PricingQuery,
489 responses(
490 (status = 200, description = "Model pricing data retrieved successfully", body = PricingResponse)
491 )
492)]
493pub async fn get_pricing(
494 Json(query): Json<PricingQuery>,
495) -> Result<Json<PricingResponse>, StatusCode> {
496 let canonical_model =
497 maybe_get_canonical_model(&query.provider, &query.model).ok_or(StatusCode::NOT_FOUND)?;
498
499 let mut pricing_data = Vec::new();
500
501 if let (Some(input_cost), Some(output_cost)) = (
502 canonical_model.pricing.prompt,
503 canonical_model.pricing.completion,
504 ) {
505 pricing_data.push(PricingData {
506 provider: query.provider.clone(),
507 model: query.model.clone(),
508 input_token_cost: input_cost,
509 output_token_cost: output_cost,
510 currency: "$".to_string(),
511 context_length: Some(canonical_model.context_length as u32),
512 });
513 }
514
515 Ok(Json(PricingResponse {
516 pricing: pricing_data,
517 source: "canonical".to_string(),
518 }))
519}
520
521#[utoipa::path(
522 post,
523 path = "/config/init",
524 responses(
525 (status = 200, description = "Config initialization check completed", body = String),
526 (status = 500, description = "Internal server error")
527 )
528)]
529pub async fn init_config() -> Result<Json<String>, StatusCode> {
530 let config = Config::global();
531
532 if config.exists() {
533 return Ok(Json("Config already exists".to_string()));
534 }
535
536 match aster::config::base::load_init_config_from_workspace() {
538 Ok(init_values) => match config.initialize_if_empty(init_values) {
539 Ok(_) => Ok(Json("Config initialized successfully".to_string())),
540 Err(_) => Err(StatusCode::INTERNAL_SERVER_ERROR),
541 },
542 Err(_) => Ok(Json(
543 "No init-config.yaml found, using default configuration".to_string(),
544 )),
545 }
546}
547
548#[utoipa::path(
549 post,
550 path = "/config/permissions",
551 request_body = UpsertPermissionsQuery,
552 responses(
553 (status = 200, description = "Permission update completed", body = String),
554 (status = 400, description = "Invalid request"),
555 )
556)]
557pub async fn upsert_permissions(
558 Json(query): Json<UpsertPermissionsQuery>,
559) -> Result<Json<String>, StatusCode> {
560 let mut permission_manager = aster::config::PermissionManager::default();
561
562 for tool_permission in &query.tool_permissions {
563 permission_manager.update_user_permission(
564 &tool_permission.tool_name,
565 tool_permission.permission.clone(),
566 );
567 }
568
569 Ok(Json("Permissions updated successfully".to_string()))
570}
571
572#[utoipa::path(
573 post,
574 path = "/config/detect-provider",
575 request_body = DetectProviderRequest,
576 responses(
577 (status = 200, description = "Provider detected successfully", body = DetectProviderResponse),
578 (status = 404, description = "No matching provider found"),
579 )
580)]
581pub async fn detect_provider(
582 Json(detect_request): Json<DetectProviderRequest>,
583) -> Result<Json<DetectProviderResponse>, StatusCode> {
584 let api_key = detect_request.api_key.trim();
585
586 match detect_provider_from_api_key(api_key).await {
587 Some((provider_name, models)) => Ok(Json(DetectProviderResponse {
588 provider_name,
589 models,
590 })),
591 None => Err(StatusCode::NOT_FOUND),
592 }
593}
594
595#[utoipa::path(
596 post,
597 path = "/config/backup",
598 responses(
599 (status = 200, description = "Config file backed up", body = String),
600 (status = 500, description = "Internal server error")
601 )
602)]
603pub async fn backup_config() -> Result<Json<String>, StatusCode> {
604 let config_path = Paths::config_dir().join("config.yaml");
605
606 if config_path.exists() {
607 let file_name = config_path
608 .file_name()
609 .ok_or(StatusCode::INTERNAL_SERVER_ERROR)?;
610
611 let mut backup_name = file_name.to_os_string();
612 backup_name.push(".bak");
613
614 let backup = config_path.with_file_name(backup_name);
615 match std::fs::copy(&config_path, &backup) {
616 Ok(_) => Ok(Json(format!("Copied {:?} to {:?}", config_path, backup))),
617 Err(_) => Err(StatusCode::INTERNAL_SERVER_ERROR),
618 }
619 } else {
620 Err(StatusCode::INTERNAL_SERVER_ERROR)
621 }
622}
623
624#[utoipa::path(
625 post,
626 path = "/config/recover",
627 responses(
628 (status = 200, description = "Config recovery attempted", body = String),
629 (status = 500, description = "Internal server error")
630 )
631)]
632pub async fn recover_config() -> Result<Json<String>, StatusCode> {
633 let config = Config::global();
634
635 match config.all_values() {
637 Ok(values) => {
638 let recovered_keys: Vec<String> = values.keys().cloned().collect();
639 if recovered_keys.is_empty() {
640 Ok(Json("Config recovery completed, but no data was recoverable. Starting with empty configuration.".to_string()))
641 } else {
642 Ok(Json(format!(
643 "Config recovery completed. Recovered {} keys: {}",
644 recovered_keys.len(),
645 recovered_keys.join(", ")
646 )))
647 }
648 }
649 Err(e) => {
650 tracing::error!("Config recovery failed: {}", e);
651 Err(StatusCode::INTERNAL_SERVER_ERROR)
652 }
653 }
654}
655
656#[utoipa::path(
657 get,
658 path = "/config/validate",
659 responses(
660 (status = 200, description = "Config validation result", body = String),
661 (status = 422, description = "Config file is corrupted")
662 )
663)]
664pub async fn validate_config() -> Result<Json<String>, StatusCode> {
665 let config_path = Paths::config_dir().join("config.yaml");
666
667 if !config_path.exists() {
668 return Ok(Json("Config file does not exist".to_string()));
669 }
670
671 match std::fs::read_to_string(&config_path) {
672 Ok(content) => match serde_yaml::from_str::<serde_yaml::Value>(&content) {
673 Ok(_) => Ok(Json("Config file is valid".to_string())),
674 Err(e) => {
675 tracing::warn!("Config validation failed: {}", e);
676 Err(StatusCode::UNPROCESSABLE_ENTITY)
677 }
678 },
679 Err(e) => {
680 tracing::error!("Failed to read config file: {}", e);
681 Err(StatusCode::INTERNAL_SERVER_ERROR)
682 }
683 }
684}
685#[utoipa::path(
686 post,
687 path = "/config/custom-providers",
688 request_body = UpdateCustomProviderRequest,
689 responses(
690 (status = 200, description = "Custom provider created successfully", body = String),
691 (status = 400, description = "Invalid request"),
692 (status = 500, description = "Internal server error")
693 )
694)]
695pub async fn create_custom_provider(
696 Json(request): Json<UpdateCustomProviderRequest>,
697) -> Result<Json<String>, StatusCode> {
698 let config = aster::config::declarative_providers::create_custom_provider(
699 &request.engine,
700 request.display_name,
701 request.api_url,
702 request.api_key,
703 request.models,
704 request.supports_streaming,
705 request.headers,
706 )
707 .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
708
709 if let Err(e) = aster::providers::refresh_custom_providers().await {
710 tracing::warn!("Failed to refresh custom providers after creation: {}", e);
711 }
712
713 Ok(Json(format!("Custom provider added - ID: {}", config.id())))
714}
715
716#[utoipa::path(
717 get,
718 path = "/config/custom-providers/{id}",
719 responses(
720 (status = 200, description = "Custom provider retrieved successfully", body = LoadedProvider),
721 (status = 404, description = "Provider not found"),
722 (status = 500, description = "Internal server error")
723 )
724)]
725pub async fn get_custom_provider(
726 Path(id): Path<String>,
727) -> Result<Json<LoadedProvider>, StatusCode> {
728 let loaded_provider = aster::config::declarative_providers::load_provider(id.as_str())
729 .map_err(|_| StatusCode::NOT_FOUND)?;
730
731 Ok(Json(loaded_provider))
732}
733
734#[utoipa::path(
735 delete,
736 path = "/config/custom-providers/{id}",
737 responses(
738 (status = 200, description = "Custom provider removed successfully", body = String),
739 (status = 404, description = "Provider not found"),
740 (status = 500, description = "Internal server error")
741 )
742)]
743pub async fn remove_custom_provider(Path(id): Path<String>) -> Result<Json<String>, StatusCode> {
744 aster::config::declarative_providers::remove_custom_provider(&id)
745 .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
746
747 if let Err(e) = aster::providers::refresh_custom_providers().await {
748 tracing::warn!("Failed to refresh custom providers after deletion: {}", e);
749 }
750
751 Ok(Json(format!("Removed custom provider: {}", id)))
752}
753
754#[utoipa::path(
755 put,
756 path = "/config/custom-providers/{id}",
757 request_body = UpdateCustomProviderRequest,
758 responses(
759 (status = 200, description = "Custom provider updated successfully", body = String),
760 (status = 404, description = "Provider not found"),
761 (status = 500, description = "Internal server error")
762 )
763)]
764pub async fn update_custom_provider(
765 Path(id): Path<String>,
766 Json(request): Json<UpdateCustomProviderRequest>,
767) -> Result<Json<String>, StatusCode> {
768 aster::config::declarative_providers::update_custom_provider(
769 &id,
770 &request.engine,
771 request.display_name,
772 request.api_url,
773 request.api_key,
774 request.models,
775 request.supports_streaming,
776 )
777 .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
778
779 if let Err(e) = aster::providers::refresh_custom_providers().await {
780 tracing::warn!("Failed to refresh custom providers after update: {}", e);
781 }
782
783 Ok(Json(format!("Updated custom provider: {}", id)))
784}
785
786#[utoipa::path(
787 post,
788 path = "/config/check_provider",
789 request_body = CheckProviderRequest,
790)]
791pub async fn check_provider(
792 Json(CheckProviderRequest { provider }): Json<CheckProviderRequest>,
793) -> Result<(), (StatusCode, String)> {
794 create_with_default_model(&provider)
795 .await
796 .map_err(|err| (StatusCode::BAD_REQUEST, err.to_string()))?;
797 Ok(())
798}
799
800#[utoipa::path(
801 post,
802 path = "/config/set_provider",
803 request_body = SetProviderRequest,
804)]
805pub async fn set_config_provider(
806 Json(SetProviderRequest { provider, model }): Json<SetProviderRequest>,
807) -> Result<(), (StatusCode, String)> {
808 create_with_default_model(&provider)
809 .await
810 .and_then(|_| {
811 let config = Config::global();
812 config
813 .set_aster_provider(provider)
814 .and_then(|_| config.set_aster_model(model))
815 .map_err(|e| anyhow::anyhow!(e))
816 })
817 .map_err(|err| (StatusCode::BAD_REQUEST, err.to_string()))?;
818 Ok(())
819}
820
821pub fn routes(state: Arc<AppState>) -> Router {
822 Router::new()
823 .route("/config", get(read_all_config))
824 .route("/config/upsert", post(upsert_config))
825 .route("/config/remove", post(remove_config))
826 .route("/config/read", post(read_config))
827 .route("/config/extensions", get(get_extensions))
828 .route("/config/extensions", post(add_extension))
829 .route("/config/extensions/{name}", delete(remove_extension))
830 .route("/config/providers", get(providers))
831 .route("/config/providers/{name}/models", get(get_provider_models))
832 .route("/config/detect-provider", post(detect_provider))
833 .route("/config/slash_commands", get(get_slash_commands))
834 .route("/config/pricing", post(get_pricing))
835 .route("/config/init", post(init_config))
836 .route("/config/backup", post(backup_config))
837 .route("/config/recover", post(recover_config))
838 .route("/config/validate", get(validate_config))
839 .route("/config/permissions", post(upsert_permissions))
840 .route("/config/custom-providers", post(create_custom_provider))
841 .route(
842 "/config/custom-providers/{id}",
843 delete(remove_custom_provider),
844 )
845 .route("/config/custom-providers/{id}", put(update_custom_provider))
846 .route("/config/custom-providers/{id}", get(get_custom_provider))
847 .route("/config/check_provider", post(check_provider))
848 .route("/config/set_provider", post(set_config_provider))
849 .with_state(state)
850}
851
852#[cfg(test)]
853mod tests {
854 use http::HeaderMap;
855
856 use super::*;
857
858 #[tokio::test]
859 async fn test_read_model_limits() {
860 let mut headers = HeaderMap::new();
861 headers.insert("X-Secret-Key", "test".parse().unwrap());
862
863 let result = read_config(Json(ConfigKeyQuery {
864 key: "model-limits".to_string(),
865 is_secret: false,
866 }))
867 .await;
868
869 assert!(result.is_ok());
870 let response = match result.unwrap().0 {
871 ConfigValueResponse::Value(value) => value,
872 ConfigValueResponse::MaskedValue(_) => panic!("unexpected secret"),
873 };
874
875 let limits: Vec<aster::model::ModelLimitConfig> = serde_json::from_value(response).unwrap();
876 assert!(!limits.is_empty());
877
878 let gpt4_limit = limits.iter().find(|l| l.pattern == "gpt-4o");
879 assert!(gpt4_limit.is_some());
880 assert_eq!(gpt4_limit.unwrap().context_limit, 128_000);
881 }
882}