1use std::collections::HashMap;
8use std::sync::Arc;
9use std::time::Duration;
10
11use super::adapters::openai_compat::OpenAICompatAdapter;
12use super::config::BackendConfig;
13use super::error::{BackendError, ModelError, Result};
14use super::providers::{
15 CompatStyle, ProviderProfile, ReasoningExtraction, ReasoningStrategy, lookup_provider,
16};
17use super::traits::Model;
18use crate::app::Config;
19use crate::utils::resolve_api_key;
20
21pub struct ModelFactory {
30 config: Arc<BackendConfig>,
31 user_config: Option<Arc<Config>>,
32}
33
34impl ModelFactory {
35 pub fn new(config: BackendConfig) -> Self {
39 Self {
40 config: Arc::new(config),
41 user_config: None,
42 }
43 }
44
45 pub fn from_config(config: &Config) -> Self {
49 Self {
50 config: Arc::new(Self::config_to_backend_config(config)),
51 user_config: Some(Arc::new(config.clone())),
52 }
53 }
54
55 pub async fn create_model(&self, model_id: &str) -> Result<Box<dyn Model>> {
65 let (provider, model_name) = parse_model_id(model_id);
66 let provider_lc = provider.to_lowercase();
67
68 if provider_lc == "ollama" {
70 use super::adapters::ollama::OllamaAdapter;
71 let adapter = OllamaAdapter::new(model_name, self.config.clone()).await?;
72 return Ok(Box::new(adapter));
73 }
74
75 let user_config = self.user_config.as_ref().ok_or_else(|| {
78 ModelError::InvalidRequest(format!(
79 "Provider '{}' requires app config; use ModelFactory::from_config",
80 provider
81 ))
82 })?;
83
84 if provider_lc == "anthropic" {
86 use super::adapters::anthropic::AnthropicAdapter;
87 let user_cfg = user_config.providers.get("anthropic");
88 let base_url = user_cfg
89 .and_then(|c| c.base_url.clone())
90 .unwrap_or_else(|| "https://api.anthropic.com/v1".to_string());
91 let api_key = resolve_api_key(
92 "ANTHROPIC_API_KEY",
93 user_cfg.and_then(|c| c.api_key_env.as_deref()),
94 )
95 .ok_or_else(|| {
96 ModelError::Authentication(
97 "ANTHROPIC_API_KEY not set (or set [providers.anthropic].api_key_env to a \
98 custom env var)"
99 .to_string(),
100 )
101 })?;
102 let adapter = AnthropicAdapter::new(api_key, model_name.to_string(), base_url)?;
103 return Ok(Box::new(adapter));
104 }
105
106 if provider_lc == "gemini" {
108 use super::adapters::gemini::GeminiAdapter;
109 let user_cfg = user_config.providers.get("gemini");
110 let base_url = user_cfg
111 .and_then(|c| c.base_url.clone())
112 .unwrap_or_else(|| "https://generativelanguage.googleapis.com/v1beta".to_string());
113 let api_key = resolve_api_key(
114 "GOOGLE_API_KEY",
115 user_cfg.and_then(|c| c.api_key_env.as_deref()),
116 )
117 .ok_or_else(|| {
118 ModelError::Authentication(
119 "GOOGLE_API_KEY not set (or set [providers.gemini].api_key_env to a custom \
120 env var)"
121 .to_string(),
122 )
123 })?;
124 let adapter = GeminiAdapter::new(api_key, model_name.to_string(), base_url)?;
125 return Ok(Box::new(adapter));
126 }
127
128 if let Some(profile) = lookup_provider(&provider_lc) {
130 let user_cfg = user_config.providers.get(&provider_lc);
131 let base_url = user_cfg
132 .and_then(|c| c.base_url.clone())
133 .unwrap_or_else(|| profile.base_url.to_string());
134 let api_key = resolve_api_key(
135 profile.api_key_env,
136 user_cfg.and_then(|c| c.api_key_env.as_deref()),
137 )
138 .ok_or_else(|| {
139 ModelError::Authentication(format!(
140 "{} API key not set (env: {})",
141 provider_lc, profile.api_key_env
142 ))
143 })?;
144 let mut headers: HashMap<String, String> = profile
145 .extra_headers
146 .iter()
147 .map(|(k, v)| (k.to_string(), v.to_string()))
148 .collect();
149 if let Some(c) = user_cfg {
150 for (k, v) in &c.extra_headers {
151 headers.insert(k.clone(), v.clone());
152 }
153 }
154 let adapter = OpenAICompatAdapter::new(
155 profile,
156 base_url,
157 api_key,
158 model_name.to_string(),
159 headers,
160 )?;
161 return Ok(Box::new(adapter));
162 }
163
164 if let Some(user_cfg) = user_config.providers.get(&provider_lc) {
167 let base_url = user_cfg.base_url.clone().ok_or_else(|| {
168 ModelError::Config(super::error::ConfigError::MissingRequired(format!(
169 "providers.{}.base_url",
170 provider_lc
171 )))
172 })?;
173 let api_key_env = user_cfg.api_key_env.as_deref().ok_or_else(|| {
174 ModelError::Config(super::error::ConfigError::MissingRequired(format!(
175 "providers.{}.api_key_env",
176 provider_lc
177 )))
178 })?;
179 let api_key = resolve_api_key(api_key_env, None).ok_or_else(|| {
180 ModelError::Authentication(format!(
181 "{} API key not set (env: {})",
182 provider_lc, api_key_env
183 ))
184 })?;
185 let compat_style: CompatStyle = match user_cfg.compat.as_deref() {
186 Some(s) => {
187 serde_json::from_value::<CompatStyle>(serde_json::Value::String(s.to_string()))
188 .map_err(|_| {
189 ModelError::Config(super::error::ConfigError::InvalidValue {
190 field: format!("providers.{}.compat", provider_lc),
191 value: s.to_string(),
192 reason: "must be one of: openai, openai-effort, openrouter"
193 .to_string(),
194 })
195 })?
196 },
197 None => CompatStyle::Openai,
198 };
199 let profile = synthesize_custom_profile(&provider_lc, compat_style);
204 let headers: HashMap<String, String> = user_cfg.extra_headers.clone();
205 let adapter = OpenAICompatAdapter::new(
206 profile,
207 base_url,
208 api_key,
209 model_name.to_string(),
210 headers,
211 )?;
212 return Ok(Box::new(adapter));
213 }
214
215 Err(ModelError::InvalidRequest(format!(
216 "Unknown provider: '{}'. Built-in providers: ollama, anthropic, gemini, openai, \
217 groq, openrouter, cerebras, deepinfra, together. Add custom providers to \
218 ~/.config/mermaid/config.toml under [providers.<name>].",
219 provider
220 )))
221 }
222
223 pub async fn create(model_id: &str, config: Option<&Config>) -> Result<Box<dyn Model>> {
234 let factory = match config {
235 Some(c) => Self::from_config(c),
236 None => Self::new(BackendConfig::default()),
237 };
238 factory.create_model(model_id).await
239 }
240
241 pub async fn list_all_models(config: &Config) -> Result<Vec<String>> {
243 let factory = Self::from_config(config);
244 let providers = factory.available_providers_impl().await;
245
246 let mut all_models = Vec::new();
247 for provider in providers {
248 if let Ok(models) = factory.list_models(&provider).await {
249 for model_name in models {
250 all_models.push(format!("{}/{}", provider, model_name));
251 }
252 }
253 }
254
255 all_models.sort();
256 Ok(all_models)
257 }
258
259 pub async fn available_providers() -> Vec<String> {
261 let factory = Self::new(BackendConfig::default());
262 factory.available_providers_impl().await
263 }
264
265 pub async fn available_providers_pub(&self) -> Vec<String> {
267 self.available_providers_impl().await
268 }
269
270 async fn available_providers_impl(&self) -> Vec<String> {
279 let mut providers = Vec::new();
280
281 let url = format!(
283 "{}/api/tags",
284 self.config.ollama_url.trim().trim_end_matches('/')
285 );
286 if let Ok(client) = reqwest::Client::builder()
287 .timeout(Duration::from_secs(2))
288 .build()
289 && let Ok(resp) = client.get(&url).send().await
290 && resp.status().is_success()
291 {
292 providers.push("ollama".to_string());
293 }
294
295 if let Some(user_config) = self.user_config.as_ref() {
300 let user_cfg = user_config.providers.get("anthropic");
301 if resolve_api_key(
302 "ANTHROPIC_API_KEY",
303 user_cfg.and_then(|c| c.api_key_env.as_deref()),
304 )
305 .is_some()
306 {
307 providers.push("anthropic".to_string());
308 }
309 }
310
311 if let Some(user_config) = self.user_config.as_ref() {
315 let user_cfg = user_config.providers.get("gemini");
316 if resolve_api_key(
317 "GOOGLE_API_KEY",
318 user_cfg.and_then(|c| c.api_key_env.as_deref()),
319 )
320 .is_some()
321 {
322 providers.push("gemini".to_string());
323 }
324 }
325
326 if let Some(user_config) = self.user_config.as_ref() {
329 for profile in super::providers::REGISTRY {
331 let user_cfg = user_config.providers.get(profile.name);
332 let api_key_present = resolve_api_key(
333 profile.api_key_env,
334 user_cfg.and_then(|c| c.api_key_env.as_deref()),
335 )
336 .is_some();
337 if api_key_present {
338 providers.push(profile.name.to_string());
339 }
340 }
341 for (name, cfg) in &user_config.providers {
344 if lookup_provider(name).is_some() {
345 continue; }
347 if let (Some(_url), Some(env)) = (&cfg.base_url, cfg.api_key_env.as_deref())
348 && resolve_api_key(env, None).is_some()
349 {
350 providers.push(name.clone());
351 }
352 }
353 }
354
355 providers
356 }
357
358 pub async fn list_models(&self, provider: &str) -> Result<Vec<String>> {
360 let lc = provider.to_lowercase();
361 if lc == "anthropic" {
366 return Ok(vec![
367 "claude-opus-4-7".to_string(),
368 "claude-sonnet-4-6".to_string(),
369 "claude-opus-4-6".to_string(),
370 "claude-sonnet-4-5".to_string(),
371 "claude-opus-4-5".to_string(),
372 "claude-haiku-4-5".to_string(),
373 ]);
374 }
375 if lc == "gemini" {
386 return Ok(vec![
387 "gemini-pro-latest".to_string(),
388 "gemini-flash-latest".to_string(),
389 "gemini-3.1-pro-preview".to_string(),
390 "gemini-3-flash-preview".to_string(),
391 "gemini-3.1-flash-lite-preview".to_string(),
392 "gemini-2.5-pro".to_string(),
393 "gemini-2.5-flash".to_string(),
394 "gemini-2.5-flash-lite".to_string(),
395 ]);
396 }
397 if lc == "ollama" {
398 let url = format!(
399 "{}/api/tags",
400 self.config.ollama_url.trim().trim_end_matches('/')
401 );
402 let client = reqwest::Client::builder()
403 .timeout(Duration::from_secs(5))
404 .build()
405 .map_err(|e| {
406 ModelError::Backend(BackendError::ConnectionFailed {
407 backend: "ollama".to_string(),
408 url: url.clone(),
409 reason: e.to_string(),
410 })
411 })?;
412 let response = client.get(&url).send().await.map_err(|e| {
413 ModelError::Backend(BackendError::ConnectionFailed {
414 backend: "ollama".to_string(),
415 url: url.clone(),
416 reason: e.to_string(),
417 })
418 })?;
419 if !response.status().is_success() {
420 return Err(ModelError::Backend(BackendError::HttpError {
421 status: response.status().as_u16(),
422 message: "Failed to list models".to_string(),
423 }));
424 }
425 let tags: super::adapters::ollama::OllamaTagsResponse =
426 response.json().await.map_err(|e| ModelError::ParseError {
427 message: format!("Failed to parse tags response: {}", e),
428 raw: None,
429 })?;
430 return Ok(tags.models.into_iter().map(|m| m.name).collect());
431 }
432
433 let synthetic_model_id = format!("{}/_", lc);
440 let adapter = self.create_model(&synthetic_model_id).await?;
441 adapter.list_models().await
442 }
443
444 fn config_to_backend_config(config: &Config) -> BackendConfig {
446 let ollama_url = format!("http://{}:{}", config.ollama.host, config.ollama.port);
447 BackendConfig {
448 ollama_url,
449 timeout_secs: 10,
450 max_idle_per_host: 10,
451 }
452 }
453}
454
455fn synthesize_custom_profile(name: &str, compat: CompatStyle) -> &'static ProviderProfile {
462 let leaked_name: &'static str = Box::leak(name.to_string().into_boxed_str());
466 let extraction = match compat {
467 CompatStyle::Openai => ReasoningExtraction::None,
468 CompatStyle::OpenaiEffort => ReasoningExtraction::None,
469 CompatStyle::Openrouter => ReasoningExtraction::DeltaContentField("reasoning"),
470 };
471 let profile = ProviderProfile {
472 name: leaked_name,
473 base_url: "user-defined",
477 api_key_env: "user-defined",
478 extra_headers: &[],
479 reasoning_strategy: match compat {
480 CompatStyle::Openai => ReasoningStrategy::None,
481 CompatStyle::OpenaiEffort => ReasoningStrategy::Effort,
482 CompatStyle::Openrouter => ReasoningStrategy::OpenRouterShape,
483 },
484 reasoning_extraction: extraction,
485 };
486 Box::leak(Box::new(profile))
487}
488
489fn parse_model_id(model_id: &str) -> (&str, &str) {
499 if let Some(idx) = model_id.find('/') {
500 let provider = &model_id[..idx];
501 let model = &model_id[idx + 1..];
502 (provider, model)
503 } else {
504 ("ollama", model_id)
505 }
506}
507
508#[cfg(test)]
509mod tests {
510 use super::*;
511
512 #[test]
513 fn test_parse_model_id_with_provider() {
514 let (provider, model) = parse_model_id("ollama/llama3");
515 assert_eq!(provider, "ollama");
516 assert_eq!(model, "llama3");
517 }
518
519 #[test]
520 fn test_parse_model_id_bare_name() {
521 let (provider, model) = parse_model_id("llama3");
522 assert_eq!(provider, "ollama");
523 assert_eq!(model, "llama3");
524 }
525
526 #[test]
527 fn test_parse_model_id_with_tag() {
528 let (provider, model) = parse_model_id("ollama/llama3:latest");
529 assert_eq!(provider, "ollama");
530 assert_eq!(model, "llama3:latest");
531 }
532
533 #[test]
534 fn test_parse_model_id_bare_with_tag() {
535 let (provider, model) = parse_model_id("llama3:7b");
536 assert_eq!(provider, "ollama");
537 assert_eq!(model, "llama3:7b");
538 }
539
540 #[test]
545 fn test_parse_model_id_keeps_slashes_in_model_name() {
546 let (provider, model) = parse_model_id("together/deepseek-ai/DeepSeek-R1");
547 assert_eq!(provider, "together");
548 assert_eq!(model, "deepseek-ai/DeepSeek-R1");
549 }
550
551 #[test]
552 fn test_model_spec_parsing() {
553 let specs = vec![
554 ("ollama/tinyllama", Some("ollama"), "tinyllama"),
555 ("qwen3-coder:30b", None, "qwen3-coder:30b"),
556 ("kimi-k2.5:cloud", None, "kimi-k2.5:cloud"),
557 ];
558 for (spec, expected_provider, expected_model) in specs {
559 let parts: Vec<&str> = spec.split('/').collect();
560 if parts.len() == 2 {
561 assert_eq!(Some(parts[0]), expected_provider);
562 assert_eq!(parts[1], expected_model);
563 } else {
564 assert_eq!(None, expected_provider);
565 assert_eq!(spec, expected_model);
566 }
567 }
568 }
569
570 #[test]
571 fn test_provider_extraction() {
572 fn extract_provider(spec: &str) -> Option<&str> {
573 spec.split('/').next().filter(|_| spec.contains('/'))
574 }
575 assert_eq!(extract_provider("ollama/tinyllama"), Some("ollama"));
576 assert_eq!(extract_provider("qwen3-coder:30b"), None);
577 }
578
579 #[tokio::test]
580 async fn unknown_provider_returns_clear_error() {
581 let cfg = Config::default();
582 let factory = ModelFactory::from_config(&cfg);
583 match factory.create_model("nonexistent/foo").await {
586 Ok(_) => panic!("expected unknown-provider error"),
587 Err(e) => {
588 let msg = e.to_string();
589 assert!(
590 msg.contains("Unknown provider"),
591 "expected 'Unknown provider' in: {}",
592 msg
593 );
594 assert!(
595 msg.contains("nonexistent"),
596 "error should name the bad provider; got: {}",
597 msg
598 );
599 },
600 }
601 }
602
603 #[tokio::test]
604 async fn missing_api_key_returns_authentication_error() {
605 temp_env::async_with_vars([("GROQ_API_KEY", None::<&str>)], async {
609 let cfg = Config::default();
610 let factory = ModelFactory::from_config(&cfg);
611 match factory.create_model("groq/qwen-qwq-32b").await {
612 Ok(_) => panic!("expected auth error"),
613 Err(e) => {
614 let msg = e.to_string();
615 assert!(
616 msg.contains("API key") || msg.contains("Authentication"),
617 "expected auth error, got: {}",
618 msg
619 );
620 assert!(
621 msg.contains("GROQ_API_KEY"),
622 "error should name the env var; got: {}",
623 msg
624 );
625 },
626 }
627 })
628 .await;
629 }
630
631 #[tokio::test]
635 async fn anthropic_missing_api_key_returns_authentication_error() {
636 temp_env::async_with_vars([("ANTHROPIC_API_KEY", None::<&str>)], async {
637 let cfg = Config::default();
638 let factory = ModelFactory::from_config(&cfg);
639 match factory.create_model("anthropic/claude-sonnet-4-6").await {
640 Ok(_) => panic!("expected auth error"),
641 Err(e) => {
642 let msg = e.to_string();
643 assert!(
644 msg.contains("ANTHROPIC_API_KEY"),
645 "error should name the env var; got: {}",
646 msg
647 );
648 assert!(
649 !msg.contains("Unknown provider"),
650 "anthropic must be a known provider; got: {}",
651 msg
652 );
653 },
654 }
655 })
656 .await;
657 }
658
659 #[tokio::test]
664 async fn anthropic_list_models_returns_curated_list() {
665 let cfg = Config::default();
666 let factory = ModelFactory::from_config(&cfg);
667 let models = factory
668 .list_models("anthropic")
669 .await
670 .expect("curated list should always succeed");
671 assert!(
672 models.iter().any(|m| m == "claude-sonnet-4-6"),
673 "expected sonnet-4-6 in curated list; got {:?}",
674 models
675 );
676 assert!(
677 models.iter().any(|m| m == "claude-opus-4-7"),
678 "expected opus-4-7 in curated list; got {:?}",
679 models
680 );
681 }
682
683 #[tokio::test]
687 async fn gemini_missing_api_key_returns_authentication_error() {
688 temp_env::async_with_vars([("GOOGLE_API_KEY", None::<&str>)], async {
689 let cfg = Config::default();
690 let factory = ModelFactory::from_config(&cfg);
691 match factory.create_model("gemini/gemini-3.1-pro-preview").await {
694 Ok(_) => panic!("expected auth error"),
695 Err(e) => {
696 let msg = e.to_string();
697 assert!(
698 msg.contains("GOOGLE_API_KEY"),
699 "error should name the env var; got: {}",
700 msg
701 );
702 assert!(
703 !msg.contains("Unknown provider"),
704 "gemini must be a known provider; got: {}",
705 msg
706 );
707 },
708 }
709 })
710 .await;
711 }
712
713 #[tokio::test]
714 async fn gemini_list_models_returns_curated_list() {
715 let cfg = Config::default();
716 let factory = ModelFactory::from_config(&cfg);
717 let models = factory
718 .list_models("gemini")
719 .await
720 .expect("curated list should always succeed");
721 assert!(
725 models.iter().any(|m| m == "gemini-pro-latest"),
726 "expected gemini-pro-latest alias in curated list; got {:?}",
727 models
728 );
729 assert!(
731 models.iter().any(|m| m == "gemini-3.1-pro-preview"),
732 "expected gemini-3.1-pro-preview in curated list; got {:?}",
733 models
734 );
735 assert!(
739 !models.iter().any(|m| m == "gemini-3-pro"),
740 "gemini-3-pro is not a valid API ID as of 2026-04; \
741 must not be in curated list; got {:?}",
742 models
743 );
744 }
745}