1use crate::Result;
2use crate::catalog::LlmModel;
3#[cfg(feature = "bedrock")]
4use crate::providers::bedrock::BedrockProvider;
5#[cfg(feature = "codex")]
6use crate::providers::codex::CodexProvider;
7use crate::providers::{
8 anthropic::AnthropicProvider,
9 gemini::GeminiProvider,
10 local::{llama_cpp::LlamaCppProvider, ollama::OllamaProvider},
11 openai::OpenAiProvider,
12 openai_compatible::generic::{self, GenericOpenAiProvider},
13 openrouter::OpenRouterProvider,
14};
15use crate::{
16 LlmError, ProviderConnectionConfig, ProviderConnectionOverrides, ProviderFactory, StreamingModelProvider,
17 alloyed::AlloyedModelProvider,
18};
19#[cfg(feature = "codex")]
20use aether_auth::OAuthCredentialStorage;
21use futures::future::BoxFuture;
22use std::collections::HashMap;
23#[cfg(feature = "codex")]
24use std::sync::Arc;
25
26#[doc = include_str!("docs/parser.md")]
27pub struct ModelProviderParser {
28 factories: HashMap<String, CreateProviderFn>,
29 provider_connections: ProviderConnectionOverrides,
30}
31
32impl ModelProviderParser {
33 pub fn new(factories: HashMap<String, CreateProviderFn>) -> Self {
35 Self { factories, provider_connections: ProviderConnectionOverrides::default() }
36 }
37}
38
39impl Default for ModelProviderParser {
40 fn default() -> Self {
42 let parser = Self::new(HashMap::new())
43 .with_provider::<AnthropicProvider>("anthropic")
44 .with_provider::<GeminiProvider>("gemini")
45 .with_provider::<OpenRouterProvider>("openrouter")
46 .with_provider::<OllamaProvider>("ollama")
47 .with_provider::<LlamaCppProvider>("llamacpp")
48 .with_provider::<OpenAiProvider>("openai")
49 .with_openai_provider("deepseek", &generic::DEEPSEEK)
50 .with_openai_provider("moonshot", &generic::MOONSHOT)
51 .with_openai_provider("zai", &generic::ZAI);
52
53 #[cfg(feature = "bedrock")]
54 let parser = parser.with_provider::<BedrockProvider>("bedrock");
55
56 parser
57 }
58}
59
60impl ModelProviderParser {
61 pub fn with_provider_connections(mut self, connections: ProviderConnectionOverrides) -> Self {
62 self.provider_connections = connections;
63 self
64 }
65
66 pub fn with_provider<P: ProviderFactory + StreamingModelProvider + 'static>(
67 mut self,
68 name: impl Into<String>,
69 ) -> Self {
70 self.factories.insert(
71 name.into(),
72 Box::new(|model: &str, connection: ProviderConnectionConfig| {
73 let model = model.to_string();
74 Box::pin(
75 async move { Ok(Box::new(P::from_env_with_connection(connection).await?.with_model(&model)) as _) },
76 )
77 }),
78 );
79 self
80 }
81
82 #[cfg(feature = "codex")]
83 pub fn with_codex_provider(mut self, store: Arc<dyn OAuthCredentialStorage>) -> Self {
84 self.factories.insert(
85 "codex".to_string(),
86 Box::new(move |model: &str, _connection: ProviderConnectionConfig| {
87 let store = Arc::clone(&store);
88 let model = model.to_string();
89 Box::pin(async move { Ok(Box::new(CodexProvider::new(store).with_model(&model)) as _) })
90 }),
91 );
92 self
93 }
94
95 pub fn with_openai_provider(mut self, name: impl Into<String>, config: &'static generic::ProviderConfig) -> Self {
96 self.factories.insert(
97 name.into(),
98 Box::new(move |model: &str, connection: ProviderConnectionConfig| {
99 let model = model.to_string();
100 Box::pin(async move {
101 Ok(
102 Box::new(
103 GenericOpenAiProvider::from_env_with_connection(config, connection)?.with_model(&model),
104 ) as _,
105 )
106 })
107 }),
108 );
109 self
110 }
111
112 pub async fn create_provider(&self, model: &LlmModel) -> Result<Box<dyn StreamingModelProvider>> {
114 let key = model.provider();
115 let factory = self.factories.get(key).ok_or_else(|| LlmError::Other(format!("Unknown provider: {key}")))?;
116 factory(&model.model_id(), self.provider_connections.config_for(key)).await
117 }
118
119 pub async fn parse(&self, models_str: &str) -> Result<(Box<dyn StreamingModelProvider>, LlmModel)> {
130 let provider_model_pairs: Vec<&str> = models_str.split(',').map(str::trim).collect();
131 if provider_model_pairs.is_empty() {
132 return Err(LlmError::Other("No models provided".to_string()));
133 }
134
135 let mut providers = Vec::new();
136 let mut first_identity: Option<LlmModel> = None;
137
138 for pair in provider_model_pairs {
139 let (provider_name, model) = pair.split_once(':').unwrap_or((pair, ""));
140
141 let factory = self
142 .factories
143 .get(provider_name)
144 .ok_or_else(|| LlmError::Other(format!("Unknown provider: {provider_name}")))?;
145
146 let connection = self.provider_connections.config_for(provider_name);
147 providers.push(factory(model, connection).await?);
148
149 if first_identity.is_none() {
150 first_identity = Some(pair.parse::<LlmModel>().map_err(LlmError::Other)?);
151 }
152 }
153
154 let identity = first_identity.ok_or_else(|| LlmError::Other("No providers parsed".to_string()))?;
155
156 let provider: Box<dyn StreamingModelProvider> = if providers.len() == 1 {
157 providers.into_iter().next().ok_or_else(|| LlmError::Other("No providers available".to_string()))?
158 } else {
159 Box::new(AlloyedModelProvider::new(providers))
160 };
161
162 Ok((provider, identity))
163 }
164}
165
166pub type CreateProviderFn = Box<
170 dyn Fn(&str, ProviderConnectionConfig) -> BoxFuture<'static, Result<Box<dyn StreamingModelProvider>>> + Send + Sync,
171>;
172
173#[cfg(test)]
174mod tests {
175 use super::*;
176
177 #[tokio::test]
178 async fn test_parse_llamacpp() {
179 let parser = ModelProviderParser::default();
180 let result = parser.parse("llamacpp").await;
181 assert!(result.is_ok());
182 let (_, model) = result.unwrap();
183 assert_eq!(model, LlmModel::LlamaCpp(String::new()));
184 }
185
186 #[tokio::test]
187 async fn test_parse_anthropic() {
188 let parser = ModelProviderParser::default();
189 let result = parser.parse("anthropic:claude-3-5-sonnet-20241022").await;
190 match result {
191 Ok((_, model)) => {
192 assert_eq!(model, LlmModel::Anthropic(crate::catalog::AnthropicModel::Claude35Sonnet20241022));
193 }
194 Err(e) => {
195 let err = e.to_string();
196 assert!(
197 err.contains("API")
198 || err.contains("ANTHROPIC")
199 || err.contains("credentials")
200 || err.contains("JSON"),
201 "Should fail on API key or credentials, not parsing. Got: {err}"
202 );
203 }
204 }
205 }
206
207 #[tokio::test]
208 async fn test_parse_ollama() {
209 let parser = ModelProviderParser::default();
210 let result = parser.parse("ollama:llama3.2").await;
211 assert!(result.is_ok());
212 let (_, model) = result.unwrap();
213 assert_eq!(model, LlmModel::Ollama("llama3.2".to_string()));
214 }
215
216 #[tokio::test]
217 async fn test_parse_openai() {
218 let parser = ModelProviderParser::default();
219 let result = parser.parse("openai:gpt-4.1").await;
220 if let Err(e) = result {
221 let err = e.to_string();
222 assert!(err.contains("API") || err.contains("OPENAI"), "Should fail on API key, not parsing. Got: {err}");
223 }
224 }
225
226 #[tokio::test]
227 async fn test_parse_openrouter() {
228 let parser = ModelProviderParser::default();
229 let result = parser.parse("openrouter:google/gemini-2.5-flash").await;
230 if let Err(e) = result {
231 let err = e.to_string();
232 assert!(err.contains("API") || err.contains("OPENROUTER"), "Should fail on API key, not parsing");
233 }
234 }
235
236 #[tokio::test]
237 async fn test_parse_gemini() {
238 let parser = ModelProviderParser::default();
239 let result = parser.parse("gemini:gemini-2.5-flash").await;
240 if let Err(e) = result {
241 let err = e.to_string();
242 assert!(err.contains("API") || err.contains("GEMINI"), "Should fail on API key, not parsing");
243 }
244 }
245
246 #[tokio::test]
247 async fn test_parse_provider_without_model() {
248 let parser = ModelProviderParser::default();
249 let result = parser.parse("anthropic").await;
250 assert!(result.is_err());
251 }
252
253 #[cfg(feature = "bedrock")]
254 #[tokio::test]
255 async fn test_parse_bedrock_inference_profile_arn() {
256 let parser = ModelProviderParser::default();
260 let arn = "arn:aws:bedrock:us-west-2:000000000000:inference-profile/us.anthropic.claude-opus-4-7";
261 let spec = format!("bedrock:{arn}");
262 let (provider, model) = parser.parse(&spec).await.expect("Bedrock ARN should parse");
263
264 assert_eq!(model.to_string(), spec);
265 assert_eq!(provider.display_name(), format!("Bedrock ({arn})"));
266 }
267
268 #[cfg(feature = "bedrock")]
269 #[tokio::test]
270 async fn test_parse_bedrock_application_inference_profile_arn() {
271 let parser = ModelProviderParser::default();
272 let arn = "arn:aws:bedrock:us-west-2:000000000000:application-inference-profile/000000000000";
273 let spec = format!("bedrock:{arn}");
274
275 let (_, model) = parser.parse(&spec).await.expect("Application inference profile ARN should parse");
276 assert_eq!(model.to_string(), spec);
277 assert_eq!(model.context_window(), None);
278 }
279
280 #[tokio::test]
281 async fn test_parse_unknown_provider() {
282 let parser = ModelProviderParser::default();
283 let result = parser.parse("unknown:model").await;
284 assert!(result.is_err());
285 if let Err(e) = result {
286 assert!(e.to_string().contains("Unknown provider"));
287 }
288 }
289
290 #[tokio::test]
291 async fn test_with_custom_provider() {
292 let parser = ModelProviderParser::default().with_provider::<OllamaProvider>("custom");
293
294 let model = LlmModel::Ollama("test-model".to_string());
295 let result = parser.create_provider(&model).await;
296 assert!(result.is_ok());
297 }
298
299 #[tokio::test]
300 async fn test_parse_single_provider() {
301 let parser = ModelProviderParser::default();
302 let result = parser.parse("llamacpp").await;
303 assert!(result.is_ok());
304 }
305
306 #[tokio::test]
307 async fn test_parse_multiple_providers() {
308 let parser = ModelProviderParser::default();
309 let result = parser.parse("llamacpp,ollama:llama3.2").await;
310 assert!(result.is_ok());
311 let (_, model) = result.unwrap();
312 assert_eq!(model, LlmModel::LlamaCpp(String::new()));
313 }
314
315 #[tokio::test]
316 async fn test_parse_with_spaces() {
317 let parser = ModelProviderParser::default();
318 let result = parser.parse("llamacpp , ollama:llama3.2").await;
319 assert!(result.is_ok());
320 }
321
322 #[test]
323 fn test_parser_is_send_sync() {
324 fn assert_send_sync<T: Send + Sync>() {}
325 assert_send_sync::<ModelProviderParser>();
326 }
327}