Skip to main content

llm/
parser.rs

1use crate::Result;
2use crate::catalog::LlmModel;
3#[cfg(feature = "bedrock")]
4use crate::providers::bedrock::BedrockProvider;
5#[cfg(feature = "codex")]
6use crate::providers::codex::CodexProvider;
7use crate::providers::{
8    anthropic::AnthropicProvider,
9    gemini::GeminiProvider,
10    local::{llama_cpp::LlamaCppProvider, ollama::OllamaProvider},
11    openai::OpenAiProvider,
12    openai_compatible::generic::{self, GenericOpenAiProvider},
13    openrouter::OpenRouterProvider,
14};
15use crate::{LlmError, ProviderFactory, StreamingModelProvider, alloyed::AlloyedModelProvider};
16#[cfg(feature = "codex")]
17use aether_auth::OAuthCredentialStorage;
18use futures::future::BoxFuture;
19use std::collections::HashMap;
20#[cfg(feature = "codex")]
21use std::sync::Arc;
22
23#[doc = include_str!("docs/parser.md")]
24pub struct ModelProviderParser {
25    factories: HashMap<String, CreateProviderFn>,
26}
27
28impl ModelProviderParser {
29    /// Create a new parser with custom provider factories
30    pub fn new(factories: HashMap<String, CreateProviderFn>) -> Self {
31        Self { factories }
32    }
33}
34
35impl Default for ModelProviderParser {
36    /// Create a parser with all built-in providers registered
37    fn default() -> Self {
38        let parser = Self::new(HashMap::new())
39            .with_provider::<AnthropicProvider>("anthropic")
40            .with_provider::<GeminiProvider>("gemini")
41            .with_provider::<OpenRouterProvider>("openrouter")
42            .with_provider::<OllamaProvider>("ollama")
43            .with_provider::<LlamaCppProvider>("llamacpp")
44            .with_provider::<OpenAiProvider>("openai")
45            .with_openai_provider("deepseek", &generic::DEEPSEEK)
46            .with_openai_provider("moonshot", &generic::MOONSHOT)
47            .with_openai_provider("zai", &generic::ZAI);
48
49        #[cfg(feature = "bedrock")]
50        let parser = parser.with_provider::<BedrockProvider>("bedrock");
51
52        parser
53    }
54}
55
56impl ModelProviderParser {
57    pub fn with_provider<P: ProviderFactory + StreamingModelProvider + 'static>(
58        mut self,
59        name: impl Into<String>,
60    ) -> Self {
61        self.factories.insert(
62            name.into(),
63            Box::new(|model: &str| {
64                let model = model.to_string();
65                Box::pin(async move { Ok(Box::new(P::from_env().await?.with_model(&model)) as _) })
66            }),
67        );
68        self
69    }
70
71    #[cfg(feature = "codex")]
72    pub fn with_codex_provider(mut self, store: Arc<dyn OAuthCredentialStorage>) -> Self {
73        self.factories.insert(
74            "codex".to_string(),
75            Box::new(move |model: &str| {
76                let store = Arc::clone(&store);
77                let model = model.to_string();
78                Box::pin(async move { Ok(Box::new(CodexProvider::new(store).with_model(&model)) as _) })
79            }),
80        );
81        self
82    }
83
84    pub fn with_openai_provider(mut self, name: impl Into<String>, config: &'static generic::ProviderConfig) -> Self {
85        self.factories.insert(
86            name.into(),
87            Box::new(move |model: &str| {
88                let model = model.to_string();
89                Box::pin(async move { Ok(Box::new(GenericOpenAiProvider::from_env(config)?.with_model(&model)) as _) })
90            }),
91        );
92        self
93    }
94
95    /// Create a provider from a typed `LlmModel`
96    pub async fn create_provider(&self, model: &LlmModel) -> Result<Box<dyn StreamingModelProvider>> {
97        let key = model.provider();
98        let factory = self.factories.get(key).ok_or_else(|| LlmError::Other(format!("Unknown provider: {key}")))?;
99        factory(&model.model_id()).await
100    }
101
102    /// Parse a model specification string and create a provider instance.
103    ///
104    /// Returns both the provider and an `LlmModel` describing the identity
105    /// of the first (or only) provider in the spec.
106    ///
107    /// # Format
108    ///
109    /// - `"provider:model"` - Single provider (e.g., "anthropic:claude-3.5-sonnet")
110    /// - `"provider1:model1,provider2:model2"` - Multiple providers create an `AlloyedModelProvider`
111    ///
112    pub async fn parse(&self, models_str: &str) -> Result<(Box<dyn StreamingModelProvider>, LlmModel)> {
113        let provider_model_pairs: Vec<&str> = models_str.split(',').map(str::trim).collect();
114        if provider_model_pairs.is_empty() {
115            return Err(LlmError::Other("No models provided".to_string()));
116        }
117
118        let mut providers = Vec::new();
119        let mut first_identity: Option<LlmModel> = None;
120
121        for pair in provider_model_pairs {
122            let (provider_name, model) = pair.split_once(':').unwrap_or((pair, ""));
123
124            let factory = self
125                .factories
126                .get(provider_name)
127                .ok_or_else(|| LlmError::Other(format!("Unknown provider: {provider_name}")))?;
128
129            providers.push(factory(model).await?);
130
131            if first_identity.is_none() {
132                first_identity = Some(pair.parse::<LlmModel>().map_err(LlmError::Other)?);
133            }
134        }
135
136        let identity = first_identity.ok_or_else(|| LlmError::Other("No providers parsed".to_string()))?;
137
138        let provider: Box<dyn StreamingModelProvider> = if providers.len() == 1 {
139            providers.into_iter().next().ok_or_else(|| LlmError::Other("No providers available".to_string()))?
140        } else {
141            Box::new(AlloyedModelProvider::new(providers))
142        };
143
144        Ok((provider, identity))
145    }
146}
147
148/// Factory function type for creating model providers
149///
150/// Takes a model name and returns a boxed future that resolves to a `StreamingModelProvider`
151pub type CreateProviderFn =
152    Box<dyn Fn(&str) -> BoxFuture<'static, Result<Box<dyn StreamingModelProvider>>> + Send + Sync>;
153
154#[cfg(test)]
155mod tests {
156    use super::*;
157
158    #[tokio::test]
159    async fn test_parse_llamacpp() {
160        let parser = ModelProviderParser::default();
161        let result = parser.parse("llamacpp").await;
162        assert!(result.is_ok());
163        let (_, model) = result.unwrap();
164        assert_eq!(model, LlmModel::LlamaCpp(String::new()));
165    }
166
167    #[tokio::test]
168    async fn test_parse_anthropic() {
169        let parser = ModelProviderParser::default();
170        let result = parser.parse("anthropic:claude-3-5-sonnet-20241022").await;
171        match result {
172            Ok((_, model)) => {
173                assert_eq!(model, LlmModel::Anthropic(crate::catalog::AnthropicModel::Claude35Sonnet20241022));
174            }
175            Err(e) => {
176                let err = e.to_string();
177                assert!(
178                    err.contains("API")
179                        || err.contains("ANTHROPIC")
180                        || err.contains("credentials")
181                        || err.contains("JSON"),
182                    "Should fail on API key or credentials, not parsing. Got: {err}"
183                );
184            }
185        }
186    }
187
188    #[tokio::test]
189    async fn test_parse_ollama() {
190        let parser = ModelProviderParser::default();
191        let result = parser.parse("ollama:llama3.2").await;
192        assert!(result.is_ok());
193        let (_, model) = result.unwrap();
194        assert_eq!(model, LlmModel::Ollama("llama3.2".to_string()));
195    }
196
197    #[tokio::test]
198    async fn test_parse_openai() {
199        let parser = ModelProviderParser::default();
200        let result = parser.parse("openai:gpt-4.1").await;
201        if let Err(e) = result {
202            let err = e.to_string();
203            assert!(err.contains("API") || err.contains("OPENAI"), "Should fail on API key, not parsing. Got: {err}");
204        }
205    }
206
207    #[tokio::test]
208    async fn test_parse_openrouter() {
209        let parser = ModelProviderParser::default();
210        let result = parser.parse("openrouter:google/gemini-2.5-flash").await;
211        if let Err(e) = result {
212            let err = e.to_string();
213            assert!(err.contains("API") || err.contains("OPENROUTER"), "Should fail on API key, not parsing");
214        }
215    }
216
217    #[tokio::test]
218    async fn test_parse_gemini() {
219        let parser = ModelProviderParser::default();
220        let result = parser.parse("gemini:gemini-2.5-flash").await;
221        if let Err(e) = result {
222            let err = e.to_string();
223            assert!(err.contains("API") || err.contains("GEMINI"), "Should fail on API key, not parsing");
224        }
225    }
226
227    #[tokio::test]
228    async fn test_parse_provider_without_model() {
229        let parser = ModelProviderParser::default();
230        let result = parser.parse("anthropic").await;
231        assert!(result.is_err());
232    }
233
234    #[tokio::test]
235    async fn test_parse_unknown_provider() {
236        let parser = ModelProviderParser::default();
237        let result = parser.parse("unknown:model").await;
238        assert!(result.is_err());
239        if let Err(e) = result {
240            assert!(e.to_string().contains("Unknown provider"));
241        }
242    }
243
244    #[tokio::test]
245    async fn test_with_custom_provider() {
246        let parser = ModelProviderParser::default().with_provider::<OllamaProvider>("custom");
247
248        let model = LlmModel::Ollama("test-model".to_string());
249        let result = parser.create_provider(&model).await;
250        assert!(result.is_ok());
251    }
252
253    #[tokio::test]
254    async fn test_parse_single_provider() {
255        let parser = ModelProviderParser::default();
256        let result = parser.parse("llamacpp").await;
257        assert!(result.is_ok());
258    }
259
260    #[tokio::test]
261    async fn test_parse_multiple_providers() {
262        let parser = ModelProviderParser::default();
263        let result = parser.parse("llamacpp,ollama:llama3.2").await;
264        assert!(result.is_ok());
265        let (_, model) = result.unwrap();
266        assert_eq!(model, LlmModel::LlamaCpp(String::new()));
267    }
268
269    #[tokio::test]
270    async fn test_parse_with_spaces() {
271        let parser = ModelProviderParser::default();
272        let result = parser.parse("llamacpp , ollama:llama3.2").await;
273        assert!(result.is_ok());
274    }
275
276    #[test]
277    fn test_parser_is_send_sync() {
278        fn assert_send_sync<T: Send + Sync>() {}
279        assert_send_sync::<ModelProviderParser>();
280    }
281}