1use std::collections::HashMap;
2use std::sync::{Arc, RwLock};
3
4use super::{
5 anthropic::AnthropicProvider,
6 azure::AzureProvider,
7 base::{Provider, ProviderMetadata},
8 claude_code::ClaudeCodeProvider,
9 codex::CodexProvider,
10 codex_stateful::CodexStatefulProvider,
11 cursor_agent::CursorAgentProvider,
12 databricks::DatabricksProvider,
13 gcpvertexai::GcpVertexAIProvider,
14 gemini_cli::GeminiCliProvider,
15 githubcopilot::GithubCopilotProvider,
16 google::GoogleProvider,
17 lead_worker::LeadWorkerProvider,
18 litellm::LiteLLMProvider,
19 ollama::OllamaProvider,
20 openai::OpenAiProvider,
21 openrouter::OpenRouterProvider,
22 provider_registry::ProviderRegistry,
23 snowflake::SnowflakeProvider,
24 tetrate::TetrateProvider,
25 venice::VeniceProvider,
26 xai::XaiProvider,
27};
28#[cfg(feature = "provider-aws")]
29use super::{bedrock::BedrockProvider, sagemaker_tgi::SageMakerTgiProvider};
30use crate::model::ModelConfig;
31use crate::providers::base::ProviderType;
32use crate::{
33 config::declarative_providers::register_declarative_providers,
34 providers::provider_registry::ProviderEntry,
35};
36use anyhow::Result;
37use tokio::sync::OnceCell;
38
39const DEFAULT_LEAD_TURNS: usize = 3;
40const DEFAULT_FAILURE_THRESHOLD: usize = 2;
41const DEFAULT_FALLBACK_TURNS: usize = 2;
42
43static REGISTRY: OnceCell<RwLock<ProviderRegistry>> = OnceCell::const_new();
44
45async fn init_registry() -> RwLock<ProviderRegistry> {
46 let mut registry = ProviderRegistry::new().with_providers(|registry| {
47 registry
48 .register::<AnthropicProvider, _>(|m| Box::pin(AnthropicProvider::from_env(m)), true);
49 registry.register::<AzureProvider, _>(|m| Box::pin(AzureProvider::from_env(m)), false);
50 #[cfg(feature = "provider-aws")]
51 registry.register::<BedrockProvider, _>(|m| Box::pin(BedrockProvider::from_env(m)), false);
52 registry
53 .register::<ClaudeCodeProvider, _>(|m| Box::pin(ClaudeCodeProvider::from_env(m)), true);
54 registry.register::<CodexProvider, _>(|m| Box::pin(CodexProvider::from_env(m)), true);
55 registry.register::<CodexStatefulProvider, _>(
56 |m| Box::pin(CodexStatefulProvider::from_env(m)),
57 true,
58 );
59 registry.register::<CursorAgentProvider, _>(
60 |m| Box::pin(CursorAgentProvider::from_env(m)),
61 false,
62 );
63 registry
64 .register::<DatabricksProvider, _>(|m| Box::pin(DatabricksProvider::from_env(m)), true);
65 registry.register::<GcpVertexAIProvider, _>(
66 |m| Box::pin(GcpVertexAIProvider::from_env(m)),
67 false,
68 );
69 registry
70 .register::<GeminiCliProvider, _>(|m| Box::pin(GeminiCliProvider::from_env(m)), false);
71 registry.register::<GithubCopilotProvider, _>(
72 |m| Box::pin(GithubCopilotProvider::from_env(m)),
73 false,
74 );
75 registry.register::<GoogleProvider, _>(|m| Box::pin(GoogleProvider::from_env(m)), true);
76 registry.register::<LiteLLMProvider, _>(|m| Box::pin(LiteLLMProvider::from_env(m)), false);
77 registry.register::<OllamaProvider, _>(|m| Box::pin(OllamaProvider::from_env(m)), true);
78 registry.register::<OpenAiProvider, _>(|m| Box::pin(OpenAiProvider::from_env(m)), true);
79 registry
80 .register::<OpenRouterProvider, _>(|m| Box::pin(OpenRouterProvider::from_env(m)), true);
81 #[cfg(feature = "provider-aws")]
82 registry.register::<SageMakerTgiProvider, _>(
83 |m| Box::pin(SageMakerTgiProvider::from_env(m)),
84 false,
85 );
86 registry
87 .register::<SnowflakeProvider, _>(|m| Box::pin(SnowflakeProvider::from_env(m)), false);
88 registry.register::<TetrateProvider, _>(|m| Box::pin(TetrateProvider::from_env(m)), true);
89 registry.register::<VeniceProvider, _>(|m| Box::pin(VeniceProvider::from_env(m)), false);
90 registry.register::<XaiProvider, _>(|m| Box::pin(XaiProvider::from_env(m)), false);
91 });
92 if let Err(e) = load_custom_providers_into_registry(&mut registry) {
93 tracing::warn!("Failed to load custom providers: {}", e);
94 }
95 RwLock::new(registry)
96}
97
98fn load_custom_providers_into_registry(registry: &mut ProviderRegistry) -> Result<()> {
99 register_declarative_providers(registry)
100}
101
102async fn get_registry() -> &'static RwLock<ProviderRegistry> {
103 REGISTRY.get_or_init(init_registry).await
104}
105
106pub async fn providers() -> Vec<(ProviderMetadata, ProviderType)> {
107 get_registry()
108 .await
109 .read()
110 .unwrap()
111 .all_metadata_with_types()
112}
113
114pub async fn refresh_custom_providers() -> Result<()> {
115 let registry = get_registry().await;
116 registry.write().unwrap().remove_custom_providers();
117
118 if let Err(e) = load_custom_providers_into_registry(&mut registry.write().unwrap()) {
119 tracing::warn!("Failed to refresh custom providers: {}", e);
120 return Err(e);
121 }
122
123 tracing::info!("Custom providers refreshed");
124 Ok(())
125}
126
127async fn get_from_registry(name: &str) -> Result<ProviderEntry> {
128 let mapped_name = map_provider_alias(name);
130
131 #[cfg(not(feature = "provider-aws"))]
132 if mapped_name == "bedrock" || mapped_name == "sagemaker_tgi" {
133 return Err(anyhow::anyhow!(
134 "Provider {} is disabled at compile time; rebuild with feature provider-aws",
135 mapped_name
136 ));
137 }
138
139 let guard = get_registry().await.read().unwrap();
140 guard
141 .entries
142 .get(mapped_name.as_str())
143 .ok_or_else(|| anyhow::anyhow!("Unknown provider: {} (mapped to: {})", name, mapped_name))
144 .cloned()
145}
146
147fn parse_provider_alias_overrides(raw: &str) -> HashMap<String, String> {
156 let trimmed = raw.trim();
157 if trimmed.is_empty() {
158 return HashMap::new();
159 }
160
161 if let Ok(json_map) = serde_json::from_str::<HashMap<String, String>>(trimmed) {
162 return json_map
163 .into_iter()
164 .map(|(alias, target)| (alias.trim().to_lowercase(), target.trim().to_lowercase()))
165 .filter(|(alias, target)| !alias.is_empty() && !target.is_empty())
166 .collect();
167 }
168
169 let mut overrides = HashMap::new();
170 for pair in trimmed.split(',') {
171 let entry = pair.trim();
172 if entry.is_empty() {
173 continue;
174 }
175
176 let parsed = entry.split_once('=').or_else(|| entry.split_once(':'));
177
178 if let Some((alias, target)) = parsed {
179 let alias = alias.trim().to_lowercase();
180 let target = target.trim().to_lowercase();
181 if !alias.is_empty() && !target.is_empty() {
182 overrides.insert(alias, target);
183 }
184 }
185 }
186
187 overrides
188}
189
190fn load_provider_alias_overrides() -> HashMap<String, String> {
191 std::env::var("ASTER_PROVIDER_ALIAS_OVERRIDES")
192 .ok()
193 .map(|raw| parse_provider_alias_overrides(&raw))
194 .unwrap_or_default()
195}
196
197fn map_provider_alias(name: &str) -> String {
198 let normalized = name.trim().to_lowercase();
199
200 if normalized.is_empty() {
201 return normalized;
202 }
203
204 if normalized.starts_with("custom-") {
207 return "openai".to_string();
208 }
209
210 if let Some(mapped) = load_provider_alias_overrides().get(normalized.as_str()) {
212 return mapped.clone();
213 }
214
215 let mapped = match normalized.as_str() {
216 "deepseek" | "deep_seek" | "deep-seek" => "openai",
219 "qwen" | "tongyi" | "dashscope" | "aliyun" => "openai",
220 "zhipu" | "glm" | "chatglm" => "openai",
221 "baichuan" => "openai",
222 "moonshot" | "kimi" => "openai",
223 "minimax" => "openai",
224 "yi" | "01ai" | "lingyiwanwu" => "openai",
225 "stepfun" | "step" => "openai",
226 "bailian" | "百炼" => "openai",
227 "doubao" | "豆包" => "openai",
228 "spark" | "讯飞" | "xunfei" => "openai",
229 "hunyuan" | "混元" => "openai",
230 "ernie" | "文心" | "wenxin" => "openai",
231
232 "groq" => "openai",
234 "together" | "togetherai" => "openai",
235 "fireworks" | "fireworksai" => "openai",
236 "perplexity" => "openai",
237 "anyscale" => "openai",
238 "lepton" | "leptonai" => "openai",
239 "novita" | "novitaai" => "openai",
240 "siliconflow" => "openai",
241 "mistral" => "openai",
242 "cohere" => "openai",
243
244 "oneapi" | "one-api" | "one_api" => "openai",
246 "newapi" | "new-api" | "new_api" => "openai",
247 "vercel" | "vercel_ai" | "vercel-ai" => "openai",
248
249 "custom" | "custom_openai" | "openai_compatible" => "openai",
251
252 "claude" => "anthropic",
254 "anthropic_compatible" | "anthropic-compatible" => "anthropic",
255
256 "gemini" | "gemini_api_key" => "google",
258 "antigravity" => "google",
259
260 "azure" | "azure_openai" | "azure-openai" => "azure",
262 "vertex" | "vertexai" | "vertex_ai" => "gcpvertexai",
263 "aws_bedrock" | "aws-bedrock" => "bedrock",
264 "kiro" => "bedrock", _ => normalized.as_str(),
268 };
269
270 mapped.to_string()
271}
272
273pub async fn create(name: &str, model: ModelConfig) -> Result<Arc<dyn Provider>> {
274 let config = crate::config::Config::global();
275
276 if let Ok(lead_model_name) = config.get_param::<String>("ASTER_LEAD_MODEL") {
277 tracing::info!("Creating lead/worker provider from environment variables");
278 return create_lead_worker_from_env(name, &model, &lead_model_name).await;
279 }
280
281 let constructor = get_from_registry(name).await?.constructor.clone();
282 constructor(model).await
283}
284
285pub async fn create_with_default_model(name: impl AsRef<str>) -> Result<Arc<dyn Provider>> {
286 get_from_registry(name.as_ref())
287 .await?
288 .create_with_default_model()
289 .await
290}
291
292pub async fn create_with_named_model(
293 provider_name: &str,
294 model_name: &str,
295) -> Result<Arc<dyn Provider>> {
296 let config = ModelConfig::new(model_name)?;
297 create(provider_name, config).await
298}
299
300async fn create_lead_worker_from_env(
301 default_provider_name: &str,
302 default_model: &ModelConfig,
303 lead_model_name: &str,
304) -> Result<Arc<dyn Provider>> {
305 let config = crate::config::Config::global();
306
307 let lead_provider_name_raw = config
308 .get_param::<String>("ASTER_LEAD_PROVIDER")
309 .unwrap_or_else(|_| default_provider_name.to_string());
310 let lead_provider_name = map_provider_alias(&lead_provider_name_raw);
311 let worker_provider_name = map_provider_alias(default_provider_name);
312
313 let lead_turns = config
314 .get_param::<usize>("ASTER_LEAD_TURNS")
315 .unwrap_or(DEFAULT_LEAD_TURNS);
316 let failure_threshold = config
317 .get_param::<usize>("ASTER_LEAD_FAILURE_THRESHOLD")
318 .unwrap_or(DEFAULT_FAILURE_THRESHOLD);
319 let fallback_turns = config
320 .get_param::<usize>("ASTER_LEAD_FALLBACK_TURNS")
321 .unwrap_or(DEFAULT_FALLBACK_TURNS);
322
323 let lead_model_config = ModelConfig::new_with_context_env(
324 lead_model_name.to_string(),
325 Some("ASTER_LEAD_CONTEXT_LIMIT"),
326 )?;
327
328 let worker_model_config = create_worker_model_config(default_model)?;
329
330 let registry = get_registry().await;
331
332 let lead_constructor = {
333 let guard = registry.read().unwrap();
334 guard
335 .entries
336 .get(lead_provider_name.as_str())
337 .ok_or_else(|| {
338 anyhow::anyhow!(
339 "Unknown provider: {} (mapped to: {})",
340 lead_provider_name_raw,
341 lead_provider_name
342 )
343 })?
344 .constructor
345 .clone()
346 };
347
348 let worker_constructor = {
349 let guard = registry.read().unwrap();
350 guard
351 .entries
352 .get(worker_provider_name.as_str())
353 .ok_or_else(|| {
354 anyhow::anyhow!(
355 "Unknown provider: {} (mapped to: {})",
356 default_provider_name,
357 worker_provider_name
358 )
359 })?
360 .constructor
361 .clone()
362 };
363
364 let lead_provider = lead_constructor(lead_model_config).await?;
365 let worker_provider = worker_constructor(worker_model_config).await?;
366
367 Ok(Arc::new(LeadWorkerProvider::new_with_settings(
368 lead_provider,
369 worker_provider,
370 lead_turns,
371 failure_threshold,
372 fallback_turns,
373 )))
374}
375
376fn create_worker_model_config(default_model: &ModelConfig) -> Result<ModelConfig> {
377 let mut worker_config = ModelConfig::new_or_fail(&default_model.model_name)
378 .with_context_limit(default_model.context_limit)
379 .with_temperature(default_model.temperature)
380 .with_max_tokens(default_model.max_tokens)
381 .with_toolshim(default_model.toolshim)
382 .with_toolshim_model(default_model.toolshim_model.clone());
383
384 let global_config = crate::config::Config::global();
385
386 if let Ok(limit) = global_config.get_param::<usize>("ASTER_WORKER_CONTEXT_LIMIT") {
387 worker_config = worker_config.with_context_limit(Some(limit));
388 } else if let Ok(limit) = global_config.get_param::<usize>("ASTER_CONTEXT_LIMIT") {
389 worker_config = worker_config.with_context_limit(Some(limit));
390 }
391
392 Ok(worker_config)
393}
394
395#[cfg(test)]
396mod tests {
397 use super::*;
398
399 #[test_case::test_case(None, None, None, DEFAULT_LEAD_TURNS, DEFAULT_FAILURE_THRESHOLD, DEFAULT_FALLBACK_TURNS ; "defaults")]
400 #[test_case::test_case(Some("7"), Some("4"), Some("3"), 7, 4, 3 ; "custom")]
401 #[tokio::test]
402 async fn test_create_lead_worker_provider(
403 lead_turns: Option<&str>,
404 failure_threshold: Option<&str>,
405 fallback_turns: Option<&str>,
406 expected_turns: usize,
407 expected_failure: usize,
408 expected_fallback: usize,
409 ) {
410 let _guard = env_lock::lock_env([
411 ("ASTER_LEAD_MODEL", Some("gpt-4o")),
412 ("ASTER_LEAD_PROVIDER", None),
413 ("ASTER_LEAD_TURNS", lead_turns),
414 ("ASTER_LEAD_FAILURE_THRESHOLD", failure_threshold),
415 ("ASTER_LEAD_FALLBACK_TURNS", fallback_turns),
416 ("OPENAI_API_KEY", Some("fake-openai-no-keyring")),
417 ]);
418
419 let provider = create("openai", ModelConfig::new_or_fail("gpt-4o-mini"))
420 .await
421 .unwrap();
422 let lw = provider.as_lead_worker().unwrap();
423 let (lead, worker) = lw.get_model_info();
424 assert_eq!(lead, "gpt-4o");
425 assert_eq!(worker, "gpt-4o-mini");
426 assert_eq!(
427 lw.get_settings(),
428 (expected_turns, expected_failure, expected_fallback)
429 );
430 }
431
432 #[tokio::test]
433 async fn test_create_regular_provider_without_lead_config() {
434 let _guard = env_lock::lock_env([
435 ("ASTER_LEAD_MODEL", None),
436 ("ASTER_LEAD_PROVIDER", None),
437 ("ASTER_LEAD_TURNS", None),
438 ("ASTER_LEAD_FAILURE_THRESHOLD", None),
439 ("ASTER_LEAD_FALLBACK_TURNS", None),
440 ("OPENAI_API_KEY", Some("fake-openai-no-keyring")),
441 ]);
442
443 let provider = create("openai", ModelConfig::new_or_fail("gpt-4o-mini"))
444 .await
445 .unwrap();
446 assert!(provider.as_lead_worker().is_none());
447 assert_eq!(provider.get_model_config().model_name, "gpt-4o-mini");
448 }
449
450 #[test_case::test_case(None, None, 16_000 ; "no overrides uses default")]
451 #[test_case::test_case(Some("32000"), None, 32_000 ; "worker limit overrides default")]
452 #[test_case::test_case(Some("32000"), Some("64000"), 32_000 ; "worker limit takes priority over global")]
453 fn test_worker_model_context_limit(
454 worker_limit: Option<&str>,
455 global_limit: Option<&str>,
456 expected_limit: usize,
457 ) {
458 let _guard = env_lock::lock_env([
459 ("ASTER_WORKER_CONTEXT_LIMIT", worker_limit),
460 ("ASTER_CONTEXT_LIMIT", global_limit),
461 ]);
462
463 let default_model =
464 ModelConfig::new_or_fail("gpt-3.5-turbo").with_context_limit(Some(16_000));
465
466 let result = create_worker_model_config(&default_model).unwrap();
467 assert_eq!(result.context_limit, Some(expected_limit));
468 }
469
470 #[tokio::test]
471 async fn test_openai_compatible_providers_config_keys() {
472 let providers_list = providers().await;
473 let cases = vec![
474 ("openai", "OPENAI_API_KEY"),
475 ("groq", "GROQ_API_KEY"),
476 ("mistral", "MISTRAL_API_KEY"),
477 ("custom_deepseek", "DEEPSEEK_API_KEY"),
478 ];
479 for (name, expected_key) in cases {
480 if let Some((meta, _)) = providers_list.iter().find(|(m, _)| m.name == name) {
481 assert!(
482 !meta.config_keys.is_empty(),
483 "{name} provider should have config keys"
484 );
485 assert_eq!(
486 meta.config_keys[0].name, expected_key,
487 "First config key for {name} should be {expected_key}, got {}",
488 meta.config_keys[0].name
489 );
490 assert!(
491 meta.config_keys[0].required,
492 "{expected_key} should be required"
493 );
494 assert!(
495 meta.config_keys[0].secret,
496 "{expected_key} should be secret"
497 );
498 } else {
499 continue;
501 }
502 }
503 }
504
505 #[test]
506 fn test_map_provider_alias_custom_uuid() {
507 let _guard = env_lock::lock_env([("ASTER_PROVIDER_ALIAS_OVERRIDES", None::<&str>)]);
508
509 assert_eq!(
511 map_provider_alias("custom-ba4e7574-dd00-4784-945a-0f383dfa1272"),
512 "openai"
513 );
514 assert_eq!(
515 map_provider_alias("custom-12345678-1234-1234-1234-123456789abc"),
516 "openai"
517 );
518 assert_eq!(map_provider_alias("custom"), "openai");
520 assert_eq!(map_provider_alias("custom_openai"), "openai");
521 }
522
523 #[test]
524 fn test_map_provider_alias_known_providers() {
525 let _guard = env_lock::lock_env([("ASTER_PROVIDER_ALIAS_OVERRIDES", None::<&str>)]);
526
527 assert_eq!(map_provider_alias("deepseek"), "openai");
529 assert_eq!(map_provider_alias("qwen"), "openai");
530 assert_eq!(map_provider_alias("claude"), "anthropic");
531 assert_eq!(map_provider_alias("gemini"), "google");
532 assert_eq!(map_provider_alias("kiro"), "bedrock");
533 assert_eq!(map_provider_alias("openai"), "openai");
535 assert_eq!(map_provider_alias("anthropic"), "anthropic");
536 assert_eq!(map_provider_alias("google"), "google");
537 }
538
539 #[test]
540 fn test_map_provider_alias_fallback_to_lowercase() {
541 let _guard = env_lock::lock_env([("ASTER_PROVIDER_ALIAS_OVERRIDES", None::<&str>)]);
542
543 assert_eq!(map_provider_alias("OpenAI"), "openai");
544 assert_eq!(
545 map_provider_alias("My-Custom-Provider"),
546 "my-custom-provider"
547 );
548 }
549
550 #[test]
551 fn test_map_provider_alias_env_override_json() {
552 let _guard = env_lock::lock_env([(
553 "ASTER_PROVIDER_ALIAS_OVERRIDES",
554 Some(r#"{"moonshotai":"openrouter","gemini":"google"}"#),
555 )]);
556
557 assert_eq!(map_provider_alias("moonshotai"), "openrouter");
558 assert_eq!(map_provider_alias("gemini"), "google");
559 }
560
561 #[test]
562 fn test_map_provider_alias_env_override_kv() {
563 let _guard = env_lock::lock_env([(
564 "ASTER_PROVIDER_ALIAS_OVERRIDES",
565 Some("deepseek=openrouter,claude=openai"),
566 )]);
567
568 assert_eq!(map_provider_alias("deepseek"), "openrouter");
569 assert_eq!(map_provider_alias("claude"), "openai");
570 }
571}