use std::sync::Arc;
use crate::config::ProviderConfig;
use crate::io::AgentIO;
use crate::llm::anthropic::{AnthropicProvider, ANTHROPIC_API_BASE};
use crate::llm::gemini::{GeminiProvider, GEMINI_API_BASE};
use crate::llm::openai::{OpenAiProvider, COPILOT_API_BASE};
use crate::llm::retry::RetryConfig;
use crate::llm::LlmProvider;
pub struct ProviderRegistry;
impl ProviderRegistry {
pub fn create_provider(
config: &ProviderConfig,
model: &str,
retry: &RetryConfig,
io: &Arc<dyn AgentIO>,
oauth_token: Option<String>,
) -> Arc<dyn LlmProvider> {
let api_base = &config.api_base;
let api_key = &config.api_key;
if api_base == COPILOT_API_BASE {
let token = oauth_token.unwrap_or_default();
Arc::new(
OpenAiProvider::new_copilot(token, model.to_string())
.with_io(io.clone())
.with_retry(retry.clone()),
) as Arc<dyn LlmProvider>
} else if is_anthropic(api_base) {
Arc::new(AnthropicProvider::new(api_key.clone(), model.to_string()))
as Arc<dyn LlmProvider>
} else if is_gemini(api_base) {
Arc::new(GeminiProvider::new(api_key.clone(), model.to_string()))
as Arc<dyn LlmProvider>
} else {
Arc::new(
OpenAiProvider::new(api_base.clone(), api_key.clone(), model.to_string())
.with_io(io.clone())
.with_retry(retry.clone()),
) as Arc<dyn LlmProvider>
}
}
#[allow(dead_code)]
pub fn builtin_providers() -> Vec<(&'static str, &'static str, &'static str)> {
vec![
("OpenAI", "https://api.openai.com/v1", "gpt-4o"),
(
"Anthropic",
ANTHROPIC_API_BASE,
"claude-3-5-sonnet-20241022",
),
("Google Gemini", GEMINI_API_BASE, "gemini-2.0-flash"),
("DeepSeek", "https://api.deepseek.com/v1", "deepseek-chat"),
(
"Qwen",
"https://dashscope.aliyuncs.com/compatible-mode/v1",
"qwen-turbo",
),
("GLM", "https://open.bigmodel.cn/api/paas/v4", "glm-4"),
("Ollama", "http://localhost:11434/v1", "llama3"),
("GitHub Copilot", COPILOT_API_BASE, "gpt-4o"),
]
}
}
fn is_anthropic(api_base: &str) -> bool {
api_base == ANTHROPIC_API_BASE || api_base.starts_with("https://api.anthropic.com")
}
fn is_gemini(api_base: &str) -> bool {
api_base == GEMINI_API_BASE || api_base.starts_with("https://generativelanguage.googleapis.com")
}
#[cfg(test)]
mod tests {
use super::*;
use crate::io::NullIO;
fn make_config(api_base: &str) -> ProviderConfig {
ProviderConfig {
api_base: api_base.to_string(),
api_key: "test-key".to_string(),
}
}
fn null_io() -> Arc<dyn AgentIO> {
Arc::new(NullIO)
}
fn default_retry() -> RetryConfig {
RetryConfig::default()
}
#[test]
fn test_is_anthropic_sentinel() {
assert!(is_anthropic(ANTHROPIC_API_BASE));
assert!(is_anthropic("anthropic"));
}
#[test]
fn test_is_anthropic_full_url() {
assert!(is_anthropic("https://api.anthropic.com/v1"));
assert!(is_anthropic("https://api.anthropic.com"));
}
#[test]
fn test_is_anthropic_false_for_others() {
assert!(!is_anthropic("https://api.openai.com/v1"));
assert!(!is_anthropic("gemini"));
assert!(!is_anthropic("copilot"));
assert!(!is_anthropic("https://api.deepseek.com/v1"));
}
#[test]
fn test_is_gemini_sentinel() {
assert!(is_gemini(GEMINI_API_BASE));
assert!(is_gemini("gemini"));
}
#[test]
fn test_is_gemini_full_url() {
assert!(is_gemini(
"https://generativelanguage.googleapis.com/v1beta"
));
assert!(is_gemini("https://generativelanguage.googleapis.com"));
}
#[test]
fn test_is_gemini_false_for_others() {
assert!(!is_gemini("https://api.openai.com/v1"));
assert!(!is_gemini("anthropic"));
assert!(!is_gemini("copilot"));
}
#[test]
fn test_builtin_providers_contains_all() {
let providers = ProviderRegistry::builtin_providers();
let bases: Vec<&str> = providers.iter().map(|(_, base, _)| *base).collect();
assert!(bases.contains(&ANTHROPIC_API_BASE), "Anthropic missing");
assert!(bases.contains(&GEMINI_API_BASE), "Gemini missing");
assert!(bases.contains(&COPILOT_API_BASE), "Copilot missing");
assert!(
bases
.iter()
.any(|b| b.starts_with("https://api.openai.com")),
"OpenAI missing"
);
}
#[test]
fn test_builtin_providers_fields_non_empty() {
for (name, base, model) in ProviderRegistry::builtin_providers() {
assert!(!name.is_empty(), "empty display name for base={}", base);
assert!(!base.is_empty(), "empty api_base for name={}", name);
assert!(!model.is_empty(), "empty model for name={}", name);
}
}
#[test]
fn test_create_provider_copilot() {
let config = make_config(COPILOT_API_BASE);
let _provider = ProviderRegistry::create_provider(
&config,
"gpt-4o",
&default_retry(),
&null_io(),
Some("fake-token".to_string()),
);
}
#[test]
fn test_create_provider_anthropic_sentinel() {
let config = make_config(ANTHROPIC_API_BASE);
let _provider = ProviderRegistry::create_provider(
&config,
"claude-3-5-sonnet-20241022",
&default_retry(),
&null_io(),
None,
);
}
#[test]
fn test_create_provider_gemini_sentinel() {
let config = make_config(GEMINI_API_BASE);
let _provider = ProviderRegistry::create_provider(
&config,
"gemini-2.0-flash",
&default_retry(),
&null_io(),
None,
);
}
#[test]
fn test_create_provider_anthropic_url() {
let config = make_config("https://api.anthropic.com/v1");
let _provider = ProviderRegistry::create_provider(
&config,
"claude-3-5-sonnet-20241022",
&default_retry(),
&null_io(),
None,
);
}
#[test]
fn test_create_provider_gemini_url() {
let config = make_config("https://generativelanguage.googleapis.com/v1beta");
let _provider = ProviderRegistry::create_provider(
&config,
"gemini-2.0-flash",
&default_retry(),
&null_io(),
None,
);
}
#[test]
fn test_create_provider_openai_fallback() {
let config = make_config("https://api.deepseek.com/v1");
let _provider = ProviderRegistry::create_provider(
&config,
"deepseek-chat",
&default_retry(),
&null_io(),
None,
);
}
#[test]
fn test_create_provider_ollama() {
let config = make_config("http://localhost:11434/v1");
let _provider = ProviderRegistry::create_provider(
&config,
"llama3",
&default_retry(),
&null_io(),
None,
);
}
}