1use std::sync::Arc;
9
10use tokio::sync::Mutex;
11
12use crate::app::Config;
13use crate::models::config::BackendConfig;
14use crate::models::{ModelError, Result, lookup_provider};
15use crate::utils::{resolve_api_key, resolve_api_key_with_fallback};
16
17const GEMINI_API_KEY_ENV: &str = "GOOGLE_API_KEY";
18const GEMINI_LEGACY_API_KEY_ENV: &str = "GEMINI_API_KEY";
19
20fn require_key(provider: &str, env_var: &str) -> Result<String> {
25 resolve_api_key(env_var, None).ok_or_else(|| {
26 ModelError::Authentication(format!("{} requires env var {}", provider, env_var))
27 })
28}
29
30fn require_key_with_fallback(
31 provider: &str,
32 env_var: &str,
33 fallback_env_var: &str,
34) -> Result<String> {
35 resolve_api_key_with_fallback(env_var, fallback_env_var, None).ok_or_else(|| {
36 ModelError::Authentication(format!(
37 "{} requires env var {} (or legacy {})",
38 provider, env_var, fallback_env_var
39 ))
40 })
41}
42
43use super::model::{
44 AnthropicProvider, GeminiProvider, ModelProvider, OllamaProvider, OpenAICompatProvider,
45};
46
47pub struct ProviderFactory {
51 config: Arc<Config>,
52 cache: Mutex<std::collections::HashMap<String, Arc<dyn ModelProvider>>>,
53}
54
55impl ProviderFactory {
56 pub fn new(config: Config) -> Self {
57 Self {
58 config: Arc::new(config),
59 cache: Mutex::new(std::collections::HashMap::new()),
60 }
61 }
62
63 pub fn config(&self) -> &Config {
64 &self.config
65 }
66
67 pub async fn resolve(&self, model_id: &str) -> Result<Arc<dyn ModelProvider>> {
71 {
72 let cache = self.cache.lock().await;
73 if let Some(p) = cache.get(model_id) {
74 return Ok(Arc::clone(p));
75 }
76 }
77
78 let provider = build_provider(&self.config, model_id).await?;
79 let arc: Arc<dyn ModelProvider> = Arc::from(provider);
80
81 let mut cache = self.cache.lock().await;
82 cache.insert(model_id.to_string(), Arc::clone(&arc));
83 Ok(arc)
84 }
85}
86
87async fn build_provider(config: &Config, model_id: &str) -> Result<Box<dyn ModelProvider>> {
95 let (provider, model_name) = parse_model_id(model_id);
96 let provider_lc = provider.to_lowercase();
97
98 if provider_lc == "ollama" {
101 let backend = ollama_backend_config(config);
102 let p = OllamaProvider::with_app_config(
103 model_name,
104 Arc::new(backend),
105 Arc::new(config.clone()),
106 )
107 .await?;
108 return Ok(Box::new(p));
109 }
110
111 if provider_lc == "anthropic" {
113 let user_cfg = config.providers.get("anthropic");
114 let base_url = user_cfg
115 .and_then(|c| c.base_url.clone())
116 .unwrap_or_else(|| "https://api.anthropic.com/v1".to_string());
117 let api_key_env = user_cfg
118 .and_then(|c| c.api_key_env.as_deref())
119 .unwrap_or("ANTHROPIC_API_KEY");
120 let api_key = require_key("anthropic", api_key_env)?;
121 let p = AnthropicProvider::new(api_key, model_name.to_string(), base_url)?;
122 return Ok(Box::new(p));
123 }
124
125 if provider_lc == "gemini" {
127 let user_cfg = config.providers.get("gemini");
128 let base_url = user_cfg
129 .and_then(|c| c.base_url.clone())
130 .unwrap_or_else(|| "https://generativelanguage.googleapis.com/v1beta".to_string());
131 let api_key = match user_cfg.and_then(|c| c.api_key_env.as_deref()) {
132 Some(api_key_env) => require_key("gemini", api_key_env)?,
133 None => {
134 require_key_with_fallback("gemini", GEMINI_API_KEY_ENV, GEMINI_LEGACY_API_KEY_ENV)?
135 },
136 };
137 let p = GeminiProvider::new(api_key, model_name.to_string(), base_url)?;
138 return Ok(Box::new(p));
139 }
140
141 if let Some(profile) = lookup_provider(&provider_lc) {
143 let user_cfg = config.providers.get(&provider_lc);
144 let base_url = user_cfg
145 .and_then(|c| c.base_url.clone())
146 .unwrap_or_else(|| profile.base_url.to_string());
147 let api_key_env = user_cfg
148 .and_then(|c| c.api_key_env.as_deref())
149 .unwrap_or(profile.api_key_env);
150 let api_key = require_key(&provider_lc, api_key_env)?;
151 let extra_headers = user_cfg
152 .map(|c| c.extra_headers.clone())
153 .unwrap_or_default();
154 let p = OpenAICompatProvider::new(
155 profile,
156 base_url,
157 api_key,
158 model_name.to_string(),
159 extra_headers,
160 )?;
161 return Ok(Box::new(p));
162 }
163
164 if let Some(user_cfg) = config.providers.get(&provider_lc)
167 && let Some(profile) = user_profile_to_static(&provider_lc, user_cfg)
168 {
169 let base_url = user_cfg.base_url.clone().ok_or_else(|| {
170 ModelError::InvalidRequest(format!(
171 "custom provider '{}' requires base_url in config",
172 provider_lc
173 ))
174 })?;
175 let api_key_env = user_cfg.api_key_env.as_deref().ok_or_else(|| {
176 ModelError::InvalidRequest(format!(
177 "custom provider '{}' requires api_key_env in config",
178 provider_lc
179 ))
180 })?;
181 let api_key = require_key(&provider_lc, api_key_env)?;
182 let p = OpenAICompatProvider::new(
183 profile,
184 base_url,
185 api_key,
186 model_name.to_string(),
187 user_cfg.extra_headers.clone(),
188 )?;
189 return Ok(Box::new(p));
190 }
191
192 Err(ModelError::InvalidRequest(format!(
193 "Unknown provider '{}' (model_id: {})",
194 provider, model_id
195 )))
196}
197
198fn parse_model_id(model_id: &str) -> (String, &str) {
201 match model_id.split_once('/') {
202 Some((p, m)) => (p.to_string(), m),
203 None => ("ollama".to_string(), model_id),
204 }
205}
206
207fn ollama_backend_config(config: &Config) -> BackendConfig {
208 BackendConfig {
209 ollama_url: format!("http://{}:{}", config.ollama.host, config.ollama.port),
210 max_idle_per_host: 10,
211 timeout_secs: 10,
212 }
213}
214
215fn user_profile_to_static(
222 name: &str,
223 user_cfg: &crate::app::UserProviderConfig,
224) -> Option<&'static crate::models::ProviderProfile> {
225 use crate::models::{ProviderProfile, ReasoningExtraction, ReasoningStrategy};
226
227 let compat = user_cfg.compat.as_deref().unwrap_or("openai");
228 let strategy = match compat {
229 "openai" => ReasoningStrategy::None,
230 "openai-effort" => ReasoningStrategy::Effort,
231 "openrouter" => ReasoningStrategy::OpenRouterShape,
232 _ => ReasoningStrategy::None,
233 };
234
235 let profile = Box::new(ProviderProfile {
236 name: Box::leak(name.to_string().into_boxed_str()),
237 base_url: Box::leak(
238 user_cfg
239 .base_url
240 .clone()
241 .unwrap_or_default()
242 .into_boxed_str(),
243 ),
244 api_key_env: Box::leak(
245 user_cfg
246 .api_key_env
247 .clone()
248 .unwrap_or_default()
249 .into_boxed_str(),
250 ),
251 extra_headers: &[],
252 reasoning_strategy: strategy,
253 reasoning_extraction: ReasoningExtraction::None,
254 });
255 Some(Box::leak(profile))
256}
257
258#[cfg(test)]
259mod tests {
260 use super::*;
261 use std::sync::atomic::{AtomicUsize, Ordering};
262
263 fn unique_env(prefix: &str) -> String {
264 static N: AtomicUsize = AtomicUsize::new(0);
265 format!(
266 "{}_{}_{}",
267 prefix,
268 std::process::id(),
269 N.fetch_add(1, Ordering::SeqCst)
270 )
271 }
272
273 #[test]
274 fn parse_bare_name_defaults_to_ollama() {
275 let (p, m) = parse_model_id("qwen3-coder:30b");
276 assert_eq!(p, "ollama");
277 assert_eq!(m, "qwen3-coder:30b");
278 }
279
280 #[test]
281 fn parse_prefixed() {
282 let (p, m) = parse_model_id("anthropic/claude-opus-4-7");
283 assert_eq!(p, "anthropic");
284 assert_eq!(m, "claude-opus-4-7");
285 }
286
287 #[test]
288 fn gemini_key_resolution_accepts_legacy_fallback() {
289 let primary = unique_env("MERMAID_FACTORY_GEMINI_PRIMARY");
290 let legacy = unique_env("MERMAID_FACTORY_GEMINI_LEGACY");
291 temp_env::with_vars(
292 [(primary.as_str(), None), (legacy.as_str(), Some("legacy"))],
293 || {
294 let resolved = require_key_with_fallback("gemini", &primary, &legacy)
295 .expect("legacy fallback should resolve");
296 assert_eq!(resolved, "legacy");
297 },
298 );
299 }
300
301 #[test]
302 fn gemini_key_resolution_prefers_google_primary() {
303 let primary = unique_env("MERMAID_FACTORY_GEMINI_PRIMARY2");
304 let legacy = unique_env("MERMAID_FACTORY_GEMINI_LEGACY2");
305 temp_env::with_vars(
306 [
307 (primary.as_str(), Some("google")),
308 (legacy.as_str(), Some("legacy")),
309 ],
310 || {
311 let resolved = require_key_with_fallback("gemini", &primary, &legacy)
312 .expect("primary should resolve");
313 assert_eq!(resolved, "google");
314 },
315 );
316 }
317
318 #[tokio::test]
319 async fn factory_reports_unknown_provider_clearly() {
320 let cfg = Config::default();
321 let f = ProviderFactory::new(cfg);
322 match f.resolve("totally-made-up/model").await {
323 Ok(_) => panic!("expected error"),
324 Err(e) => {
325 let msg = format!("{}", e);
326 assert!(
327 msg.contains("totally-made-up") || msg.contains("Unknown provider"),
328 "error message: {}",
329 msg
330 );
331 },
332 }
333 }
334}