Skip to main content

llm/
parser.rs

1use crate::Result;
2use crate::catalog::LlmModel;
3#[cfg(feature = "bedrock")]
4use crate::providers::bedrock::BedrockProvider;
5#[cfg(feature = "codex")]
6use crate::providers::codex::CodexProvider;
7use crate::providers::{
8    anthropic::AnthropicProvider,
9    gemini::GeminiProvider,
10    local::{llama_cpp::LlamaCppProvider, ollama::OllamaProvider},
11    openai::OpenAiProvider,
12    openai_compatible::generic::{self, GenericOpenAiProvider},
13    openrouter::OpenRouterProvider,
14};
15use crate::{
16    LlmError, ProviderConnectionConfig, ProviderConnectionOverrides, ProviderFactory, StreamingModelProvider,
17    alloyed::AlloyedModelProvider,
18};
19#[cfg(feature = "codex")]
20use aether_auth::OAuthCredentialStorage;
21use futures::future::BoxFuture;
22use std::collections::HashMap;
23#[cfg(feature = "codex")]
24use std::sync::Arc;
25
26#[doc = include_str!("docs/parser.md")]
27pub struct ModelProviderParser {
28    factories: HashMap<String, CreateProviderFn>,
29    provider_connections: ProviderConnectionOverrides,
30}
31
32impl ModelProviderParser {
33    /// Create a new parser with custom provider factories
34    pub fn new(factories: HashMap<String, CreateProviderFn>) -> Self {
35        Self { factories, provider_connections: ProviderConnectionOverrides::default() }
36    }
37}
38
39impl Default for ModelProviderParser {
40    /// Create a parser with all built-in providers registered
41    fn default() -> Self {
42        let parser = Self::new(HashMap::new())
43            .with_provider::<AnthropicProvider>("anthropic")
44            .with_provider::<GeminiProvider>("gemini")
45            .with_provider::<OpenRouterProvider>("openrouter")
46            .with_provider::<OllamaProvider>("ollama")
47            .with_provider::<LlamaCppProvider>("llamacpp")
48            .with_provider::<OpenAiProvider>("openai")
49            .with_openai_provider("deepseek", &generic::DEEPSEEK)
50            .with_openai_provider("moonshot", &generic::MOONSHOT)
51            .with_openai_provider("zai", &generic::ZAI);
52
53        #[cfg(feature = "bedrock")]
54        let parser = parser.with_provider::<BedrockProvider>("bedrock");
55
56        parser
57    }
58}
59
60impl ModelProviderParser {
61    pub fn with_provider_connections(mut self, connections: ProviderConnectionOverrides) -> Self {
62        self.provider_connections = connections;
63        self
64    }
65
66    pub fn with_provider<P: ProviderFactory + StreamingModelProvider + 'static>(
67        mut self,
68        name: impl Into<String>,
69    ) -> Self {
70        self.factories.insert(
71            name.into(),
72            Box::new(|model: &str, connection: ProviderConnectionConfig| {
73                let model = model.to_string();
74                Box::pin(
75                    async move { Ok(Box::new(P::from_env_with_connection(connection).await?.with_model(&model)) as _) },
76                )
77            }),
78        );
79        self
80    }
81
82    #[cfg(feature = "codex")]
83    pub fn with_codex_provider(mut self, store: Arc<dyn OAuthCredentialStorage>) -> Self {
84        self.factories.insert(
85            "codex".to_string(),
86            Box::new(move |model: &str, _connection: ProviderConnectionConfig| {
87                let store = Arc::clone(&store);
88                let model = model.to_string();
89                Box::pin(async move { Ok(Box::new(CodexProvider::new(store).with_model(&model)) as _) })
90            }),
91        );
92        self
93    }
94
95    pub fn with_openai_provider(mut self, name: impl Into<String>, config: &'static generic::ProviderConfig) -> Self {
96        self.factories.insert(
97            name.into(),
98            Box::new(move |model: &str, connection: ProviderConnectionConfig| {
99                let model = model.to_string();
100                Box::pin(async move {
101                    Ok(
102                        Box::new(
103                            GenericOpenAiProvider::from_env_with_connection(config, connection)?.with_model(&model),
104                        ) as _,
105                    )
106                })
107            }),
108        );
109        self
110    }
111
112    /// Create a provider from a typed `LlmModel`
113    pub async fn create_provider(&self, model: &LlmModel) -> Result<Box<dyn StreamingModelProvider>> {
114        let key = model.provider();
115        let factory = self.factories.get(key).ok_or_else(|| LlmError::Other(format!("Unknown provider: {key}")))?;
116        factory(&model.model_id(), self.provider_connections.config_for(key)).await
117    }
118
119    /// Parse a model specification string and create a provider instance.
120    ///
121    /// Returns both the provider and an `LlmModel` describing the identity
122    /// of the first (or only) provider in the spec.
123    ///
124    /// # Format
125    ///
126    /// - `"provider:model"` - Single provider (e.g., "anthropic:claude-3.5-sonnet")
127    /// - `"provider1:model1,provider2:model2"` - Multiple providers create an `AlloyedModelProvider`
128    ///
129    pub async fn parse(&self, models_str: &str) -> Result<(Box<dyn StreamingModelProvider>, LlmModel)> {
130        let provider_model_pairs: Vec<&str> = models_str.split(',').map(str::trim).collect();
131        if provider_model_pairs.is_empty() {
132            return Err(LlmError::Other("No models provided".to_string()));
133        }
134
135        let bedrock_has_inference_profile_arn =
136            self.provider_connections.config_for("bedrock").inference_profile_arn.is_some();
137        let mut seen_bedrock = false;
138        let mut providers = Vec::new();
139        let mut first_identity: Option<LlmModel> = None;
140
141        for pair in provider_model_pairs {
142            let (provider_name, model) = pair.split_once(':').unwrap_or((pair, ""));
143
144            if provider_name == "bedrock" && bedrock_has_inference_profile_arn {
145                if seen_bedrock {
146                    return Err(LlmError::Other(
147                        "providers.bedrock.inferenceProfileArn cannot be used with multiple bedrock models in one alloy spec"
148                            .to_string(),
149                    ));
150                }
151                seen_bedrock = true;
152            }
153
154            let factory = self
155                .factories
156                .get(provider_name)
157                .ok_or_else(|| LlmError::Other(format!("Unknown provider: {provider_name}")))?;
158
159            let connection = self.provider_connections.config_for(provider_name);
160            providers.push(factory(model, connection).await?);
161
162            if first_identity.is_none() {
163                first_identity = Some(pair.parse::<LlmModel>().map_err(LlmError::Other)?);
164            }
165        }
166
167        let identity = first_identity.ok_or_else(|| LlmError::Other("No providers parsed".to_string()))?;
168
169        let provider: Box<dyn StreamingModelProvider> = if providers.len() == 1 {
170            providers.into_iter().next().ok_or_else(|| LlmError::Other("No providers available".to_string()))?
171        } else {
172            Box::new(AlloyedModelProvider::new(providers))
173        };
174
175        Ok((provider, identity))
176    }
177}
178
179/// Factory function type for creating model providers
180///
181/// Takes a model name and returns a boxed future that resolves to a `StreamingModelProvider`
182pub type CreateProviderFn = Box<
183    dyn Fn(&str, ProviderConnectionConfig) -> BoxFuture<'static, Result<Box<dyn StreamingModelProvider>>> + Send + Sync,
184>;
185
186#[cfg(test)]
187mod tests {
188    use super::*;
189
190    #[tokio::test]
191    async fn test_parse_llamacpp() {
192        let parser = ModelProviderParser::default();
193        let result = parser.parse("llamacpp").await;
194        assert!(result.is_ok());
195        let (_, model) = result.unwrap();
196        assert_eq!(model, LlmModel::LlamaCpp(String::new()));
197    }
198
199    #[tokio::test]
200    async fn test_parse_anthropic() {
201        let parser = ModelProviderParser::default();
202        let result = parser.parse("anthropic:claude-3-5-sonnet-20241022").await;
203        match result {
204            Ok((_, model)) => {
205                assert_eq!(model, LlmModel::Anthropic(crate::catalog::AnthropicModel::Claude35Sonnet20241022));
206            }
207            Err(e) => {
208                let err = e.to_string();
209                assert!(
210                    err.contains("API")
211                        || err.contains("ANTHROPIC")
212                        || err.contains("credentials")
213                        || err.contains("JSON"),
214                    "Should fail on API key or credentials, not parsing. Got: {err}"
215                );
216            }
217        }
218    }
219
220    #[tokio::test]
221    async fn test_parse_ollama() {
222        let parser = ModelProviderParser::default();
223        let result = parser.parse("ollama:llama3.2").await;
224        assert!(result.is_ok());
225        let (_, model) = result.unwrap();
226        assert_eq!(model, LlmModel::Ollama("llama3.2".to_string()));
227    }
228
229    #[tokio::test]
230    async fn test_parse_openai() {
231        let parser = ModelProviderParser::default();
232        let result = parser.parse("openai:gpt-4.1").await;
233        if let Err(e) = result {
234            let err = e.to_string();
235            assert!(err.contains("API") || err.contains("OPENAI"), "Should fail on API key, not parsing. Got: {err}");
236        }
237    }
238
239    #[tokio::test]
240    async fn test_parse_openrouter() {
241        let parser = ModelProviderParser::default();
242        let result = parser.parse("openrouter:google/gemini-2.5-flash").await;
243        if let Err(e) = result {
244            let err = e.to_string();
245            assert!(err.contains("API") || err.contains("OPENROUTER"), "Should fail on API key, not parsing");
246        }
247    }
248
249    #[tokio::test]
250    async fn test_parse_gemini() {
251        let parser = ModelProviderParser::default();
252        let result = parser.parse("gemini:gemini-2.5-flash").await;
253        if let Err(e) = result {
254            let err = e.to_string();
255            assert!(err.contains("API") || err.contains("GEMINI"), "Should fail on API key, not parsing");
256        }
257    }
258
259    #[tokio::test]
260    async fn test_parse_provider_without_model() {
261        let parser = ModelProviderParser::default();
262        let result = parser.parse("anthropic").await;
263        assert!(result.is_err());
264    }
265
266    #[cfg(feature = "bedrock")]
267    #[tokio::test]
268    async fn test_parse_rejects_bedrock_inference_profile_arn() {
269        let parser = ModelProviderParser::default();
270        let spec = "bedrock:arn:aws:bedrock:us-west-2:000000000000:inference-profile/us.anthropic.claude-opus-4-7";
271
272        let error = match parser.parse(spec).await {
273            Ok(_) => panic!("Bedrock ARN should be rejected"),
274            Err(error) => error.to_string(),
275        };
276
277        assert!(error.contains("providers.bedrock.inferenceProfileArn"), "{error}");
278    }
279
280    #[cfg(feature = "bedrock")]
281    #[tokio::test]
282    async fn test_parse_rejects_bedrock_application_inference_profile_arn() {
283        let parser = ModelProviderParser::default();
284        let spec = "bedrock:arn:aws:bedrock:us-west-2:000000000000:application-inference-profile/000000000000";
285
286        let error = match parser.parse(spec).await {
287            Ok(_) => panic!("Bedrock ARN should be rejected"),
288            Err(error) => error.to_string(),
289        };
290
291        assert!(error.contains("providers.bedrock.inferenceProfileArn"), "{error}");
292    }
293
294    #[tokio::test]
295    async fn test_parse_unknown_provider() {
296        let parser = ModelProviderParser::default();
297        let result = parser.parse("unknown:model").await;
298        assert!(result.is_err());
299        if let Err(e) = result {
300            assert!(e.to_string().contains("Unknown provider"));
301        }
302    }
303
304    #[tokio::test]
305    async fn test_with_custom_provider() {
306        let parser = ModelProviderParser::default().with_provider::<OllamaProvider>("custom");
307
308        let model = LlmModel::Ollama("test-model".to_string());
309        let result = parser.create_provider(&model).await;
310        assert!(result.is_ok());
311    }
312
313    #[tokio::test]
314    async fn test_parse_single_provider() {
315        let parser = ModelProviderParser::default();
316        let result = parser.parse("llamacpp").await;
317        assert!(result.is_ok());
318    }
319
320    #[tokio::test]
321    async fn test_parse_multiple_providers() {
322        let parser = ModelProviderParser::default();
323        let result = parser.parse("llamacpp,ollama:llama3.2").await;
324        assert!(result.is_ok());
325        let (_, model) = result.unwrap();
326        assert_eq!(model, LlmModel::LlamaCpp(String::new()));
327    }
328
329    #[tokio::test]
330    async fn test_parse_with_spaces() {
331        let parser = ModelProviderParser::default();
332        let result = parser.parse("llamacpp , ollama:llama3.2").await;
333        assert!(result.is_ok());
334    }
335
336    #[test]
337    fn test_parser_is_send_sync() {
338        fn assert_send_sync<T: Send + Sync>() {}
339        assert_send_sync::<ModelProviderParser>();
340    }
341}