use super::{
anthropic::AnthropicProvider, byteplus::BytePlusProvider, chatgpt::ChatGPTProvider,
deepseek::DeepSeekProvider, gemini::GeminiProvider, juglans::JuglansProvider,
qwen::QwenProvider, xai::XaiProvider, LlmProvider,
};
use dashmap::DashMap;
use std::collections::HashMap;
use std::env;
use std::sync::Arc;
use super::claude_code::ClaudeCodeProvider;
use super::mcp_types::McpSession;
#[derive(Debug, Clone, Default)]
pub struct LlmProviderConfig {
pub api_key: Option<String>,
pub base_url: Option<String>,
}
#[derive(Clone)]
pub struct ProviderFactory {
anthropic: Arc<AnthropicProvider>,
byteplus: Arc<BytePlusProvider>,
chatgpt: Arc<ChatGPTProvider>,
deepseek: Arc<DeepSeekProvider>,
gemini: Arc<GeminiProvider>,
juglans: Arc<JuglansProvider>,
qwen: Arc<QwenProvider>,
xai: Arc<XaiProvider>,
extra: Arc<DashMap<String, Arc<dyn LlmProvider>>>,
}
impl ProviderFactory {
pub fn new() -> Self {
Self {
anthropic: Arc::new(AnthropicProvider::new()),
byteplus: Arc::new(BytePlusProvider::new()),
chatgpt: Arc::new(ChatGPTProvider::new()),
deepseek: Arc::new(DeepSeekProvider::new()),
gemini: Arc::new(GeminiProvider::new()),
juglans: Arc::new(JuglansProvider::new()),
qwen: Arc::new(QwenProvider::new()),
xai: Arc::new(XaiProvider::new()),
extra: Arc::new(DashMap::new()),
}
}
pub fn new_with_config(configs: &HashMap<String, LlmProviderConfig>) -> Self {
let apply = |name: &str, key_env: &str, url_env: Option<&str>| {
if let Some(cfg) = configs.get(name) {
if let Some(key) = &cfg.api_key {
if !key.is_empty() {
std::env::set_var(key_env, key);
}
}
if let Some(url) = &cfg.base_url {
if let Some(url_env) = url_env {
if !url.is_empty() {
std::env::set_var(url_env, url);
}
}
}
}
};
apply("openai", "OPENAI_API_KEY", Some("OPENAI_API_BASE"));
apply("anthropic", "ANTHROPIC_API_KEY", Some("ANTHROPIC_BASE_URL"));
apply("deepseek", "DEEPSEEK_API_KEY", None);
apply("gemini", "GEMINI_API_KEY", None);
apply("qwen", "QWEN_API_KEY", None);
apply("byteplus", "ARK_API_KEY", Some("ARK_API_BASE"));
apply("xai", "XAI_API_KEY", None);
apply("juglans", "JUGLANS_API_KEY", Some("JUGLANS_API_BASE"));
Self::new()
}
pub fn new_with_mcp(tool_sessions: Arc<DashMap<String, McpSession>>) -> Self {
let mut factory = Self::new();
let mut cc = ClaudeCodeProvider::new();
cc.set_tool_sessions(tool_sessions);
factory
.extra
.insert("claude-code".to_string(), Arc::new(cc));
factory
}
pub fn register_provider(&self, name: &str, provider: Arc<dyn LlmProvider>) {
self.extra.insert(name.to_string(), provider);
}
pub fn get_provider(&self, model: &str) -> (Arc<dyn LlmProvider>, String) {
if let Some((provider_name, actual_model)) = model.split_once('/') {
let pn = provider_name.to_lowercase();
if let Some(p) = self.extra.get(&pn) {
return (p.value().clone(), actual_model.to_string());
}
let p: Arc<dyn LlmProvider> = match pn.as_str() {
"openai" | "chatgpt" => self.chatgpt.clone(),
"anthropic" | "claude" => self.anthropic.clone(),
"deepseek" => self.deepseek.clone(),
"qwen" => self.qwen.clone(),
"gemini" => self.gemini.clone(),
"byteplus" | "ark" => self.byteplus.clone(),
"xai" => self.xai.clone(),
"juglans" => self.juglans.clone(),
_ => self.chatgpt.clone(),
};
return (p, actual_model.to_string());
}
let m = model.to_lowercase();
let default_provider = env::var("DEFAULT_LLM_PROVIDER")
.unwrap_or_default()
.to_lowercase();
if m.contains("claude-code") || (m == "default" && default_provider == "claude-code") {
if let Some(p) = self.extra.get("claude-code") {
return (p.value().clone(), model.to_string());
}
}
if m.contains("claude") || (m == "default" && default_provider == "anthropic") {
return (self.anthropic.clone(), model.to_string());
}
if m.contains("qwen") || (m == "default" && default_provider == "qwen") {
return (self.qwen.clone(), model.to_string());
}
if m.contains("gemini") {
return (self.gemini.clone(), model.to_string());
}
if m.contains("deepseek") {
return (self.deepseek.clone(), model.to_string());
}
if m.contains("grok") {
return (self.xai.clone(), model.to_string());
}
if m.contains("doubao")
|| m.starts_with("ep-")
|| (m == "default" && default_provider == "byteplus")
{
return (self.byteplus.clone(), model.to_string());
}
(self.chatgpt.clone(), model.to_string())
}
}