Skip to main content

llm/
parser.rs

1use crate::Result;
2use crate::catalog::LlmModel;
3#[cfg(feature = "bedrock")]
4use crate::providers::bedrock::BedrockProvider;
5#[cfg(feature = "codex")]
6use crate::providers::codex::CodexProvider;
7use crate::providers::{
8    anthropic::AnthropicProvider,
9    gemini::GeminiProvider,
10    local::{llama_cpp::LlamaCppProvider, ollama::OllamaProvider},
11    openai::OpenAiProvider,
12    openai_compatible::generic::{self, GenericOpenAiProvider},
13    openrouter::OpenRouterProvider,
14};
15use crate::{
16    LlmError, ProviderConnectionConfig, ProviderConnectionOverrides, ProviderFactory, StreamingModelProvider,
17    alloyed::AlloyedModelProvider,
18};
19#[cfg(feature = "codex")]
20use aether_auth::OAuthCredentialStorage;
21use futures::future::BoxFuture;
22use std::collections::HashMap;
23#[cfg(feature = "codex")]
24use std::sync::Arc;
25
26#[doc = include_str!("docs/parser.md")]
27pub struct ModelProviderParser {
28    factories: HashMap<String, CreateProviderFn>,
29    provider_connections: ProviderConnectionOverrides,
30}
31
32impl ModelProviderParser {
33    /// Create a new parser with custom provider factories
34    pub fn new(factories: HashMap<String, CreateProviderFn>) -> Self {
35        Self { factories, provider_connections: ProviderConnectionOverrides::default() }
36    }
37}
38
39impl Default for ModelProviderParser {
40    /// Create a parser with all built-in providers registered
41    fn default() -> Self {
42        let parser = Self::new(HashMap::new())
43            .with_provider::<AnthropicProvider>("anthropic")
44            .with_provider::<GeminiProvider>("gemini")
45            .with_provider::<OpenRouterProvider>("openrouter")
46            .with_provider::<OllamaProvider>("ollama")
47            .with_provider::<LlamaCppProvider>("llamacpp")
48            .with_provider::<OpenAiProvider>("openai")
49            .with_openai_provider("deepseek", &generic::DEEPSEEK)
50            .with_openai_provider("moonshot", &generic::MOONSHOT)
51            .with_openai_provider("zai", &generic::ZAI);
52
53        #[cfg(feature = "bedrock")]
54        let parser = parser.with_provider::<BedrockProvider>("bedrock");
55
56        parser
57    }
58}
59
60impl ModelProviderParser {
61    pub fn with_provider_connections(mut self, connections: ProviderConnectionOverrides) -> Self {
62        self.provider_connections = connections;
63        self
64    }
65
66    pub fn with_provider<P: ProviderFactory + StreamingModelProvider + 'static>(
67        mut self,
68        name: impl Into<String>,
69    ) -> Self {
70        self.factories.insert(
71            name.into(),
72            Box::new(|model: &str, connection: ProviderConnectionConfig| {
73                let model = model.to_string();
74                Box::pin(
75                    async move { Ok(Box::new(P::from_env_with_connection(connection).await?.with_model(&model)) as _) },
76                )
77            }),
78        );
79        self
80    }
81
82    #[cfg(feature = "codex")]
83    pub fn with_codex_provider(mut self, store: Arc<dyn OAuthCredentialStorage>) -> Self {
84        self.factories.insert(
85            "codex".to_string(),
86            Box::new(move |model: &str, _connection: ProviderConnectionConfig| {
87                let store = Arc::clone(&store);
88                let model = model.to_string();
89                Box::pin(async move { Ok(Box::new(CodexProvider::new(store).with_model(&model)) as _) })
90            }),
91        );
92        self
93    }
94
95    pub fn with_openai_provider(mut self, name: impl Into<String>, config: &'static generic::ProviderConfig) -> Self {
96        self.factories.insert(
97            name.into(),
98            Box::new(move |model: &str, connection: ProviderConnectionConfig| {
99                let model = model.to_string();
100                Box::pin(async move {
101                    Ok(
102                        Box::new(
103                            GenericOpenAiProvider::from_env_with_connection(config, connection)?.with_model(&model),
104                        ) as _,
105                    )
106                })
107            }),
108        );
109        self
110    }
111
112    /// Create a provider from a typed `LlmModel`
113    pub async fn create_provider(&self, model: &LlmModel) -> Result<Box<dyn StreamingModelProvider>> {
114        let key = model.provider();
115        let factory = self.factories.get(key).ok_or_else(|| LlmError::Other(format!("Unknown provider: {key}")))?;
116        factory(&model.model_id(), self.provider_connections.config_for(key)).await
117    }
118
119    /// Parse a model specification string and create a provider instance.
120    ///
121    /// Returns both the provider and an `LlmModel` describing the identity
122    /// of the first (or only) provider in the spec.
123    ///
124    /// # Format
125    ///
126    /// - `"provider:model"` - Single provider (e.g., "anthropic:claude-3.5-sonnet")
127    /// - `"provider1:model1,provider2:model2"` - Multiple providers create an `AlloyedModelProvider`
128    ///
129    pub async fn parse(&self, models_str: &str) -> Result<(Box<dyn StreamingModelProvider>, LlmModel)> {
130        let provider_model_pairs: Vec<&str> = models_str.split(',').map(str::trim).collect();
131        if provider_model_pairs.is_empty() {
132            return Err(LlmError::Other("No models provided".to_string()));
133        }
134
135        let mut providers = Vec::new();
136        let mut first_identity: Option<LlmModel> = None;
137
138        for pair in provider_model_pairs {
139            let (provider_name, model) = pair.split_once(':').unwrap_or((pair, ""));
140
141            let factory = self
142                .factories
143                .get(provider_name)
144                .ok_or_else(|| LlmError::Other(format!("Unknown provider: {provider_name}")))?;
145
146            let connection = self.provider_connections.config_for(provider_name);
147            providers.push(factory(model, connection).await?);
148
149            if first_identity.is_none() {
150                first_identity = Some(pair.parse::<LlmModel>().map_err(LlmError::Other)?);
151            }
152        }
153
154        let identity = first_identity.ok_or_else(|| LlmError::Other("No providers parsed".to_string()))?;
155
156        let provider: Box<dyn StreamingModelProvider> = if providers.len() == 1 {
157            providers.into_iter().next().ok_or_else(|| LlmError::Other("No providers available".to_string()))?
158        } else {
159            Box::new(AlloyedModelProvider::new(providers))
160        };
161
162        Ok((provider, identity))
163    }
164}
165
166/// Factory function type for creating model providers
167///
168/// Takes a model name and returns a boxed future that resolves to a `StreamingModelProvider`
169pub type CreateProviderFn = Box<
170    dyn Fn(&str, ProviderConnectionConfig) -> BoxFuture<'static, Result<Box<dyn StreamingModelProvider>>> + Send + Sync,
171>;
172
173#[cfg(test)]
174mod tests {
175    use super::*;
176
177    #[tokio::test]
178    async fn test_parse_llamacpp() {
179        let parser = ModelProviderParser::default();
180        let result = parser.parse("llamacpp").await;
181        assert!(result.is_ok());
182        let (_, model) = result.unwrap();
183        assert_eq!(model, LlmModel::LlamaCpp(String::new()));
184    }
185
186    #[tokio::test]
187    async fn test_parse_anthropic() {
188        let parser = ModelProviderParser::default();
189        let result = parser.parse("anthropic:claude-3-5-sonnet-20241022").await;
190        match result {
191            Ok((_, model)) => {
192                assert_eq!(model, LlmModel::Anthropic(crate::catalog::AnthropicModel::Claude35Sonnet20241022));
193            }
194            Err(e) => {
195                let err = e.to_string();
196                assert!(
197                    err.contains("API")
198                        || err.contains("ANTHROPIC")
199                        || err.contains("credentials")
200                        || err.contains("JSON"),
201                    "Should fail on API key or credentials, not parsing. Got: {err}"
202                );
203            }
204        }
205    }
206
207    #[tokio::test]
208    async fn test_parse_ollama() {
209        let parser = ModelProviderParser::default();
210        let result = parser.parse("ollama:llama3.2").await;
211        assert!(result.is_ok());
212        let (_, model) = result.unwrap();
213        assert_eq!(model, LlmModel::Ollama("llama3.2".to_string()));
214    }
215
216    #[tokio::test]
217    async fn test_parse_openai() {
218        let parser = ModelProviderParser::default();
219        let result = parser.parse("openai:gpt-4.1").await;
220        if let Err(e) = result {
221            let err = e.to_string();
222            assert!(err.contains("API") || err.contains("OPENAI"), "Should fail on API key, not parsing. Got: {err}");
223        }
224    }
225
226    #[tokio::test]
227    async fn test_parse_openrouter() {
228        let parser = ModelProviderParser::default();
229        let result = parser.parse("openrouter:google/gemini-2.5-flash").await;
230        if let Err(e) = result {
231            let err = e.to_string();
232            assert!(err.contains("API") || err.contains("OPENROUTER"), "Should fail on API key, not parsing");
233        }
234    }
235
236    #[tokio::test]
237    async fn test_parse_gemini() {
238        let parser = ModelProviderParser::default();
239        let result = parser.parse("gemini:gemini-2.5-flash").await;
240        if let Err(e) = result {
241            let err = e.to_string();
242            assert!(err.contains("API") || err.contains("GEMINI"), "Should fail on API key, not parsing");
243        }
244    }
245
246    #[tokio::test]
247    async fn test_parse_provider_without_model() {
248        let parser = ModelProviderParser::default();
249        let result = parser.parse("anthropic").await;
250        assert!(result.is_err());
251    }
252
253    #[cfg(feature = "bedrock")]
254    #[tokio::test]
255    async fn test_parse_bedrock_inference_profile_arn() {
256        // ARNs (including those with extra `:` separators) must round-trip through the parser
257        // without being misinterpreted as `provider:model` splits — the parser splits on the
258        // first `:` only, so everything after `bedrock:` becomes the model string.
259        let parser = ModelProviderParser::default();
260        let arn = "arn:aws:bedrock:us-west-2:000000000000:inference-profile/us.anthropic.claude-opus-4-7";
261        let spec = format!("bedrock:{arn}");
262        let (provider, model) = parser.parse(&spec).await.expect("Bedrock ARN should parse");
263
264        assert_eq!(model.to_string(), spec);
265        assert_eq!(provider.display_name(), format!("Bedrock ({arn})"));
266    }
267
268    #[cfg(feature = "bedrock")]
269    #[tokio::test]
270    async fn test_parse_bedrock_application_inference_profile_arn() {
271        let parser = ModelProviderParser::default();
272        let arn = "arn:aws:bedrock:us-west-2:000000000000:application-inference-profile/000000000000";
273        let spec = format!("bedrock:{arn}");
274
275        let (_, model) = parser.parse(&spec).await.expect("Application inference profile ARN should parse");
276        assert_eq!(model.to_string(), spec);
277        assert_eq!(model.context_window(), None);
278    }
279
280    #[tokio::test]
281    async fn test_parse_unknown_provider() {
282        let parser = ModelProviderParser::default();
283        let result = parser.parse("unknown:model").await;
284        assert!(result.is_err());
285        if let Err(e) = result {
286            assert!(e.to_string().contains("Unknown provider"));
287        }
288    }
289
290    #[tokio::test]
291    async fn test_with_custom_provider() {
292        let parser = ModelProviderParser::default().with_provider::<OllamaProvider>("custom");
293
294        let model = LlmModel::Ollama("test-model".to_string());
295        let result = parser.create_provider(&model).await;
296        assert!(result.is_ok());
297    }
298
299    #[tokio::test]
300    async fn test_parse_single_provider() {
301        let parser = ModelProviderParser::default();
302        let result = parser.parse("llamacpp").await;
303        assert!(result.is_ok());
304    }
305
306    #[tokio::test]
307    async fn test_parse_multiple_providers() {
308        let parser = ModelProviderParser::default();
309        let result = parser.parse("llamacpp,ollama:llama3.2").await;
310        assert!(result.is_ok());
311        let (_, model) = result.unwrap();
312        assert_eq!(model, LlmModel::LlamaCpp(String::new()));
313    }
314
315    #[tokio::test]
316    async fn test_parse_with_spaces() {
317        let parser = ModelProviderParser::default();
318        let result = parser.parse("llamacpp , ollama:llama3.2").await;
319        assert!(result.is_ok());
320    }
321
322    #[test]
323    fn test_parser_is_send_sync() {
324        fn assert_send_sync<T: Send + Sync>() {}
325        assert_send_sync::<ModelProviderParser>();
326    }
327}