use crate::Result;
use crate::catalog::LlmModel;
#[cfg(feature = "bedrock")]
use crate::providers::bedrock::BedrockProvider;
#[cfg(feature = "codex")]
use crate::providers::codex::CodexProvider;
use crate::providers::{
anthropic::AnthropicProvider,
gemini::GeminiProvider,
local::{llama_cpp::LlamaCppProvider, ollama::OllamaProvider},
openai::OpenAiProvider,
openai_compatible::generic::{self, GenericOpenAiProvider},
openrouter::OpenRouterProvider,
};
use crate::{LlmError, ProviderFactory, StreamingModelProvider, alloyed::AlloyedModelProvider};
use futures::future::BoxFuture;
use std::collections::HashMap;
#[doc = include_str!("docs/parser.md")]
pub struct ModelProviderParser {
factories: HashMap<String, CreateProviderFn>,
}
impl ModelProviderParser {
pub fn new(factories: HashMap<String, CreateProviderFn>) -> Self {
Self { factories }
}
}
impl Default for ModelProviderParser {
fn default() -> Self {
let parser = Self::new(HashMap::new())
.with_provider::<AnthropicProvider>("anthropic")
.with_provider::<GeminiProvider>("gemini")
.with_provider::<OpenRouterProvider>("openrouter")
.with_provider::<OllamaProvider>("ollama")
.with_provider::<LlamaCppProvider>("llamacpp")
.with_provider::<OpenAiProvider>("openai")
.with_openai_provider("deepseek", &generic::DEEPSEEK)
.with_openai_provider("moonshot", &generic::MOONSHOT)
.with_openai_provider("zai", &generic::ZAI);
#[cfg(feature = "bedrock")]
let parser = parser.with_provider::<BedrockProvider>("bedrock");
#[cfg(feature = "codex")]
let parser = parser.with_provider::<CodexProvider>("codex");
parser
}
}
impl ModelProviderParser {
pub fn with_provider<P: ProviderFactory + StreamingModelProvider + 'static>(
mut self,
name: impl Into<String>,
) -> Self {
self.factories.insert(
name.into(),
Box::new(|model: &str| {
let model = model.to_string();
Box::pin(async move { Ok(Box::new(P::from_env().await?.with_model(&model)) as _) })
}),
);
self
}
pub fn with_openai_provider(mut self, name: impl Into<String>, config: &'static generic::ProviderConfig) -> Self {
self.factories.insert(
name.into(),
Box::new(move |model: &str| {
let model = model.to_string();
Box::pin(async move { Ok(Box::new(GenericOpenAiProvider::from_env(config)?.with_model(&model)) as _) })
}),
);
self
}
pub async fn create_provider(&self, model: &LlmModel) -> Result<Box<dyn StreamingModelProvider>> {
let key = model.provider();
let factory = self.factories.get(key).ok_or_else(|| LlmError::Other(format!("Unknown provider: {key}")))?;
factory(&model.model_id()).await
}
pub async fn parse(&self, models_str: &str) -> Result<(Box<dyn StreamingModelProvider>, LlmModel)> {
let provider_model_pairs: Vec<&str> = models_str.split(',').map(str::trim).collect();
if provider_model_pairs.is_empty() {
return Err(LlmError::Other("No models provided".to_string()));
}
let mut providers = Vec::new();
let mut first_identity: Option<LlmModel> = None;
for pair in provider_model_pairs {
let (provider_name, model) = pair.split_once(':').unwrap_or((pair, ""));
let factory = self
.factories
.get(provider_name)
.ok_or_else(|| LlmError::Other(format!("Unknown provider: {provider_name}")))?;
providers.push(factory(model).await?);
if first_identity.is_none() {
first_identity = Some(pair.parse::<LlmModel>().map_err(LlmError::Other)?);
}
}
let identity = first_identity.ok_or_else(|| LlmError::Other("No providers parsed".to_string()))?;
let provider: Box<dyn StreamingModelProvider> = if providers.len() == 1 {
providers.into_iter().next().ok_or_else(|| LlmError::Other("No providers available".to_string()))?
} else {
Box::new(AlloyedModelProvider::new(providers))
};
Ok((provider, identity))
}
}
pub type CreateProviderFn =
Box<dyn Fn(&str) -> BoxFuture<'static, Result<Box<dyn StreamingModelProvider>>> + Send + Sync>;
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_parse_llamacpp() {
let parser = ModelProviderParser::default();
let result = parser.parse("llamacpp").await;
assert!(result.is_ok());
let (_, model) = result.unwrap();
assert_eq!(model, LlmModel::LlamaCpp(String::new()));
}
#[tokio::test]
async fn test_parse_anthropic() {
let parser = ModelProviderParser::default();
let result = parser.parse("anthropic:claude-3-5-sonnet-20241022").await;
match result {
Ok((_, model)) => {
assert_eq!(model, LlmModel::Anthropic(crate::catalog::AnthropicModel::Claude35Sonnet20241022));
}
Err(e) => {
let err = e.to_string();
assert!(
err.contains("API")
|| err.contains("ANTHROPIC")
|| err.contains("credentials")
|| err.contains("JSON"),
"Should fail on API key or credentials, not parsing. Got: {err}"
);
}
}
}
#[tokio::test]
async fn test_parse_ollama() {
let parser = ModelProviderParser::default();
let result = parser.parse("ollama:llama3.2").await;
assert!(result.is_ok());
let (_, model) = result.unwrap();
assert_eq!(model, LlmModel::Ollama("llama3.2".to_string()));
}
#[tokio::test]
async fn test_parse_openai() {
let parser = ModelProviderParser::default();
let result = parser.parse("openai:gpt-4.1").await;
if let Err(e) = result {
let err = e.to_string();
assert!(err.contains("API") || err.contains("OPENAI"), "Should fail on API key, not parsing. Got: {err}");
}
}
#[tokio::test]
async fn test_parse_openrouter() {
let parser = ModelProviderParser::default();
let result = parser.parse("openrouter:google/gemini-2.5-flash").await;
if let Err(e) = result {
let err = e.to_string();
assert!(err.contains("API") || err.contains("OPENROUTER"), "Should fail on API key, not parsing");
}
}
#[tokio::test]
async fn test_parse_gemini() {
let parser = ModelProviderParser::default();
let result = parser.parse("gemini:gemini-2.5-flash").await;
if let Err(e) = result {
let err = e.to_string();
assert!(err.contains("API") || err.contains("GEMINI"), "Should fail on API key, not parsing");
}
}
#[tokio::test]
async fn test_parse_provider_without_model() {
let parser = ModelProviderParser::default();
let result = parser.parse("anthropic").await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_parse_unknown_provider() {
let parser = ModelProviderParser::default();
let result = parser.parse("unknown:model").await;
assert!(result.is_err());
if let Err(e) = result {
assert!(e.to_string().contains("Unknown provider"));
}
}
#[tokio::test]
async fn test_with_custom_provider() {
let parser = ModelProviderParser::default().with_provider::<OllamaProvider>("custom");
let model = LlmModel::Ollama("test-model".to_string());
let result = parser.create_provider(&model).await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_parse_single_provider() {
let parser = ModelProviderParser::default();
let result = parser.parse("llamacpp").await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_parse_multiple_providers() {
let parser = ModelProviderParser::default();
let result = parser.parse("llamacpp,ollama:llama3.2").await;
assert!(result.is_ok());
let (_, model) = result.unwrap();
assert_eq!(model, LlmModel::LlamaCpp(String::new()));
}
#[tokio::test]
async fn test_parse_with_spaces() {
let parser = ModelProviderParser::default();
let result = parser.parse("llamacpp , ollama:llama3.2").await;
assert!(result.is_ok());
}
#[test]
fn test_parser_is_send_sync() {
fn assert_send_sync<T: Send + Sync>() {}
assert_send_sync::<ModelProviderParser>();
}
}