Skip to main content

llm/
parser.rs

1use crate::Result;
2use crate::catalog::LlmModel;
3#[cfg(feature = "bedrock")]
4use crate::providers::bedrock::BedrockProvider;
5#[cfg(feature = "codex")]
6use crate::providers::codex::CodexProvider;
7use crate::providers::{
8    anthropic::AnthropicProvider,
9    gemini::GeminiProvider,
10    local::{llama_cpp::LlamaCppProvider, ollama::OllamaProvider},
11    openai::OpenAiProvider,
12    openai_compatible::generic::{self, GenericOpenAiProvider},
13    openrouter::OpenRouterProvider,
14};
15use crate::{LlmError, ProviderFactory, StreamingModelProvider, alloyed::AlloyedModelProvider};
16#[cfg(feature = "codex")]
17use aether_auth::OAuthCredentialStorage;
18use futures::future::BoxFuture;
19use std::collections::HashMap;
20#[cfg(feature = "codex")]
21use std::sync::Arc;
22
23#[doc = include_str!("docs/parser.md")]
24pub struct ModelProviderParser {
25    factories: HashMap<String, CreateProviderFn>,
26}
27
28impl ModelProviderParser {
29    /// Create a new parser with custom provider factories
30    pub fn new(factories: HashMap<String, CreateProviderFn>) -> Self {
31        Self { factories }
32    }
33}
34
35impl Default for ModelProviderParser {
36    /// Create a parser with all built-in providers registered
37    fn default() -> Self {
38        let parser = Self::new(HashMap::new())
39            .with_provider::<AnthropicProvider>("anthropic")
40            .with_provider::<GeminiProvider>("gemini")
41            .with_provider::<OpenRouterProvider>("openrouter")
42            .with_provider::<OllamaProvider>("ollama")
43            .with_provider::<LlamaCppProvider>("llamacpp")
44            .with_provider::<OpenAiProvider>("openai")
45            .with_openai_provider("deepseek", &generic::DEEPSEEK)
46            .with_openai_provider("moonshot", &generic::MOONSHOT)
47            .with_openai_provider("zai", &generic::ZAI);
48
49        #[cfg(feature = "bedrock")]
50        let parser = parser.with_provider::<BedrockProvider>("bedrock");
51
52        parser
53    }
54}
55
56impl ModelProviderParser {
57    pub fn with_provider<P: ProviderFactory + StreamingModelProvider + 'static>(
58        mut self,
59        name: impl Into<String>,
60    ) -> Self {
61        self.factories.insert(
62            name.into(),
63            Box::new(|model: &str| {
64                let model = model.to_string();
65                Box::pin(async move { Ok(Box::new(P::from_env().await?.with_model(&model)) as _) })
66            }),
67        );
68        self
69    }
70
71    #[cfg(feature = "codex")]
72    pub fn with_codex_provider(mut self, store: Arc<dyn OAuthCredentialStorage>) -> Self {
73        self.factories.insert(
74            "codex".to_string(),
75            Box::new(move |model: &str| {
76                let store = Arc::clone(&store);
77                let model = model.to_string();
78                Box::pin(async move { Ok(Box::new(CodexProvider::new(store).with_model(&model)) as _) })
79            }),
80        );
81        self
82    }
83
84    pub fn with_openai_provider(mut self, name: impl Into<String>, config: &'static generic::ProviderConfig) -> Self {
85        self.factories.insert(
86            name.into(),
87            Box::new(move |model: &str| {
88                let model = model.to_string();
89                Box::pin(async move { Ok(Box::new(GenericOpenAiProvider::from_env(config)?.with_model(&model)) as _) })
90            }),
91        );
92        self
93    }
94
95    /// Create a provider from a typed `LlmModel`
96    pub async fn create_provider(&self, model: &LlmModel) -> Result<Box<dyn StreamingModelProvider>> {
97        let key = model.provider();
98        let factory = self.factories.get(key).ok_or_else(|| LlmError::Other(format!("Unknown provider: {key}")))?;
99        factory(&model.model_id()).await
100    }
101
102    /// Parse a model specification string and create a provider instance.
103    ///
104    /// Returns both the provider and an `LlmModel` describing the identity
105    /// of the first (or only) provider in the spec.
106    ///
107    /// # Format
108    ///
109    /// - `"provider:model"` - Single provider (e.g., "anthropic:claude-3.5-sonnet")
110    /// - `"provider1:model1,provider2:model2"` - Multiple providers create an `AlloyedModelProvider`
111    ///
112    pub async fn parse(&self, models_str: &str) -> Result<(Box<dyn StreamingModelProvider>, LlmModel)> {
113        let provider_model_pairs: Vec<&str> = models_str.split(',').map(str::trim).collect();
114        if provider_model_pairs.is_empty() {
115            return Err(LlmError::Other("No models provided".to_string()));
116        }
117
118        let mut providers = Vec::new();
119        let mut first_identity: Option<LlmModel> = None;
120
121        for pair in provider_model_pairs {
122            let (provider_name, model) = pair.split_once(':').unwrap_or((pair, ""));
123
124            let factory = self
125                .factories
126                .get(provider_name)
127                .ok_or_else(|| LlmError::Other(format!("Unknown provider: {provider_name}")))?;
128
129            providers.push(factory(model).await?);
130
131            if first_identity.is_none() {
132                first_identity = Some(pair.parse::<LlmModel>().map_err(LlmError::Other)?);
133            }
134        }
135
136        let identity = first_identity.ok_or_else(|| LlmError::Other("No providers parsed".to_string()))?;
137
138        let provider: Box<dyn StreamingModelProvider> = if providers.len() == 1 {
139            providers.into_iter().next().ok_or_else(|| LlmError::Other("No providers available".to_string()))?
140        } else {
141            Box::new(AlloyedModelProvider::new(providers))
142        };
143
144        Ok((provider, identity))
145    }
146}
147
148/// Factory function type for creating model providers
149///
150/// Takes a model name and returns a boxed future that resolves to a `StreamingModelProvider`
151pub type CreateProviderFn =
152    Box<dyn Fn(&str) -> BoxFuture<'static, Result<Box<dyn StreamingModelProvider>>> + Send + Sync>;
153
154#[cfg(test)]
155mod tests {
156    use super::*;
157
158    #[tokio::test]
159    async fn test_parse_llamacpp() {
160        let parser = ModelProviderParser::default();
161        let result = parser.parse("llamacpp").await;
162        assert!(result.is_ok());
163        let (_, model) = result.unwrap();
164        assert_eq!(model, LlmModel::LlamaCpp(String::new()));
165    }
166
167    #[tokio::test]
168    async fn test_parse_anthropic() {
169        let parser = ModelProviderParser::default();
170        let result = parser.parse("anthropic:claude-3-5-sonnet-20241022").await;
171        match result {
172            Ok((_, model)) => {
173                assert_eq!(model, LlmModel::Anthropic(crate::catalog::AnthropicModel::Claude35Sonnet20241022));
174            }
175            Err(e) => {
176                let err = e.to_string();
177                assert!(
178                    err.contains("API")
179                        || err.contains("ANTHROPIC")
180                        || err.contains("credentials")
181                        || err.contains("JSON"),
182                    "Should fail on API key or credentials, not parsing. Got: {err}"
183                );
184            }
185        }
186    }
187
188    #[tokio::test]
189    async fn test_parse_ollama() {
190        let parser = ModelProviderParser::default();
191        let result = parser.parse("ollama:llama3.2").await;
192        assert!(result.is_ok());
193        let (_, model) = result.unwrap();
194        assert_eq!(model, LlmModel::Ollama("llama3.2".to_string()));
195    }
196
197    #[tokio::test]
198    async fn test_parse_openai() {
199        let parser = ModelProviderParser::default();
200        let result = parser.parse("openai:gpt-4.1").await;
201        if let Err(e) = result {
202            let err = e.to_string();
203            assert!(err.contains("API") || err.contains("OPENAI"), "Should fail on API key, not parsing. Got: {err}");
204        }
205    }
206
207    #[tokio::test]
208    async fn test_parse_openrouter() {
209        let parser = ModelProviderParser::default();
210        let result = parser.parse("openrouter:google/gemini-2.5-flash").await;
211        if let Err(e) = result {
212            let err = e.to_string();
213            assert!(err.contains("API") || err.contains("OPENROUTER"), "Should fail on API key, not parsing");
214        }
215    }
216
217    #[tokio::test]
218    async fn test_parse_gemini() {
219        let parser = ModelProviderParser::default();
220        let result = parser.parse("gemini:gemini-2.5-flash").await;
221        if let Err(e) = result {
222            let err = e.to_string();
223            assert!(err.contains("API") || err.contains("GEMINI"), "Should fail on API key, not parsing");
224        }
225    }
226
227    #[tokio::test]
228    async fn test_parse_provider_without_model() {
229        let parser = ModelProviderParser::default();
230        let result = parser.parse("anthropic").await;
231        assert!(result.is_err());
232    }
233
234    #[cfg(feature = "bedrock")]
235    #[tokio::test]
236    async fn test_parse_bedrock_inference_profile_arn() {
237        // ARNs (including those with extra `:` separators) must round-trip through the parser
238        // without being misinterpreted as `provider:model` splits — the parser splits on the
239        // first `:` only, so everything after `bedrock:` becomes the model string.
240        let parser = ModelProviderParser::default();
241        let arn = "arn:aws:bedrock:us-west-2:000000000000:inference-profile/us.anthropic.claude-opus-4-7";
242        let spec = format!("bedrock:{arn}");
243        let (provider, model) = parser.parse(&spec).await.expect("Bedrock ARN should parse");
244
245        assert_eq!(model.to_string(), spec);
246        assert_eq!(provider.display_name(), format!("Bedrock ({arn})"));
247    }
248
249    #[cfg(feature = "bedrock")]
250    #[tokio::test]
251    async fn test_parse_bedrock_application_inference_profile_arn() {
252        let parser = ModelProviderParser::default();
253        let arn = "arn:aws:bedrock:us-west-2:000000000000:application-inference-profile/000000000000";
254        let spec = format!("bedrock:{arn}");
255
256        let (_, model) = parser.parse(&spec).await.expect("Application inference profile ARN should parse");
257        assert_eq!(model.to_string(), spec);
258        assert_eq!(model.context_window(), None);
259    }
260
261    #[tokio::test]
262    async fn test_parse_unknown_provider() {
263        let parser = ModelProviderParser::default();
264        let result = parser.parse("unknown:model").await;
265        assert!(result.is_err());
266        if let Err(e) = result {
267            assert!(e.to_string().contains("Unknown provider"));
268        }
269    }
270
271    #[tokio::test]
272    async fn test_with_custom_provider() {
273        let parser = ModelProviderParser::default().with_provider::<OllamaProvider>("custom");
274
275        let model = LlmModel::Ollama("test-model".to_string());
276        let result = parser.create_provider(&model).await;
277        assert!(result.is_ok());
278    }
279
280    #[tokio::test]
281    async fn test_parse_single_provider() {
282        let parser = ModelProviderParser::default();
283        let result = parser.parse("llamacpp").await;
284        assert!(result.is_ok());
285    }
286
287    #[tokio::test]
288    async fn test_parse_multiple_providers() {
289        let parser = ModelProviderParser::default();
290        let result = parser.parse("llamacpp,ollama:llama3.2").await;
291        assert!(result.is_ok());
292        let (_, model) = result.unwrap();
293        assert_eq!(model, LlmModel::LlamaCpp(String::new()));
294    }
295
296    #[tokio::test]
297    async fn test_parse_with_spaces() {
298        let parser = ModelProviderParser::default();
299        let result = parser.parse("llamacpp , ollama:llama3.2").await;
300        assert!(result.is_ok());
301    }
302
303    #[test]
304    fn test_parser_is_send_sync() {
305        fn assert_send_sync<T: Send + Sync>() {}
306        assert_send_sync::<ModelProviderParser>();
307    }
308}