1use crate::Result;
2use crate::catalog::LlmModel;
3#[cfg(feature = "bedrock")]
4use crate::providers::bedrock::BedrockProvider;
5#[cfg(feature = "codex")]
6use crate::providers::codex::CodexProvider;
7use crate::providers::{
8 anthropic::AnthropicProvider,
9 gemini::GeminiProvider,
10 local::{llama_cpp::LlamaCppProvider, ollama::OllamaProvider},
11 openai::OpenAiProvider,
12 openai_compatible::generic::{self, GenericOpenAiProvider},
13 openrouter::OpenRouterProvider,
14};
15use crate::{LlmError, ProviderFactory, StreamingModelProvider, alloyed::AlloyedModelProvider};
16#[cfg(feature = "codex")]
17use aether_auth::OAuthCredentialStorage;
18use futures::future::BoxFuture;
19use std::collections::HashMap;
20#[cfg(feature = "codex")]
21use std::sync::Arc;
22
23#[doc = include_str!("docs/parser.md")]
24pub struct ModelProviderParser {
25 factories: HashMap<String, CreateProviderFn>,
26}
27
28impl ModelProviderParser {
29 pub fn new(factories: HashMap<String, CreateProviderFn>) -> Self {
31 Self { factories }
32 }
33}
34
35impl Default for ModelProviderParser {
36 fn default() -> Self {
38 let parser = Self::new(HashMap::new())
39 .with_provider::<AnthropicProvider>("anthropic")
40 .with_provider::<GeminiProvider>("gemini")
41 .with_provider::<OpenRouterProvider>("openrouter")
42 .with_provider::<OllamaProvider>("ollama")
43 .with_provider::<LlamaCppProvider>("llamacpp")
44 .with_provider::<OpenAiProvider>("openai")
45 .with_openai_provider("deepseek", &generic::DEEPSEEK)
46 .with_openai_provider("moonshot", &generic::MOONSHOT)
47 .with_openai_provider("zai", &generic::ZAI);
48
49 #[cfg(feature = "bedrock")]
50 let parser = parser.with_provider::<BedrockProvider>("bedrock");
51
52 parser
53 }
54}
55
56impl ModelProviderParser {
57 pub fn with_provider<P: ProviderFactory + StreamingModelProvider + 'static>(
58 mut self,
59 name: impl Into<String>,
60 ) -> Self {
61 self.factories.insert(
62 name.into(),
63 Box::new(|model: &str| {
64 let model = model.to_string();
65 Box::pin(async move { Ok(Box::new(P::from_env().await?.with_model(&model)) as _) })
66 }),
67 );
68 self
69 }
70
71 #[cfg(feature = "codex")]
72 pub fn with_codex_provider(mut self, store: Arc<dyn OAuthCredentialStorage>) -> Self {
73 self.factories.insert(
74 "codex".to_string(),
75 Box::new(move |model: &str| {
76 let store = Arc::clone(&store);
77 let model = model.to_string();
78 Box::pin(async move { Ok(Box::new(CodexProvider::new(store).with_model(&model)) as _) })
79 }),
80 );
81 self
82 }
83
84 pub fn with_openai_provider(mut self, name: impl Into<String>, config: &'static generic::ProviderConfig) -> Self {
85 self.factories.insert(
86 name.into(),
87 Box::new(move |model: &str| {
88 let model = model.to_string();
89 Box::pin(async move { Ok(Box::new(GenericOpenAiProvider::from_env(config)?.with_model(&model)) as _) })
90 }),
91 );
92 self
93 }
94
95 pub async fn create_provider(&self, model: &LlmModel) -> Result<Box<dyn StreamingModelProvider>> {
97 let key = model.provider();
98 let factory = self.factories.get(key).ok_or_else(|| LlmError::Other(format!("Unknown provider: {key}")))?;
99 factory(&model.model_id()).await
100 }
101
102 pub async fn parse(&self, models_str: &str) -> Result<(Box<dyn StreamingModelProvider>, LlmModel)> {
113 let provider_model_pairs: Vec<&str> = models_str.split(',').map(str::trim).collect();
114 if provider_model_pairs.is_empty() {
115 return Err(LlmError::Other("No models provided".to_string()));
116 }
117
118 let mut providers = Vec::new();
119 let mut first_identity: Option<LlmModel> = None;
120
121 for pair in provider_model_pairs {
122 let (provider_name, model) = pair.split_once(':').unwrap_or((pair, ""));
123
124 let factory = self
125 .factories
126 .get(provider_name)
127 .ok_or_else(|| LlmError::Other(format!("Unknown provider: {provider_name}")))?;
128
129 providers.push(factory(model).await?);
130
131 if first_identity.is_none() {
132 first_identity = Some(pair.parse::<LlmModel>().map_err(LlmError::Other)?);
133 }
134 }
135
136 let identity = first_identity.ok_or_else(|| LlmError::Other("No providers parsed".to_string()))?;
137
138 let provider: Box<dyn StreamingModelProvider> = if providers.len() == 1 {
139 providers.into_iter().next().ok_or_else(|| LlmError::Other("No providers available".to_string()))?
140 } else {
141 Box::new(AlloyedModelProvider::new(providers))
142 };
143
144 Ok((provider, identity))
145 }
146}
147
148pub type CreateProviderFn =
152 Box<dyn Fn(&str) -> BoxFuture<'static, Result<Box<dyn StreamingModelProvider>>> + Send + Sync>;
153
154#[cfg(test)]
155mod tests {
156 use super::*;
157
158 #[tokio::test]
159 async fn test_parse_llamacpp() {
160 let parser = ModelProviderParser::default();
161 let result = parser.parse("llamacpp").await;
162 assert!(result.is_ok());
163 let (_, model) = result.unwrap();
164 assert_eq!(model, LlmModel::LlamaCpp(String::new()));
165 }
166
167 #[tokio::test]
168 async fn test_parse_anthropic() {
169 let parser = ModelProviderParser::default();
170 let result = parser.parse("anthropic:claude-3-5-sonnet-20241022").await;
171 match result {
172 Ok((_, model)) => {
173 assert_eq!(model, LlmModel::Anthropic(crate::catalog::AnthropicModel::Claude35Sonnet20241022));
174 }
175 Err(e) => {
176 let err = e.to_string();
177 assert!(
178 err.contains("API")
179 || err.contains("ANTHROPIC")
180 || err.contains("credentials")
181 || err.contains("JSON"),
182 "Should fail on API key or credentials, not parsing. Got: {err}"
183 );
184 }
185 }
186 }
187
188 #[tokio::test]
189 async fn test_parse_ollama() {
190 let parser = ModelProviderParser::default();
191 let result = parser.parse("ollama:llama3.2").await;
192 assert!(result.is_ok());
193 let (_, model) = result.unwrap();
194 assert_eq!(model, LlmModel::Ollama("llama3.2".to_string()));
195 }
196
197 #[tokio::test]
198 async fn test_parse_openai() {
199 let parser = ModelProviderParser::default();
200 let result = parser.parse("openai:gpt-4.1").await;
201 if let Err(e) = result {
202 let err = e.to_string();
203 assert!(err.contains("API") || err.contains("OPENAI"), "Should fail on API key, not parsing. Got: {err}");
204 }
205 }
206
207 #[tokio::test]
208 async fn test_parse_openrouter() {
209 let parser = ModelProviderParser::default();
210 let result = parser.parse("openrouter:google/gemini-2.5-flash").await;
211 if let Err(e) = result {
212 let err = e.to_string();
213 assert!(err.contains("API") || err.contains("OPENROUTER"), "Should fail on API key, not parsing");
214 }
215 }
216
217 #[tokio::test]
218 async fn test_parse_gemini() {
219 let parser = ModelProviderParser::default();
220 let result = parser.parse("gemini:gemini-2.5-flash").await;
221 if let Err(e) = result {
222 let err = e.to_string();
223 assert!(err.contains("API") || err.contains("GEMINI"), "Should fail on API key, not parsing");
224 }
225 }
226
227 #[tokio::test]
228 async fn test_parse_provider_without_model() {
229 let parser = ModelProviderParser::default();
230 let result = parser.parse("anthropic").await;
231 assert!(result.is_err());
232 }
233
234 #[tokio::test]
235 async fn test_parse_unknown_provider() {
236 let parser = ModelProviderParser::default();
237 let result = parser.parse("unknown:model").await;
238 assert!(result.is_err());
239 if let Err(e) = result {
240 assert!(e.to_string().contains("Unknown provider"));
241 }
242 }
243
244 #[tokio::test]
245 async fn test_with_custom_provider() {
246 let parser = ModelProviderParser::default().with_provider::<OllamaProvider>("custom");
247
248 let model = LlmModel::Ollama("test-model".to_string());
249 let result = parser.create_provider(&model).await;
250 assert!(result.is_ok());
251 }
252
253 #[tokio::test]
254 async fn test_parse_single_provider() {
255 let parser = ModelProviderParser::default();
256 let result = parser.parse("llamacpp").await;
257 assert!(result.is_ok());
258 }
259
260 #[tokio::test]
261 async fn test_parse_multiple_providers() {
262 let parser = ModelProviderParser::default();
263 let result = parser.parse("llamacpp,ollama:llama3.2").await;
264 assert!(result.is_ok());
265 let (_, model) = result.unwrap();
266 assert_eq!(model, LlmModel::LlamaCpp(String::new()));
267 }
268
269 #[tokio::test]
270 async fn test_parse_with_spaces() {
271 let parser = ModelProviderParser::default();
272 let result = parser.parse("llamacpp , ollama:llama3.2").await;
273 assert!(result.is_ok());
274 }
275
276 #[test]
277 fn test_parser_is_send_sync() {
278 fn assert_send_sync<T: Send + Sync>() {}
279 assert_send_sync::<ModelProviderParser>();
280 }
281}