use crate::llm::providers::{
AmazonBedrockProvider, AnthropicProvider, CloudflareWorkersAiProvider, DeepSeekProvider,
GoogleVertexProvider, MinimaxProvider, OpenAiProvider, OpenRouterProvider, ZaiProvider,
};
use crate::llm::traits::AiProvider;
use anyhow::Result;
pub struct ProviderFactory;
impl ProviderFactory {
pub fn parse_model(model: &str) -> Result<(String, String)> {
if let Some(pos) = model.find(':') {
let provider = model[..pos].to_string();
let model_name = model[pos + 1..].to_string();
if provider.is_empty() || model_name.is_empty() {
return Err(anyhow::anyhow!(
"Invalid model format. Use 'provider:model' (e.g., 'openai:gpt-4o')"
));
}
Ok((provider, model_name))
} else {
Err(anyhow::anyhow!("Invalid model format '{}'. Must specify provider like 'openai:gpt-4o' or 'openrouter:anthropic/claude-3.5-sonnet'", model))
}
}
pub fn create_provider(provider_name: &str) -> Result<Box<dyn AiProvider>> {
match provider_name.to_lowercase().as_str() {
"openrouter" => Ok(Box::new(OpenRouterProvider::new())),
"openai" => Ok(Box::new(OpenAiProvider::new())),
"anthropic" => Ok(Box::new(AnthropicProvider::new())),
"google" => Ok(Box::new(GoogleVertexProvider::new())),
"amazon" => Ok(Box::new(AmazonBedrockProvider::new())),
"cloudflare" => Ok(Box::new(CloudflareWorkersAiProvider::new())),
"deepseek" => Ok(Box::new(DeepSeekProvider::new())),
"minimax" => Ok(Box::new(MinimaxProvider::new())),
"zai" => Ok(Box::new(ZaiProvider::new())),
_ => Err(anyhow::anyhow!("Unsupported provider: {}. Supported providers: openrouter, openai, anthropic, google, amazon, cloudflare, deepseek, minimax, zai", provider_name)),
}
}
pub fn get_provider_for_model(model: &str) -> Result<(Box<dyn AiProvider>, String)> {
let (provider_name, model_name) = Self::parse_model(model)?;
let provider = Self::create_provider(&provider_name)?;
if !provider.supports_model(&model_name) {
return Err(anyhow::anyhow!(
"Provider '{}' does not support model '{}'",
provider_name,
model_name
));
}
Ok((provider, model_name))
}
pub fn supported_providers() -> Vec<&'static str> {
vec![
"openrouter",
"openai",
"anthropic",
"google",
"amazon",
"cloudflare",
"deepseek",
"minimax",
"zai",
]
}
pub fn validate_model_format(model: &str) -> Result<()> {
Self::parse_model(model)?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_model() {
let result = ProviderFactory::parse_model("openrouter:anthropic/claude-3.5-sonnet");
assert!(result.is_ok());
let (provider, model) = result.unwrap();
assert_eq!(provider, "openrouter");
assert_eq!(model, "anthropic/claude-3.5-sonnet");
let result = ProviderFactory::parse_model("openai:gpt-4o");
assert!(result.is_ok());
let (provider, model) = result.unwrap();
assert_eq!(provider, "openai");
assert_eq!(model, "gpt-4o");
let result = ProviderFactory::parse_model("deepseek:deepseek-chat");
assert!(result.is_ok());
let (provider, model) = result.unwrap();
assert_eq!(provider, "deepseek");
assert_eq!(model, "deepseek-chat");
let result = ProviderFactory::parse_model("gpt-4o");
assert!(result.is_err());
let result = ProviderFactory::parse_model(":gpt-4o");
assert!(result.is_err());
let result = ProviderFactory::parse_model("openai:");
assert!(result.is_err());
}
#[test]
fn test_supported_providers() {
let providers = ProviderFactory::supported_providers();
assert!(providers.contains(&"openai"));
assert!(providers.contains(&"anthropic"));
assert!(providers.contains(&"openrouter"));
assert!(providers.contains(&"google"));
assert!(providers.contains(&"amazon"));
assert!(providers.contains(&"cloudflare"));
assert!(providers.contains(&"deepseek"));
assert!(providers.contains(&"minimax"));
}
#[test]
fn test_validate_model_format() {
assert!(ProviderFactory::validate_model_format("openai:gpt-4o").is_ok());
assert!(ProviderFactory::validate_model_format("anthropic:claude-3.5-sonnet").is_ok());
assert!(ProviderFactory::validate_model_format("gpt-4o").is_err());
assert!(ProviderFactory::validate_model_format(":model").is_err());
assert!(ProviderFactory::validate_model_format("provider:").is_err());
}
#[test]
fn test_create_provider() {
assert!(ProviderFactory::create_provider("openai").is_ok());
assert!(ProviderFactory::create_provider("anthropic").is_ok());
assert!(ProviderFactory::create_provider("openrouter").is_ok());
assert!(ProviderFactory::create_provider("google").is_ok());
assert!(ProviderFactory::create_provider("amazon").is_ok());
assert!(ProviderFactory::create_provider("cloudflare").is_ok());
assert!(ProviderFactory::create_provider("deepseek").is_ok());
assert!(ProviderFactory::create_provider("minimax").is_ok());
assert!(ProviderFactory::create_provider("OpenAI").is_ok());
assert!(ProviderFactory::create_provider("ANTHROPIC").is_ok());
assert!(ProviderFactory::create_provider("MiniMax").is_ok());
assert!(ProviderFactory::create_provider("invalid").is_err());
}
#[test]
fn test_provider_capabilities() {
let openai = ProviderFactory::create_provider("openai").unwrap();
assert_eq!(openai.name(), "openai");
assert!(openai.supports_model("gpt-4o"));
assert!(openai.supports_vision("gpt-4o"));
assert!(openai.supports_caching("gpt-4o"));
let anthropic = ProviderFactory::create_provider("anthropic").unwrap();
assert_eq!(anthropic.name(), "anthropic");
assert!(anthropic.supports_model("claude-3.5-sonnet"));
assert!(anthropic.supports_vision("claude-3.5-sonnet"));
assert!(anthropic.supports_caching("claude-3.5-sonnet"));
let openrouter = ProviderFactory::create_provider("openrouter").unwrap();
assert_eq!(openrouter.name(), "openrouter");
assert!(openrouter.supports_model("any-model")); assert!(openrouter.supports_vision("claude-3.5-sonnet"));
assert!(openrouter.supports_caching("claude-3.5-sonnet"));
}
#[test]
fn test_get_provider_for_model() {
let result = ProviderFactory::get_provider_for_model("openai:gpt-4o");
assert!(result.is_ok());
let (provider, model) = result.unwrap();
assert_eq!(provider.name(), "openai");
assert_eq!(model, "gpt-4o");
let result = ProviderFactory::get_provider_for_model("anthropic:claude-3.5-sonnet");
assert!(result.is_ok());
let (provider, model) = result.unwrap();
assert_eq!(provider.name(), "anthropic");
assert_eq!(model, "claude-3.5-sonnet");
let result = ProviderFactory::get_provider_for_model("minimax:MiniMax-M2.1");
assert!(result.is_ok());
let (provider, model) = result.unwrap();
assert_eq!(provider.name(), "minimax");
assert_eq!(model, "MiniMax-M2.1");
assert!(provider.supports_caching(&model));
assert!(provider.supports_model(&model));
let result = ProviderFactory::get_provider_for_model("gpt-4o");
assert!(result.is_err());
let result = ProviderFactory::get_provider_for_model("invalid:model");
assert!(result.is_err());
}
#[test]
fn test_get_provider_for_model_case_insensitive() {
let result = ProviderFactory::get_provider_for_model("OPENAI:gpt-4o");
assert!(result.is_ok());
let (provider, model) = result.unwrap();
assert_eq!(provider.name(), "openai");
assert_eq!(model, "gpt-4o");
let result = ProviderFactory::get_provider_for_model("Anthropic:claude-3.5-sonnet");
assert!(result.is_ok());
let (provider, model) = result.unwrap();
assert_eq!(provider.name(), "anthropic");
assert_eq!(model, "claude-3.5-sonnet");
let result = ProviderFactory::get_provider_for_model("openai:GPT-4O");
assert!(result.is_ok());
let (provider, model) = result.unwrap();
assert_eq!(provider.name(), "openai");
assert_eq!(model, "GPT-4O");
assert!(provider.supports_model(&model));
let result = ProviderFactory::get_provider_for_model("minimax:MINIMAX-M2.1");
assert!(result.is_ok());
let (provider, model) = result.unwrap();
assert_eq!(provider.name(), "minimax");
assert_eq!(model, "MINIMAX-M2.1");
assert!(provider.supports_model(&model));
}
}