1use crate::Result;
2use crate::catalog::LlmModel;
3#[cfg(feature = "bedrock")]
4use crate::providers::bedrock::BedrockProvider;
5#[cfg(feature = "codex")]
6use crate::providers::codex::CodexProvider;
7use crate::providers::{
8 anthropic::AnthropicProvider,
9 gemini::GeminiProvider,
10 local::{llama_cpp::LlamaCppProvider, ollama::OllamaProvider},
11 openai::OpenAiProvider,
12 openai_compatible::generic::{self, GenericOpenAiProvider},
13 openrouter::OpenRouterProvider,
14};
15use crate::{LlmError, ProviderFactory, StreamingModelProvider, alloyed::AlloyedModelProvider};
16use futures::future::BoxFuture;
17use std::collections::HashMap;
18
19#[doc = include_str!("docs/parser.md")]
20pub struct ModelProviderParser {
21 factories: HashMap<String, CreateProviderFn>,
22}
23
24impl ModelProviderParser {
25 pub fn new(factories: HashMap<String, CreateProviderFn>) -> Self {
27 Self { factories }
28 }
29}
30
31impl Default for ModelProviderParser {
32 fn default() -> Self {
34 let parser = Self::new(HashMap::new())
35 .with_provider::<AnthropicProvider>("anthropic")
36 .with_provider::<GeminiProvider>("gemini")
37 .with_provider::<OpenRouterProvider>("openrouter")
38 .with_provider::<OllamaProvider>("ollama")
39 .with_provider::<LlamaCppProvider>("llamacpp")
40 .with_provider::<OpenAiProvider>("openai")
41 .with_openai_provider("deepseek", &generic::DEEPSEEK)
42 .with_openai_provider("moonshot", &generic::MOONSHOT)
43 .with_openai_provider("zai", &generic::ZAI);
44
45 #[cfg(feature = "bedrock")]
46 let parser = parser.with_provider::<BedrockProvider>("bedrock");
47
48 #[cfg(feature = "codex")]
49 let parser = parser.with_provider::<CodexProvider>("codex");
50
51 parser
52 }
53}
54
55impl ModelProviderParser {
56 pub fn with_provider<P: ProviderFactory + StreamingModelProvider + 'static>(
57 mut self,
58 name: impl Into<String>,
59 ) -> Self {
60 self.factories.insert(
61 name.into(),
62 Box::new(|model: &str| {
63 let model = model.to_string();
64 Box::pin(async move { Ok(Box::new(P::from_env().await?.with_model(&model)) as _) })
65 }),
66 );
67 self
68 }
69
70 pub fn with_openai_provider(mut self, name: impl Into<String>, config: &'static generic::ProviderConfig) -> Self {
71 self.factories.insert(
72 name.into(),
73 Box::new(move |model: &str| {
74 let model = model.to_string();
75 Box::pin(async move { Ok(Box::new(GenericOpenAiProvider::from_env(config)?.with_model(&model)) as _) })
76 }),
77 );
78 self
79 }
80
81 pub async fn create_provider(&self, model: &LlmModel) -> Result<Box<dyn StreamingModelProvider>> {
83 let key = model.provider();
84 let factory = self.factories.get(key).ok_or_else(|| LlmError::Other(format!("Unknown provider: {key}")))?;
85 factory(&model.model_id()).await
86 }
87
88 pub async fn parse(&self, models_str: &str) -> Result<(Box<dyn StreamingModelProvider>, LlmModel)> {
99 let provider_model_pairs: Vec<&str> = models_str.split(',').map(str::trim).collect();
100 if provider_model_pairs.is_empty() {
101 return Err(LlmError::Other("No models provided".to_string()));
102 }
103
104 let mut providers = Vec::new();
105 let mut first_identity: Option<LlmModel> = None;
106
107 for pair in provider_model_pairs {
108 let (provider_name, model) = pair.split_once(':').unwrap_or((pair, ""));
109
110 let factory = self
111 .factories
112 .get(provider_name)
113 .ok_or_else(|| LlmError::Other(format!("Unknown provider: {provider_name}")))?;
114
115 providers.push(factory(model).await?);
116
117 if first_identity.is_none() {
118 first_identity = Some(pair.parse::<LlmModel>().map_err(LlmError::Other)?);
119 }
120 }
121
122 let identity = first_identity.ok_or_else(|| LlmError::Other("No providers parsed".to_string()))?;
123
124 let provider: Box<dyn StreamingModelProvider> = if providers.len() == 1 {
125 providers.into_iter().next().ok_or_else(|| LlmError::Other("No providers available".to_string()))?
126 } else {
127 Box::new(AlloyedModelProvider::new(providers))
128 };
129
130 Ok((provider, identity))
131 }
132}
133
134pub type CreateProviderFn =
138 Box<dyn Fn(&str) -> BoxFuture<'static, Result<Box<dyn StreamingModelProvider>>> + Send + Sync>;
139
140#[cfg(test)]
141mod tests {
142 use super::*;
143
144 #[tokio::test]
145 async fn test_parse_llamacpp() {
146 let parser = ModelProviderParser::default();
147 let result = parser.parse("llamacpp").await;
148 assert!(result.is_ok());
149 let (_, model) = result.unwrap();
150 assert_eq!(model, LlmModel::LlamaCpp(String::new()));
151 }
152
153 #[tokio::test]
154 async fn test_parse_anthropic() {
155 let parser = ModelProviderParser::default();
156 let result = parser.parse("anthropic:claude-3-5-sonnet-20241022").await;
157 match result {
158 Ok((_, model)) => {
159 assert_eq!(model, LlmModel::Anthropic(crate::catalog::AnthropicModel::Claude35Sonnet20241022));
160 }
161 Err(e) => {
162 let err = e.to_string();
163 assert!(
164 err.contains("API")
165 || err.contains("ANTHROPIC")
166 || err.contains("credentials")
167 || err.contains("JSON"),
168 "Should fail on API key or credentials, not parsing. Got: {err}"
169 );
170 }
171 }
172 }
173
174 #[tokio::test]
175 async fn test_parse_ollama() {
176 let parser = ModelProviderParser::default();
177 let result = parser.parse("ollama:llama3.2").await;
178 assert!(result.is_ok());
179 let (_, model) = result.unwrap();
180 assert_eq!(model, LlmModel::Ollama("llama3.2".to_string()));
181 }
182
183 #[tokio::test]
184 async fn test_parse_openai() {
185 let parser = ModelProviderParser::default();
186 let result = parser.parse("openai:gpt-4.1").await;
187 if let Err(e) = result {
188 let err = e.to_string();
189 assert!(err.contains("API") || err.contains("OPENAI"), "Should fail on API key, not parsing. Got: {err}");
190 }
191 }
192
193 #[tokio::test]
194 async fn test_parse_openrouter() {
195 let parser = ModelProviderParser::default();
196 let result = parser.parse("openrouter:google/gemini-2.5-flash").await;
197 if let Err(e) = result {
198 let err = e.to_string();
199 assert!(err.contains("API") || err.contains("OPENROUTER"), "Should fail on API key, not parsing");
200 }
201 }
202
203 #[tokio::test]
204 async fn test_parse_gemini() {
205 let parser = ModelProviderParser::default();
206 let result = parser.parse("gemini:gemini-2.5-flash").await;
207 if let Err(e) = result {
208 let err = e.to_string();
209 assert!(err.contains("API") || err.contains("GEMINI"), "Should fail on API key, not parsing");
210 }
211 }
212
213 #[tokio::test]
214 async fn test_parse_provider_without_model() {
215 let parser = ModelProviderParser::default();
216 let result = parser.parse("anthropic").await;
217 assert!(result.is_err());
218 }
219
220 #[tokio::test]
221 async fn test_parse_unknown_provider() {
222 let parser = ModelProviderParser::default();
223 let result = parser.parse("unknown:model").await;
224 assert!(result.is_err());
225 if let Err(e) = result {
226 assert!(e.to_string().contains("Unknown provider"));
227 }
228 }
229
230 #[tokio::test]
231 async fn test_with_custom_provider() {
232 let parser = ModelProviderParser::default().with_provider::<OllamaProvider>("custom");
233
234 let model = LlmModel::Ollama("test-model".to_string());
235 let result = parser.create_provider(&model).await;
236 assert!(result.is_ok());
237 }
238
239 #[tokio::test]
240 async fn test_parse_single_provider() {
241 let parser = ModelProviderParser::default();
242 let result = parser.parse("llamacpp").await;
243 assert!(result.is_ok());
244 }
245
246 #[tokio::test]
247 async fn test_parse_multiple_providers() {
248 let parser = ModelProviderParser::default();
249 let result = parser.parse("llamacpp,ollama:llama3.2").await;
250 assert!(result.is_ok());
251 let (_, model) = result.unwrap();
252 assert_eq!(model, LlmModel::LlamaCpp(String::new()));
253 }
254
255 #[tokio::test]
256 async fn test_parse_with_spaces() {
257 let parser = ModelProviderParser::default();
258 let result = parser.parse("llamacpp , ollama:llama3.2").await;
259 assert!(result.is_ok());
260 }
261
262 #[test]
263 fn test_parser_is_send_sync() {
264 fn assert_send_sync<T: Send + Sync>() {}
265 assert_send_sync::<ModelProviderParser>();
266 }
267}