1use crate::Result;
2use crate::catalog::LlmModel;
3#[cfg(feature = "bedrock")]
4use crate::providers::bedrock::BedrockProvider;
5#[cfg(feature = "codex")]
6use crate::providers::codex::CodexProvider;
7use crate::providers::{
8 anthropic::AnthropicProvider,
9 gemini::GeminiProvider,
10 local::{llama_cpp::LlamaCppProvider, ollama::OllamaProvider},
11 openai::OpenAiProvider,
12 openai_compatible::generic::{self, GenericOpenAiProvider},
13 openrouter::OpenRouterProvider,
14};
15use crate::{
16 LlmError, ProviderConnectionConfig, ProviderConnectionOverrides, ProviderFactory, StreamingModelProvider,
17 alloyed::AlloyedModelProvider,
18};
19#[cfg(feature = "codex")]
20use aether_auth::OAuthCredentialStorage;
21use futures::future::BoxFuture;
22use std::collections::HashMap;
23#[cfg(feature = "codex")]
24use std::sync::Arc;
25
26#[doc = include_str!("docs/parser.md")]
27pub struct ModelProviderParser {
28 factories: HashMap<String, CreateProviderFn>,
29 provider_connections: ProviderConnectionOverrides,
30}
31
32impl ModelProviderParser {
33 pub fn new(factories: HashMap<String, CreateProviderFn>) -> Self {
35 Self { factories, provider_connections: ProviderConnectionOverrides::default() }
36 }
37}
38
39impl Default for ModelProviderParser {
40 fn default() -> Self {
42 let parser = Self::new(HashMap::new())
43 .with_provider::<AnthropicProvider>("anthropic")
44 .with_provider::<GeminiProvider>("gemini")
45 .with_provider::<OpenRouterProvider>("openrouter")
46 .with_provider::<OllamaProvider>("ollama")
47 .with_provider::<LlamaCppProvider>("llamacpp")
48 .with_provider::<OpenAiProvider>("openai")
49 .with_openai_provider("deepseek", &generic::DEEPSEEK)
50 .with_openai_provider("moonshot", &generic::MOONSHOT)
51 .with_openai_provider("zai", &generic::ZAI);
52
53 #[cfg(feature = "bedrock")]
54 let parser = parser.with_provider::<BedrockProvider>("bedrock");
55
56 parser
57 }
58}
59
60impl ModelProviderParser {
61 pub fn with_provider_connections(mut self, connections: ProviderConnectionOverrides) -> Self {
62 self.provider_connections = connections;
63 self
64 }
65
66 pub fn with_provider<P: ProviderFactory + StreamingModelProvider + 'static>(
67 mut self,
68 name: impl Into<String>,
69 ) -> Self {
70 self.factories.insert(
71 name.into(),
72 Box::new(|model: &str, connection: ProviderConnectionConfig| {
73 let model = model.to_string();
74 Box::pin(
75 async move { Ok(Box::new(P::from_env_with_connection(connection).await?.with_model(&model)) as _) },
76 )
77 }),
78 );
79 self
80 }
81
82 #[cfg(feature = "codex")]
83 pub fn with_codex_provider(mut self, store: Arc<dyn OAuthCredentialStorage>) -> Self {
84 self.factories.insert(
85 "codex".to_string(),
86 Box::new(move |model: &str, _connection: ProviderConnectionConfig| {
87 let store = Arc::clone(&store);
88 let model = model.to_string();
89 Box::pin(async move { Ok(Box::new(CodexProvider::new(store).with_model(&model)) as _) })
90 }),
91 );
92 self
93 }
94
95 pub fn with_openai_provider(mut self, name: impl Into<String>, config: &'static generic::ProviderConfig) -> Self {
96 self.factories.insert(
97 name.into(),
98 Box::new(move |model: &str, connection: ProviderConnectionConfig| {
99 let model = model.to_string();
100 Box::pin(async move {
101 Ok(
102 Box::new(
103 GenericOpenAiProvider::from_env_with_connection(config, connection)?.with_model(&model),
104 ) as _,
105 )
106 })
107 }),
108 );
109 self
110 }
111
112 pub async fn create_provider(&self, model: &LlmModel) -> Result<Box<dyn StreamingModelProvider>> {
114 let key = model.provider();
115 let factory = self.factories.get(key).ok_or_else(|| LlmError::Other(format!("Unknown provider: {key}")))?;
116 factory(&model.model_id(), self.provider_connections.config_for(key)).await
117 }
118
119 pub async fn parse(&self, models_str: &str) -> Result<(Box<dyn StreamingModelProvider>, LlmModel)> {
130 let provider_model_pairs: Vec<&str> = models_str.split(',').map(str::trim).collect();
131 if provider_model_pairs.is_empty() {
132 return Err(LlmError::Other("No models provided".to_string()));
133 }
134
135 let bedrock_has_inference_profile_arn =
136 self.provider_connections.config_for("bedrock").inference_profile_arn.is_some();
137 let mut seen_bedrock = false;
138 let mut providers = Vec::new();
139 let mut first_identity: Option<LlmModel> = None;
140
141 for pair in provider_model_pairs {
142 let (provider_name, model) = pair.split_once(':').unwrap_or((pair, ""));
143
144 if provider_name == "bedrock" && bedrock_has_inference_profile_arn {
145 if seen_bedrock {
146 return Err(LlmError::Other(
147 "providers.bedrock.inferenceProfileArn cannot be used with multiple bedrock models in one alloy spec"
148 .to_string(),
149 ));
150 }
151 seen_bedrock = true;
152 }
153
154 let factory = self
155 .factories
156 .get(provider_name)
157 .ok_or_else(|| LlmError::Other(format!("Unknown provider: {provider_name}")))?;
158
159 let connection = self.provider_connections.config_for(provider_name);
160 providers.push(factory(model, connection).await?);
161
162 if first_identity.is_none() {
163 first_identity = Some(pair.parse::<LlmModel>().map_err(LlmError::Other)?);
164 }
165 }
166
167 let identity = first_identity.ok_or_else(|| LlmError::Other("No providers parsed".to_string()))?;
168
169 let provider: Box<dyn StreamingModelProvider> = if providers.len() == 1 {
170 providers.into_iter().next().ok_or_else(|| LlmError::Other("No providers available".to_string()))?
171 } else {
172 Box::new(AlloyedModelProvider::new(providers))
173 };
174
175 Ok((provider, identity))
176 }
177}
178
179pub type CreateProviderFn = Box<
183 dyn Fn(&str, ProviderConnectionConfig) -> BoxFuture<'static, Result<Box<dyn StreamingModelProvider>>> + Send + Sync,
184>;
185
186#[cfg(test)]
187mod tests {
188 use super::*;
189
190 #[tokio::test]
191 async fn test_parse_llamacpp() {
192 let parser = ModelProviderParser::default();
193 let result = parser.parse("llamacpp").await;
194 assert!(result.is_ok());
195 let (_, model) = result.unwrap();
196 assert_eq!(model, LlmModel::LlamaCpp(String::new()));
197 }
198
199 #[tokio::test]
200 async fn test_parse_anthropic() {
201 let parser = ModelProviderParser::default();
202 let result = parser.parse("anthropic:claude-3-5-sonnet-20241022").await;
203 match result {
204 Ok((_, model)) => {
205 assert_eq!(model, LlmModel::Anthropic(crate::catalog::AnthropicModel::Claude35Sonnet20241022));
206 }
207 Err(e) => {
208 let err = e.to_string();
209 assert!(
210 err.contains("API")
211 || err.contains("ANTHROPIC")
212 || err.contains("credentials")
213 || err.contains("JSON"),
214 "Should fail on API key or credentials, not parsing. Got: {err}"
215 );
216 }
217 }
218 }
219
220 #[tokio::test]
221 async fn test_parse_ollama() {
222 let parser = ModelProviderParser::default();
223 let result = parser.parse("ollama:llama3.2").await;
224 assert!(result.is_ok());
225 let (_, model) = result.unwrap();
226 assert_eq!(model, LlmModel::Ollama("llama3.2".to_string()));
227 }
228
229 #[tokio::test]
230 async fn test_parse_openai() {
231 let parser = ModelProviderParser::default();
232 let result = parser.parse("openai:gpt-4.1").await;
233 if let Err(e) = result {
234 let err = e.to_string();
235 assert!(err.contains("API") || err.contains("OPENAI"), "Should fail on API key, not parsing. Got: {err}");
236 }
237 }
238
239 #[tokio::test]
240 async fn test_parse_openrouter() {
241 let parser = ModelProviderParser::default();
242 let result = parser.parse("openrouter:google/gemini-2.5-flash").await;
243 if let Err(e) = result {
244 let err = e.to_string();
245 assert!(err.contains("API") || err.contains("OPENROUTER"), "Should fail on API key, not parsing");
246 }
247 }
248
249 #[tokio::test]
250 async fn test_parse_gemini() {
251 let parser = ModelProviderParser::default();
252 let result = parser.parse("gemini:gemini-2.5-flash").await;
253 if let Err(e) = result {
254 let err = e.to_string();
255 assert!(err.contains("API") || err.contains("GEMINI"), "Should fail on API key, not parsing");
256 }
257 }
258
259 #[tokio::test]
260 async fn test_parse_provider_without_model() {
261 let parser = ModelProviderParser::default();
262 let result = parser.parse("anthropic").await;
263 assert!(result.is_err());
264 }
265
266 #[cfg(feature = "bedrock")]
267 #[tokio::test]
268 async fn test_parse_rejects_bedrock_inference_profile_arn() {
269 let parser = ModelProviderParser::default();
270 let spec = "bedrock:arn:aws:bedrock:us-west-2:000000000000:inference-profile/us.anthropic.claude-opus-4-7";
271
272 let error = match parser.parse(spec).await {
273 Ok(_) => panic!("Bedrock ARN should be rejected"),
274 Err(error) => error.to_string(),
275 };
276
277 assert!(error.contains("providers.bedrock.inferenceProfileArn"), "{error}");
278 }
279
280 #[cfg(feature = "bedrock")]
281 #[tokio::test]
282 async fn test_parse_rejects_bedrock_application_inference_profile_arn() {
283 let parser = ModelProviderParser::default();
284 let spec = "bedrock:arn:aws:bedrock:us-west-2:000000000000:application-inference-profile/000000000000";
285
286 let error = match parser.parse(spec).await {
287 Ok(_) => panic!("Bedrock ARN should be rejected"),
288 Err(error) => error.to_string(),
289 };
290
291 assert!(error.contains("providers.bedrock.inferenceProfileArn"), "{error}");
292 }
293
294 #[tokio::test]
295 async fn test_parse_unknown_provider() {
296 let parser = ModelProviderParser::default();
297 let result = parser.parse("unknown:model").await;
298 assert!(result.is_err());
299 if let Err(e) = result {
300 assert!(e.to_string().contains("Unknown provider"));
301 }
302 }
303
304 #[tokio::test]
305 async fn test_with_custom_provider() {
306 let parser = ModelProviderParser::default().with_provider::<OllamaProvider>("custom");
307
308 let model = LlmModel::Ollama("test-model".to_string());
309 let result = parser.create_provider(&model).await;
310 assert!(result.is_ok());
311 }
312
313 #[tokio::test]
314 async fn test_parse_single_provider() {
315 let parser = ModelProviderParser::default();
316 let result = parser.parse("llamacpp").await;
317 assert!(result.is_ok());
318 }
319
320 #[tokio::test]
321 async fn test_parse_multiple_providers() {
322 let parser = ModelProviderParser::default();
323 let result = parser.parse("llamacpp,ollama:llama3.2").await;
324 assert!(result.is_ok());
325 let (_, model) = result.unwrap();
326 assert_eq!(model, LlmModel::LlamaCpp(String::new()));
327 }
328
329 #[tokio::test]
330 async fn test_parse_with_spaces() {
331 let parser = ModelProviderParser::default();
332 let result = parser.parse("llamacpp , ollama:llama3.2").await;
333 assert!(result.is_ok());
334 }
335
336 #[test]
337 fn test_parser_is_send_sync() {
338 fn assert_send_sync<T: Send + Sync>() {}
339 assert_send_sync::<ModelProviderParser>();
340 }
341}