Skip to main content

aster/providers/
factory.rs

1use std::collections::HashMap;
2use std::sync::{Arc, RwLock};
3
4use super::{
5    anthropic::AnthropicProvider,
6    azure::AzureProvider,
7    base::{Provider, ProviderMetadata},
8    claude_code::ClaudeCodeProvider,
9    codex::CodexProvider,
10    codex_stateful::CodexStatefulProvider,
11    cursor_agent::CursorAgentProvider,
12    databricks::DatabricksProvider,
13    gcpvertexai::GcpVertexAIProvider,
14    gemini_cli::GeminiCliProvider,
15    githubcopilot::GithubCopilotProvider,
16    google::GoogleProvider,
17    lead_worker::LeadWorkerProvider,
18    litellm::LiteLLMProvider,
19    ollama::OllamaProvider,
20    openai::OpenAiProvider,
21    openrouter::OpenRouterProvider,
22    provider_registry::ProviderRegistry,
23    snowflake::SnowflakeProvider,
24    tetrate::TetrateProvider,
25    venice::VeniceProvider,
26    xai::XaiProvider,
27};
28#[cfg(feature = "provider-aws")]
29use super::{bedrock::BedrockProvider, sagemaker_tgi::SageMakerTgiProvider};
30use crate::model::ModelConfig;
31use crate::providers::base::ProviderType;
32use crate::{
33    config::declarative_providers::register_declarative_providers,
34    providers::provider_registry::ProviderEntry,
35};
36use anyhow::Result;
37use tokio::sync::OnceCell;
38
39const DEFAULT_LEAD_TURNS: usize = 3;
40const DEFAULT_FAILURE_THRESHOLD: usize = 2;
41const DEFAULT_FALLBACK_TURNS: usize = 2;
42
43static REGISTRY: OnceCell<RwLock<ProviderRegistry>> = OnceCell::const_new();
44
45async fn init_registry() -> RwLock<ProviderRegistry> {
46    let mut registry = ProviderRegistry::new().with_providers(|registry| {
47        registry
48            .register::<AnthropicProvider, _>(|m| Box::pin(AnthropicProvider::from_env(m)), true);
49        registry.register::<AzureProvider, _>(|m| Box::pin(AzureProvider::from_env(m)), false);
50        #[cfg(feature = "provider-aws")]
51        registry.register::<BedrockProvider, _>(|m| Box::pin(BedrockProvider::from_env(m)), false);
52        registry
53            .register::<ClaudeCodeProvider, _>(|m| Box::pin(ClaudeCodeProvider::from_env(m)), true);
54        registry.register::<CodexProvider, _>(|m| Box::pin(CodexProvider::from_env(m)), true);
55        registry.register::<CodexStatefulProvider, _>(
56            |m| Box::pin(CodexStatefulProvider::from_env(m)),
57            true,
58        );
59        registry.register::<CursorAgentProvider, _>(
60            |m| Box::pin(CursorAgentProvider::from_env(m)),
61            false,
62        );
63        registry
64            .register::<DatabricksProvider, _>(|m| Box::pin(DatabricksProvider::from_env(m)), true);
65        registry.register::<GcpVertexAIProvider, _>(
66            |m| Box::pin(GcpVertexAIProvider::from_env(m)),
67            false,
68        );
69        registry
70            .register::<GeminiCliProvider, _>(|m| Box::pin(GeminiCliProvider::from_env(m)), false);
71        registry.register::<GithubCopilotProvider, _>(
72            |m| Box::pin(GithubCopilotProvider::from_env(m)),
73            false,
74        );
75        registry.register::<GoogleProvider, _>(|m| Box::pin(GoogleProvider::from_env(m)), true);
76        registry.register::<LiteLLMProvider, _>(|m| Box::pin(LiteLLMProvider::from_env(m)), false);
77        registry.register::<OllamaProvider, _>(|m| Box::pin(OllamaProvider::from_env(m)), true);
78        registry.register::<OpenAiProvider, _>(|m| Box::pin(OpenAiProvider::from_env(m)), true);
79        registry
80            .register::<OpenRouterProvider, _>(|m| Box::pin(OpenRouterProvider::from_env(m)), true);
81        #[cfg(feature = "provider-aws")]
82        registry.register::<SageMakerTgiProvider, _>(
83            |m| Box::pin(SageMakerTgiProvider::from_env(m)),
84            false,
85        );
86        registry
87            .register::<SnowflakeProvider, _>(|m| Box::pin(SnowflakeProvider::from_env(m)), false);
88        registry.register::<TetrateProvider, _>(|m| Box::pin(TetrateProvider::from_env(m)), true);
89        registry.register::<VeniceProvider, _>(|m| Box::pin(VeniceProvider::from_env(m)), false);
90        registry.register::<XaiProvider, _>(|m| Box::pin(XaiProvider::from_env(m)), false);
91    });
92    if let Err(e) = load_custom_providers_into_registry(&mut registry) {
93        tracing::warn!("Failed to load custom providers: {}", e);
94    }
95    RwLock::new(registry)
96}
97
98fn load_custom_providers_into_registry(registry: &mut ProviderRegistry) -> Result<()> {
99    register_declarative_providers(registry)
100}
101
102async fn get_registry() -> &'static RwLock<ProviderRegistry> {
103    REGISTRY.get_or_init(init_registry).await
104}
105
106pub async fn providers() -> Vec<(ProviderMetadata, ProviderType)> {
107    get_registry()
108        .await
109        .read()
110        .unwrap()
111        .all_metadata_with_types()
112}
113
114pub async fn refresh_custom_providers() -> Result<()> {
115    let registry = get_registry().await;
116    registry.write().unwrap().remove_custom_providers();
117
118    if let Err(e) = load_custom_providers_into_registry(&mut registry.write().unwrap()) {
119        tracing::warn!("Failed to refresh custom providers: {}", e);
120        return Err(e);
121    }
122
123    tracing::info!("Custom providers refreshed");
124    Ok(())
125}
126
127async fn get_from_registry(name: &str) -> Result<ProviderEntry> {
128    // 将各种 Provider 名称映射到 Aster 支持的 Provider
129    let mapped_name = map_provider_alias(name);
130
131    #[cfg(not(feature = "provider-aws"))]
132    if mapped_name == "bedrock" || mapped_name == "sagemaker_tgi" {
133        return Err(anyhow::anyhow!(
134            "Provider {} is disabled at compile time; rebuild with feature provider-aws",
135            mapped_name
136        ));
137    }
138
139    let guard = get_registry().await.read().unwrap();
140    guard
141        .entries
142        .get(mapped_name.as_str())
143        .ok_or_else(|| anyhow::anyhow!("Unknown provider: {} (mapped to: {})", name, mapped_name))
144        .cloned()
145}
146
147/// 将各种 Provider 名称映射到 Aster 支持的 Provider
148///
149/// Aster 原生支持的 Provider:
150/// - openai, anthropic, google, azure, bedrock, ollama, gcpvertexai
151/// - openrouter, litellm, databricks, codex, xai, venice, tetrate
152/// - snowflake, sagemaker_tgi, githubcopilot, gemini_cli, cursor_agent, claude_code
153///
154/// 其他 Provider 会映射到兼容的 Provider
155fn parse_provider_alias_overrides(raw: &str) -> HashMap<String, String> {
156    let trimmed = raw.trim();
157    if trimmed.is_empty() {
158        return HashMap::new();
159    }
160
161    if let Ok(json_map) = serde_json::from_str::<HashMap<String, String>>(trimmed) {
162        return json_map
163            .into_iter()
164            .map(|(alias, target)| (alias.trim().to_lowercase(), target.trim().to_lowercase()))
165            .filter(|(alias, target)| !alias.is_empty() && !target.is_empty())
166            .collect();
167    }
168
169    let mut overrides = HashMap::new();
170    for pair in trimmed.split(',') {
171        let entry = pair.trim();
172        if entry.is_empty() {
173            continue;
174        }
175
176        let parsed = entry.split_once('=').or_else(|| entry.split_once(':'));
177
178        if let Some((alias, target)) = parsed {
179            let alias = alias.trim().to_lowercase();
180            let target = target.trim().to_lowercase();
181            if !alias.is_empty() && !target.is_empty() {
182                overrides.insert(alias, target);
183            }
184        }
185    }
186
187    overrides
188}
189
190fn load_provider_alias_overrides() -> HashMap<String, String> {
191    std::env::var("ASTER_PROVIDER_ALIAS_OVERRIDES")
192        .ok()
193        .map(|raw| parse_provider_alias_overrides(&raw))
194        .unwrap_or_default()
195}
196
197fn map_provider_alias(name: &str) -> String {
198    let normalized = name.trim().to_lowercase();
199
200    if normalized.is_empty() {
201        return normalized;
202    }
203
204    // 自定义 Provider(UUID 格式,如 custom-ba4e7574-dd00-4784-945a-0f383dfa1272)
205    // 这些是用户通过 API Key Provider 添加的自定义服务,通常是 OpenAI 兼容的
206    if normalized.starts_with("custom-") {
207        return "openai".to_string();
208    }
209
210    // 应用层可通过环境变量覆盖别名映射,避免框架层频繁改代码
211    if let Some(mapped) = load_provider_alias_overrides().get(normalized.as_str()) {
212        return mapped.clone();
213    }
214
215    let mapped = match normalized.as_str() {
216        // ========== OpenAI 兼容格式 ==========
217        // 国内 AI 服务
218        "deepseek" | "deep_seek" | "deep-seek" => "openai",
219        "qwen" | "tongyi" | "dashscope" | "aliyun" => "openai",
220        "zhipu" | "glm" | "chatglm" => "openai",
221        "baichuan" => "openai",
222        "moonshot" | "kimi" => "openai",
223        "minimax" => "openai",
224        "yi" | "01ai" | "lingyiwanwu" => "openai",
225        "stepfun" | "step" => "openai",
226        "bailian" | "百炼" => "openai",
227        "doubao" | "豆包" => "openai",
228        "spark" | "讯飞" | "xunfei" => "openai",
229        "hunyuan" | "混元" => "openai",
230        "ernie" | "文心" | "wenxin" => "openai",
231
232        // 国际 AI 服务(OpenAI 兼容)
233        "groq" => "openai",
234        "together" | "togetherai" => "openai",
235        "fireworks" | "fireworksai" => "openai",
236        "perplexity" => "openai",
237        "anyscale" => "openai",
238        "lepton" | "leptonai" => "openai",
239        "novita" | "novitaai" => "openai",
240        "siliconflow" => "openai",
241        "mistral" => "openai",
242        "cohere" => "openai",
243
244        // API 聚合服务
245        "oneapi" | "one-api" | "one_api" => "openai",
246        "newapi" | "new-api" | "new_api" => "openai",
247        "vercel" | "vercel_ai" | "vercel-ai" => "openai",
248
249        // 自定义/通用 OpenAI 兼容
250        "custom" | "custom_openai" | "openai_compatible" => "openai",
251
252        // ========== Anthropic 兼容格式 ==========
253        "claude" => "anthropic",
254        "anthropic_compatible" | "anthropic-compatible" => "anthropic",
255
256        // ========== Google/Gemini 格式 ==========
257        "gemini" | "gemini_api_key" => "google",
258        "antigravity" => "google",
259
260        // ========== 其他已支持的 Provider(保持原名) ==========
261        "azure" | "azure_openai" | "azure-openai" => "azure",
262        "vertex" | "vertexai" | "vertex_ai" => "gcpvertexai",
263        "aws_bedrock" | "aws-bedrock" => "bedrock",
264        "kiro" => "bedrock", // Kiro 使用 CodeWhisperer API
265
266        // 默认返回小写原名称(让 Aster 原生处理)
267        _ => normalized.as_str(),
268    };
269
270    mapped.to_string()
271}
272
273pub async fn create(name: &str, model: ModelConfig) -> Result<Arc<dyn Provider>> {
274    let config = crate::config::Config::global();
275
276    if let Ok(lead_model_name) = config.get_param::<String>("ASTER_LEAD_MODEL") {
277        tracing::info!("Creating lead/worker provider from environment variables");
278        return create_lead_worker_from_env(name, &model, &lead_model_name).await;
279    }
280
281    let constructor = get_from_registry(name).await?.constructor.clone();
282    constructor(model).await
283}
284
285pub async fn create_with_default_model(name: impl AsRef<str>) -> Result<Arc<dyn Provider>> {
286    get_from_registry(name.as_ref())
287        .await?
288        .create_with_default_model()
289        .await
290}
291
292pub async fn create_with_named_model(
293    provider_name: &str,
294    model_name: &str,
295) -> Result<Arc<dyn Provider>> {
296    let config = ModelConfig::new(model_name)?;
297    create(provider_name, config).await
298}
299
300async fn create_lead_worker_from_env(
301    default_provider_name: &str,
302    default_model: &ModelConfig,
303    lead_model_name: &str,
304) -> Result<Arc<dyn Provider>> {
305    let config = crate::config::Config::global();
306
307    let lead_provider_name_raw = config
308        .get_param::<String>("ASTER_LEAD_PROVIDER")
309        .unwrap_or_else(|_| default_provider_name.to_string());
310    let lead_provider_name = map_provider_alias(&lead_provider_name_raw);
311    let worker_provider_name = map_provider_alias(default_provider_name);
312
313    let lead_turns = config
314        .get_param::<usize>("ASTER_LEAD_TURNS")
315        .unwrap_or(DEFAULT_LEAD_TURNS);
316    let failure_threshold = config
317        .get_param::<usize>("ASTER_LEAD_FAILURE_THRESHOLD")
318        .unwrap_or(DEFAULT_FAILURE_THRESHOLD);
319    let fallback_turns = config
320        .get_param::<usize>("ASTER_LEAD_FALLBACK_TURNS")
321        .unwrap_or(DEFAULT_FALLBACK_TURNS);
322
323    let lead_model_config = ModelConfig::new_with_context_env(
324        lead_model_name.to_string(),
325        Some("ASTER_LEAD_CONTEXT_LIMIT"),
326    )?;
327
328    let worker_model_config = create_worker_model_config(default_model)?;
329
330    let registry = get_registry().await;
331
332    let lead_constructor = {
333        let guard = registry.read().unwrap();
334        guard
335            .entries
336            .get(lead_provider_name.as_str())
337            .ok_or_else(|| {
338                anyhow::anyhow!(
339                    "Unknown provider: {} (mapped to: {})",
340                    lead_provider_name_raw,
341                    lead_provider_name
342                )
343            })?
344            .constructor
345            .clone()
346    };
347
348    let worker_constructor = {
349        let guard = registry.read().unwrap();
350        guard
351            .entries
352            .get(worker_provider_name.as_str())
353            .ok_or_else(|| {
354                anyhow::anyhow!(
355                    "Unknown provider: {} (mapped to: {})",
356                    default_provider_name,
357                    worker_provider_name
358                )
359            })?
360            .constructor
361            .clone()
362    };
363
364    let lead_provider = lead_constructor(lead_model_config).await?;
365    let worker_provider = worker_constructor(worker_model_config).await?;
366
367    Ok(Arc::new(LeadWorkerProvider::new_with_settings(
368        lead_provider,
369        worker_provider,
370        lead_turns,
371        failure_threshold,
372        fallback_turns,
373    )))
374}
375
376fn create_worker_model_config(default_model: &ModelConfig) -> Result<ModelConfig> {
377    let mut worker_config = ModelConfig::new_or_fail(&default_model.model_name)
378        .with_context_limit(default_model.context_limit)
379        .with_temperature(default_model.temperature)
380        .with_max_tokens(default_model.max_tokens)
381        .with_toolshim(default_model.toolshim)
382        .with_toolshim_model(default_model.toolshim_model.clone());
383
384    let global_config = crate::config::Config::global();
385
386    if let Ok(limit) = global_config.get_param::<usize>("ASTER_WORKER_CONTEXT_LIMIT") {
387        worker_config = worker_config.with_context_limit(Some(limit));
388    } else if let Ok(limit) = global_config.get_param::<usize>("ASTER_CONTEXT_LIMIT") {
389        worker_config = worker_config.with_context_limit(Some(limit));
390    }
391
392    Ok(worker_config)
393}
394
395#[cfg(test)]
396mod tests {
397    use super::*;
398
399    #[test_case::test_case(None, None, None, DEFAULT_LEAD_TURNS, DEFAULT_FAILURE_THRESHOLD, DEFAULT_FALLBACK_TURNS ; "defaults")]
400    #[test_case::test_case(Some("7"), Some("4"), Some("3"), 7, 4, 3 ; "custom")]
401    #[tokio::test]
402    async fn test_create_lead_worker_provider(
403        lead_turns: Option<&str>,
404        failure_threshold: Option<&str>,
405        fallback_turns: Option<&str>,
406        expected_turns: usize,
407        expected_failure: usize,
408        expected_fallback: usize,
409    ) {
410        let _guard = env_lock::lock_env([
411            ("ASTER_LEAD_MODEL", Some("gpt-4o")),
412            ("ASTER_LEAD_PROVIDER", None),
413            ("ASTER_LEAD_TURNS", lead_turns),
414            ("ASTER_LEAD_FAILURE_THRESHOLD", failure_threshold),
415            ("ASTER_LEAD_FALLBACK_TURNS", fallback_turns),
416            ("OPENAI_API_KEY", Some("fake-openai-no-keyring")),
417        ]);
418
419        let provider = create("openai", ModelConfig::new_or_fail("gpt-4o-mini"))
420            .await
421            .unwrap();
422        let lw = provider.as_lead_worker().unwrap();
423        let (lead, worker) = lw.get_model_info();
424        assert_eq!(lead, "gpt-4o");
425        assert_eq!(worker, "gpt-4o-mini");
426        assert_eq!(
427            lw.get_settings(),
428            (expected_turns, expected_failure, expected_fallback)
429        );
430    }
431
432    #[tokio::test]
433    async fn test_create_regular_provider_without_lead_config() {
434        let _guard = env_lock::lock_env([
435            ("ASTER_LEAD_MODEL", None),
436            ("ASTER_LEAD_PROVIDER", None),
437            ("ASTER_LEAD_TURNS", None),
438            ("ASTER_LEAD_FAILURE_THRESHOLD", None),
439            ("ASTER_LEAD_FALLBACK_TURNS", None),
440            ("OPENAI_API_KEY", Some("fake-openai-no-keyring")),
441        ]);
442
443        let provider = create("openai", ModelConfig::new_or_fail("gpt-4o-mini"))
444            .await
445            .unwrap();
446        assert!(provider.as_lead_worker().is_none());
447        assert_eq!(provider.get_model_config().model_name, "gpt-4o-mini");
448    }
449
450    #[test_case::test_case(None, None, 16_000 ; "no overrides uses default")]
451    #[test_case::test_case(Some("32000"), None, 32_000 ; "worker limit overrides default")]
452    #[test_case::test_case(Some("32000"), Some("64000"), 32_000 ; "worker limit takes priority over global")]
453    fn test_worker_model_context_limit(
454        worker_limit: Option<&str>,
455        global_limit: Option<&str>,
456        expected_limit: usize,
457    ) {
458        let _guard = env_lock::lock_env([
459            ("ASTER_WORKER_CONTEXT_LIMIT", worker_limit),
460            ("ASTER_CONTEXT_LIMIT", global_limit),
461        ]);
462
463        let default_model =
464            ModelConfig::new_or_fail("gpt-3.5-turbo").with_context_limit(Some(16_000));
465
466        let result = create_worker_model_config(&default_model).unwrap();
467        assert_eq!(result.context_limit, Some(expected_limit));
468    }
469
470    #[tokio::test]
471    async fn test_openai_compatible_providers_config_keys() {
472        let providers_list = providers().await;
473        let cases = vec![
474            ("openai", "OPENAI_API_KEY"),
475            ("groq", "GROQ_API_KEY"),
476            ("mistral", "MISTRAL_API_KEY"),
477            ("custom_deepseek", "DEEPSEEK_API_KEY"),
478        ];
479        for (name, expected_key) in cases {
480            if let Some((meta, _)) = providers_list.iter().find(|(m, _)| m.name == name) {
481                assert!(
482                    !meta.config_keys.is_empty(),
483                    "{name} provider should have config keys"
484                );
485                assert_eq!(
486                    meta.config_keys[0].name, expected_key,
487                    "First config key for {name} should be {expected_key}, got {}",
488                    meta.config_keys[0].name
489                );
490                assert!(
491                    meta.config_keys[0].required,
492                    "{expected_key} should be required"
493                );
494                assert!(
495                    meta.config_keys[0].secret,
496                    "{expected_key} should be secret"
497                );
498            } else {
499                // Provider not registered; skip test for this provider
500                continue;
501            }
502        }
503    }
504
505    #[test]
506    fn test_map_provider_alias_custom_uuid() {
507        let _guard = env_lock::lock_env([("ASTER_PROVIDER_ALIAS_OVERRIDES", None::<&str>)]);
508
509        // 自定义 Provider UUID 格式应该映射到 openai
510        assert_eq!(
511            map_provider_alias("custom-ba4e7574-dd00-4784-945a-0f383dfa1272"),
512            "openai"
513        );
514        assert_eq!(
515            map_provider_alias("custom-12345678-1234-1234-1234-123456789abc"),
516            "openai"
517        );
518        // 普通 custom 也应该映射到 openai
519        assert_eq!(map_provider_alias("custom"), "openai");
520        assert_eq!(map_provider_alias("custom_openai"), "openai");
521    }
522
523    #[test]
524    fn test_map_provider_alias_known_providers() {
525        let _guard = env_lock::lock_env([("ASTER_PROVIDER_ALIAS_OVERRIDES", None::<&str>)]);
526
527        // 已知的 Provider 应该正确映射
528        assert_eq!(map_provider_alias("deepseek"), "openai");
529        assert_eq!(map_provider_alias("qwen"), "openai");
530        assert_eq!(map_provider_alias("claude"), "anthropic");
531        assert_eq!(map_provider_alias("gemini"), "google");
532        assert_eq!(map_provider_alias("kiro"), "bedrock");
533        // 原生支持的 Provider 应该保持原名
534        assert_eq!(map_provider_alias("openai"), "openai");
535        assert_eq!(map_provider_alias("anthropic"), "anthropic");
536        assert_eq!(map_provider_alias("google"), "google");
537    }
538
539    #[test]
540    fn test_map_provider_alias_fallback_to_lowercase() {
541        let _guard = env_lock::lock_env([("ASTER_PROVIDER_ALIAS_OVERRIDES", None::<&str>)]);
542
543        assert_eq!(map_provider_alias("OpenAI"), "openai");
544        assert_eq!(
545            map_provider_alias("My-Custom-Provider"),
546            "my-custom-provider"
547        );
548    }
549
550    #[test]
551    fn test_map_provider_alias_env_override_json() {
552        let _guard = env_lock::lock_env([(
553            "ASTER_PROVIDER_ALIAS_OVERRIDES",
554            Some(r#"{"moonshotai":"openrouter","gemini":"google"}"#),
555        )]);
556
557        assert_eq!(map_provider_alias("moonshotai"), "openrouter");
558        assert_eq!(map_provider_alias("gemini"), "google");
559    }
560
561    #[test]
562    fn test_map_provider_alias_env_override_kv() {
563        let _guard = env_lock::lock_env([(
564            "ASTER_PROVIDER_ALIAS_OVERRIDES",
565            Some("deepseek=openrouter,claude=openai"),
566        )]);
567
568        assert_eq!(map_provider_alias("deepseek"), "openrouter");
569        assert_eq!(map_provider_alias("claude"), "openai");
570    }
571}