Skip to main content

mermaid_cli/providers/
factory.rs

1//! Runtime provider construction.
2//!
3//! `ProviderFactory` turns `(Config, model_id)` into the right
4//! `Arc<dyn ModelProvider>`. The effect runner holds one of these
5//! and asks it to build a provider the first time a new model is
6//! referenced; subsequent lookups hit the cache.
7
8use std::sync::Arc;
9
10use tokio::sync::Mutex;
11
12use crate::app::Config;
13use crate::models::config::BackendConfig;
14use crate::models::{ModelError, Result, lookup_provider};
15use crate::utils::{resolve_api_key, resolve_api_key_with_fallback};
16
17const GEMINI_API_KEY_ENV: &str = "GOOGLE_API_KEY";
18const GEMINI_LEGACY_API_KEY_ENV: &str = "GEMINI_API_KEY";
19
20/// Resolve an API key or return a clear `ModelError` when the env
21/// var isn't set. Takes `default_env` (the registry-default name)
22/// and allows no override — the factory passes the already-resolved
23/// env name.
24fn require_key(provider: &str, env_var: &str) -> Result<String> {
25    resolve_api_key(env_var, None).ok_or_else(|| {
26        ModelError::Authentication(format!("{} requires env var {}", provider, env_var))
27    })
28}
29
30fn require_key_with_fallback(
31    provider: &str,
32    env_var: &str,
33    fallback_env_var: &str,
34) -> Result<String> {
35    resolve_api_key_with_fallback(env_var, fallback_env_var, None).ok_or_else(|| {
36        ModelError::Authentication(format!(
37            "{} requires env var {} (or legacy {})",
38            provider, env_var, fallback_env_var
39        ))
40    })
41}
42
43use super::model::{
44    AnthropicProvider, GeminiProvider, ModelProvider, OllamaProvider, OpenAICompatProvider,
45};
46
47/// Per-process provider cache. Providers are expensive to construct
48/// (HTTP client, connection pool, capability lookup) so the effect
49/// runner asks for them lazily and reuses across turns.
50pub struct ProviderFactory {
51    config: Arc<Config>,
52    cache: Mutex<std::collections::HashMap<String, Arc<dyn ModelProvider>>>,
53}
54
55impl ProviderFactory {
56    pub fn new(config: Config) -> Self {
57        Self {
58            config: Arc::new(config),
59            cache: Mutex::new(std::collections::HashMap::new()),
60        }
61    }
62
63    pub fn config(&self) -> &Config {
64        &self.config
65    }
66
67    /// Resolve (or lazily construct) a provider for the given model
68    /// ID. Hits the cache on the second and subsequent calls for the
69    /// same ID.
70    pub async fn resolve(&self, model_id: &str) -> Result<Arc<dyn ModelProvider>> {
71        {
72            let cache = self.cache.lock().await;
73            if let Some(p) = cache.get(model_id) {
74                return Ok(Arc::clone(p));
75            }
76        }
77
78        let provider = build_provider(&self.config, model_id).await?;
79        let arc: Arc<dyn ModelProvider> = Arc::from(provider);
80
81        let mut cache = self.cache.lock().await;
82        cache.insert(model_id.to_string(), Arc::clone(&arc));
83        Ok(arc)
84    }
85}
86
87/// Build a provider for the given `model_id`:
88///   1. `ollama/<model>` → OllamaProvider.
89///   2. `anthropic/<model>` → AnthropicProvider.
90///   3. `gemini/<model>` → GeminiProvider.
91///   4. Other builtin providers (openai, openrouter, groq, …) → OpenAICompatProvider.
92///   5. User-defined `[providers.<name>]` → custom OpenAICompatProvider.
93///   6. Bare model name → OllamaProvider.
94async fn build_provider(config: &Config, model_id: &str) -> Result<Box<dyn ModelProvider>> {
95    let (provider, model_name) = parse_model_id(model_id);
96    let provider_lc = provider.to_lowercase();
97
98    // 1. Ollama (and bare names). F11: pass Arc<Config> so the wrapper
99    // can forward Ollama hardware options to the adapter.
100    if provider_lc == "ollama" {
101        let backend = ollama_backend_config(config);
102        let p = OllamaProvider::with_app_config(
103            model_name,
104            Arc::new(backend),
105            Arc::new(config.clone()),
106        )
107        .await?;
108        return Ok(Box::new(p));
109    }
110
111    // 2. Anthropic — bespoke API shape.
112    if provider_lc == "anthropic" {
113        let user_cfg = config.providers.get("anthropic");
114        let base_url = user_cfg
115            .and_then(|c| c.base_url.clone())
116            .unwrap_or_else(|| "https://api.anthropic.com/v1".to_string());
117        let api_key_env = user_cfg
118            .and_then(|c| c.api_key_env.as_deref())
119            .unwrap_or("ANTHROPIC_API_KEY");
120        let api_key = require_key("anthropic", api_key_env)?;
121        let p = AnthropicProvider::new(api_key, model_name.to_string(), base_url)?;
122        return Ok(Box::new(p));
123    }
124
125    // 3. Gemini — GCP AI Studio shape.
126    if provider_lc == "gemini" {
127        let user_cfg = config.providers.get("gemini");
128        let base_url = user_cfg
129            .and_then(|c| c.base_url.clone())
130            .unwrap_or_else(|| "https://generativelanguage.googleapis.com/v1beta".to_string());
131        let api_key = match user_cfg.and_then(|c| c.api_key_env.as_deref()) {
132            Some(api_key_env) => require_key("gemini", api_key_env)?,
133            None => {
134                require_key_with_fallback("gemini", GEMINI_API_KEY_ENV, GEMINI_LEGACY_API_KEY_ENV)?
135            },
136        };
137        let p = GeminiProvider::new(api_key, model_name.to_string(), base_url)?;
138        return Ok(Box::new(p));
139    }
140
141    // 4 + 5. OpenAI-compatible registry or user-custom.
142    if let Some(profile) = lookup_provider(&provider_lc) {
143        let user_cfg = config.providers.get(&provider_lc);
144        let base_url = user_cfg
145            .and_then(|c| c.base_url.clone())
146            .unwrap_or_else(|| profile.base_url.to_string());
147        let api_key_env = user_cfg
148            .and_then(|c| c.api_key_env.as_deref())
149            .unwrap_or(profile.api_key_env);
150        let api_key = require_key(&provider_lc, api_key_env)?;
151        let extra_headers = user_cfg
152            .map(|c| c.extra_headers.clone())
153            .unwrap_or_default();
154        let p = OpenAICompatProvider::new(
155            profile,
156            base_url,
157            api_key,
158            model_name.to_string(),
159            extra_headers,
160        )?;
161        return Ok(Box::new(p));
162    }
163
164    // User-custom: no registry entry, but the user has [providers.<name>]
165    // in config with a declared `compat` field.
166    if let Some(user_cfg) = config.providers.get(&provider_lc)
167        && let Some(profile) = user_profile_to_static(&provider_lc, user_cfg)
168    {
169        let base_url = user_cfg.base_url.clone().ok_or_else(|| {
170            ModelError::InvalidRequest(format!(
171                "custom provider '{}' requires base_url in config",
172                provider_lc
173            ))
174        })?;
175        let api_key_env = user_cfg.api_key_env.as_deref().ok_or_else(|| {
176            ModelError::InvalidRequest(format!(
177                "custom provider '{}' requires api_key_env in config",
178                provider_lc
179            ))
180        })?;
181        let api_key = require_key(&provider_lc, api_key_env)?;
182        let p = OpenAICompatProvider::new(
183            profile,
184            base_url,
185            api_key,
186            model_name.to_string(),
187            user_cfg.extra_headers.clone(),
188        )?;
189        return Ok(Box::new(p));
190    }
191
192    Err(ModelError::InvalidRequest(format!(
193        "Unknown provider '{}' (model_id: {})",
194        provider, model_id
195    )))
196}
197
198/// Parse `provider/model` → `(provider, model)`. Bare strings are
199/// Ollama by convention.
200fn parse_model_id(model_id: &str) -> (String, &str) {
201    match model_id.split_once('/') {
202        Some((p, m)) => (p.to_string(), m),
203        None => ("ollama".to_string(), model_id),
204    }
205}
206
207fn ollama_backend_config(config: &Config) -> BackendConfig {
208    BackendConfig {
209        ollama_url: format!("http://{}:{}", config.ollama.host, config.ollama.port),
210        max_idle_per_host: 10,
211        timeout_secs: 10,
212    }
213}
214
215/// Convert a user-defined `[providers.<name>]` entry into a `&'static
216/// ProviderProfile`. We need `&'static` because `ProviderProfile`'s
217/// lifetime is tied to the registry constants; we leak a tiny owned
218/// copy so custom providers can participate without redesigning the
219/// profile type. Leaked allocations are bounded by the number of
220/// custom providers (typically 0-3).
221fn user_profile_to_static(
222    name: &str,
223    user_cfg: &crate::app::UserProviderConfig,
224) -> Option<&'static crate::models::ProviderProfile> {
225    use crate::models::{ProviderProfile, ReasoningExtraction, ReasoningStrategy};
226
227    let compat = user_cfg.compat.as_deref().unwrap_or("openai");
228    let strategy = match compat {
229        "openai" => ReasoningStrategy::None,
230        "openai-effort" => ReasoningStrategy::Effort,
231        "openrouter" => ReasoningStrategy::OpenRouterShape,
232        _ => ReasoningStrategy::None,
233    };
234
235    let profile = Box::new(ProviderProfile {
236        name: Box::leak(name.to_string().into_boxed_str()),
237        base_url: Box::leak(
238            user_cfg
239                .base_url
240                .clone()
241                .unwrap_or_default()
242                .into_boxed_str(),
243        ),
244        api_key_env: Box::leak(
245            user_cfg
246                .api_key_env
247                .clone()
248                .unwrap_or_default()
249                .into_boxed_str(),
250        ),
251        extra_headers: &[],
252        reasoning_strategy: strategy,
253        reasoning_extraction: ReasoningExtraction::None,
254    });
255    Some(Box::leak(profile))
256}
257
258#[cfg(test)]
259mod tests {
260    use super::*;
261    use std::sync::atomic::{AtomicUsize, Ordering};
262
263    fn unique_env(prefix: &str) -> String {
264        static N: AtomicUsize = AtomicUsize::new(0);
265        format!(
266            "{}_{}_{}",
267            prefix,
268            std::process::id(),
269            N.fetch_add(1, Ordering::SeqCst)
270        )
271    }
272
273    #[test]
274    fn parse_bare_name_defaults_to_ollama() {
275        let (p, m) = parse_model_id("qwen3-coder:30b");
276        assert_eq!(p, "ollama");
277        assert_eq!(m, "qwen3-coder:30b");
278    }
279
280    #[test]
281    fn parse_prefixed() {
282        let (p, m) = parse_model_id("anthropic/claude-opus-4-7");
283        assert_eq!(p, "anthropic");
284        assert_eq!(m, "claude-opus-4-7");
285    }
286
287    #[test]
288    fn gemini_key_resolution_accepts_legacy_fallback() {
289        let primary = unique_env("MERMAID_FACTORY_GEMINI_PRIMARY");
290        let legacy = unique_env("MERMAID_FACTORY_GEMINI_LEGACY");
291        temp_env::with_vars(
292            [(primary.as_str(), None), (legacy.as_str(), Some("legacy"))],
293            || {
294                let resolved = require_key_with_fallback("gemini", &primary, &legacy)
295                    .expect("legacy fallback should resolve");
296                assert_eq!(resolved, "legacy");
297            },
298        );
299    }
300
301    #[test]
302    fn gemini_key_resolution_prefers_google_primary() {
303        let primary = unique_env("MERMAID_FACTORY_GEMINI_PRIMARY2");
304        let legacy = unique_env("MERMAID_FACTORY_GEMINI_LEGACY2");
305        temp_env::with_vars(
306            [
307                (primary.as_str(), Some("google")),
308                (legacy.as_str(), Some("legacy")),
309            ],
310            || {
311                let resolved = require_key_with_fallback("gemini", &primary, &legacy)
312                    .expect("primary should resolve");
313                assert_eq!(resolved, "google");
314            },
315        );
316    }
317
318    #[tokio::test]
319    async fn factory_reports_unknown_provider_clearly() {
320        let cfg = Config::default();
321        let f = ProviderFactory::new(cfg);
322        match f.resolve("totally-made-up/model").await {
323            Ok(_) => panic!("expected error"),
324            Err(e) => {
325                let msg = format!("{}", e);
326                assert!(
327                    msg.contains("totally-made-up") || msg.contains("Unknown provider"),
328                    "error message: {}",
329                    msg
330                );
331            },
332        }
333    }
334}