Skip to main content

llm/
parser.rs

1use crate::Result;
2use crate::catalog::LlmModel;
3#[cfg(feature = "bedrock")]
4use crate::providers::bedrock::BedrockProvider;
5#[cfg(feature = "codex")]
6use crate::providers::codex::CodexProvider;
7use crate::providers::{
8    anthropic::AnthropicProvider,
9    gemini::GeminiProvider,
10    local::{llama_cpp::LlamaCppProvider, ollama::OllamaProvider},
11    openai::OpenAiProvider,
12    openai_compatible::generic::{self, GenericOpenAiProvider},
13    openrouter::OpenRouterProvider,
14};
15use crate::{LlmError, ProviderFactory, StreamingModelProvider, alloyed::AlloyedModelProvider};
16use futures::future::BoxFuture;
17use std::collections::HashMap;
18
19#[doc = include_str!("docs/parser.md")]
20pub struct ModelProviderParser {
21    factories: HashMap<String, CreateProviderFn>,
22}
23
24impl ModelProviderParser {
25    /// Create a new parser with custom provider factories
26    pub fn new(factories: HashMap<String, CreateProviderFn>) -> Self {
27        Self { factories }
28    }
29}
30
31impl Default for ModelProviderParser {
32    /// Create a parser with all built-in providers registered
33    fn default() -> Self {
34        let parser = Self::new(HashMap::new())
35            .with_provider::<AnthropicProvider>("anthropic")
36            .with_provider::<GeminiProvider>("gemini")
37            .with_provider::<OpenRouterProvider>("openrouter")
38            .with_provider::<OllamaProvider>("ollama")
39            .with_provider::<LlamaCppProvider>("llamacpp")
40            .with_provider::<OpenAiProvider>("openai")
41            .with_openai_provider("deepseek", &generic::DEEPSEEK)
42            .with_openai_provider("moonshot", &generic::MOONSHOT)
43            .with_openai_provider("zai", &generic::ZAI);
44
45        #[cfg(feature = "bedrock")]
46        let parser = parser.with_provider::<BedrockProvider>("bedrock");
47
48        #[cfg(feature = "codex")]
49        let parser = parser.with_provider::<CodexProvider>("codex");
50
51        parser
52    }
53}
54
55impl ModelProviderParser {
56    pub fn with_provider<P: ProviderFactory + StreamingModelProvider + 'static>(
57        mut self,
58        name: impl Into<String>,
59    ) -> Self {
60        self.factories.insert(
61            name.into(),
62            Box::new(|model: &str| {
63                let model = model.to_string();
64                Box::pin(async move { Ok(Box::new(P::from_env().await?.with_model(&model)) as _) })
65            }),
66        );
67        self
68    }
69
70    pub fn with_openai_provider(mut self, name: impl Into<String>, config: &'static generic::ProviderConfig) -> Self {
71        self.factories.insert(
72            name.into(),
73            Box::new(move |model: &str| {
74                let model = model.to_string();
75                Box::pin(async move { Ok(Box::new(GenericOpenAiProvider::from_env(config)?.with_model(&model)) as _) })
76            }),
77        );
78        self
79    }
80
81    /// Create a provider from a typed `LlmModel`
82    pub async fn create_provider(&self, model: &LlmModel) -> Result<Box<dyn StreamingModelProvider>> {
83        let key = model.provider();
84        let factory = self.factories.get(key).ok_or_else(|| LlmError::Other(format!("Unknown provider: {key}")))?;
85        factory(&model.model_id()).await
86    }
87
88    /// Parse a model specification string and create a provider instance.
89    ///
90    /// Returns both the provider and an `LlmModel` describing the identity
91    /// of the first (or only) provider in the spec.
92    ///
93    /// # Format
94    ///
95    /// - `"provider:model"` - Single provider (e.g., "anthropic:claude-3.5-sonnet")
96    /// - `"provider1:model1,provider2:model2"` - Multiple providers create an `AlloyedModelProvider`
97    ///
98    pub async fn parse(&self, models_str: &str) -> Result<(Box<dyn StreamingModelProvider>, LlmModel)> {
99        let provider_model_pairs: Vec<&str> = models_str.split(',').map(str::trim).collect();
100        if provider_model_pairs.is_empty() {
101            return Err(LlmError::Other("No models provided".to_string()));
102        }
103
104        let mut providers = Vec::new();
105        let mut first_identity: Option<LlmModel> = None;
106
107        for pair in provider_model_pairs {
108            let (provider_name, model) = pair.split_once(':').unwrap_or((pair, ""));
109
110            let factory = self
111                .factories
112                .get(provider_name)
113                .ok_or_else(|| LlmError::Other(format!("Unknown provider: {provider_name}")))?;
114
115            providers.push(factory(model).await?);
116
117            if first_identity.is_none() {
118                first_identity = Some(pair.parse::<LlmModel>().map_err(LlmError::Other)?);
119            }
120        }
121
122        let identity = first_identity.ok_or_else(|| LlmError::Other("No providers parsed".to_string()))?;
123
124        let provider: Box<dyn StreamingModelProvider> = if providers.len() == 1 {
125            providers.into_iter().next().ok_or_else(|| LlmError::Other("No providers available".to_string()))?
126        } else {
127            Box::new(AlloyedModelProvider::new(providers))
128        };
129
130        Ok((provider, identity))
131    }
132}
133
134/// Factory function type for creating model providers
135///
136/// Takes a model name and returns a boxed future that resolves to a `StreamingModelProvider`
137pub type CreateProviderFn =
138    Box<dyn Fn(&str) -> BoxFuture<'static, Result<Box<dyn StreamingModelProvider>>> + Send + Sync>;
139
140#[cfg(test)]
141mod tests {
142    use super::*;
143
144    #[tokio::test]
145    async fn test_parse_llamacpp() {
146        let parser = ModelProviderParser::default();
147        let result = parser.parse("llamacpp").await;
148        assert!(result.is_ok());
149        let (_, model) = result.unwrap();
150        assert_eq!(model, LlmModel::LlamaCpp(String::new()));
151    }
152
153    #[tokio::test]
154    async fn test_parse_anthropic() {
155        let parser = ModelProviderParser::default();
156        let result = parser.parse("anthropic:claude-3-5-sonnet-20241022").await;
157        match result {
158            Ok((_, model)) => {
159                assert_eq!(model, LlmModel::Anthropic(crate::catalog::AnthropicModel::Claude35Sonnet20241022));
160            }
161            Err(e) => {
162                let err = e.to_string();
163                assert!(
164                    err.contains("API")
165                        || err.contains("ANTHROPIC")
166                        || err.contains("credentials")
167                        || err.contains("JSON"),
168                    "Should fail on API key or credentials, not parsing. Got: {err}"
169                );
170            }
171        }
172    }
173
174    #[tokio::test]
175    async fn test_parse_ollama() {
176        let parser = ModelProviderParser::default();
177        let result = parser.parse("ollama:llama3.2").await;
178        assert!(result.is_ok());
179        let (_, model) = result.unwrap();
180        assert_eq!(model, LlmModel::Ollama("llama3.2".to_string()));
181    }
182
183    #[tokio::test]
184    async fn test_parse_openai() {
185        let parser = ModelProviderParser::default();
186        let result = parser.parse("openai:gpt-4.1").await;
187        if let Err(e) = result {
188            let err = e.to_string();
189            assert!(err.contains("API") || err.contains("OPENAI"), "Should fail on API key, not parsing. Got: {err}");
190        }
191    }
192
193    #[tokio::test]
194    async fn test_parse_openrouter() {
195        let parser = ModelProviderParser::default();
196        let result = parser.parse("openrouter:google/gemini-2.5-flash").await;
197        if let Err(e) = result {
198            let err = e.to_string();
199            assert!(err.contains("API") || err.contains("OPENROUTER"), "Should fail on API key, not parsing");
200        }
201    }
202
203    #[tokio::test]
204    async fn test_parse_gemini() {
205        let parser = ModelProviderParser::default();
206        let result = parser.parse("gemini:gemini-2.5-flash").await;
207        if let Err(e) = result {
208            let err = e.to_string();
209            assert!(err.contains("API") || err.contains("GEMINI"), "Should fail on API key, not parsing");
210        }
211    }
212
213    #[tokio::test]
214    async fn test_parse_provider_without_model() {
215        let parser = ModelProviderParser::default();
216        let result = parser.parse("anthropic").await;
217        assert!(result.is_err());
218    }
219
220    #[tokio::test]
221    async fn test_parse_unknown_provider() {
222        let parser = ModelProviderParser::default();
223        let result = parser.parse("unknown:model").await;
224        assert!(result.is_err());
225        if let Err(e) = result {
226            assert!(e.to_string().contains("Unknown provider"));
227        }
228    }
229
230    #[tokio::test]
231    async fn test_with_custom_provider() {
232        let parser = ModelProviderParser::default().with_provider::<OllamaProvider>("custom");
233
234        let model = LlmModel::Ollama("test-model".to_string());
235        let result = parser.create_provider(&model).await;
236        assert!(result.is_ok());
237    }
238
239    #[tokio::test]
240    async fn test_parse_single_provider() {
241        let parser = ModelProviderParser::default();
242        let result = parser.parse("llamacpp").await;
243        assert!(result.is_ok());
244    }
245
246    #[tokio::test]
247    async fn test_parse_multiple_providers() {
248        let parser = ModelProviderParser::default();
249        let result = parser.parse("llamacpp,ollama:llama3.2").await;
250        assert!(result.is_ok());
251        let (_, model) = result.unwrap();
252        assert_eq!(model, LlmModel::LlamaCpp(String::new()));
253    }
254
255    #[tokio::test]
256    async fn test_parse_with_spaces() {
257        let parser = ModelProviderParser::default();
258        let result = parser.parse("llamacpp , ollama:llama3.2").await;
259        assert!(result.is_ok());
260    }
261
262    #[test]
263    fn test_parser_is_send_sync() {
264        fn assert_send_sync<T: Send + Sync>() {}
265        assert_send_sync::<ModelProviderParser>();
266    }
267}