Skip to main content

mermaid_cli/models/
backend.rs

1//! Model factory - creates model instances from identifiers
2//!
3//! Parses model identifiers like "ollama/llama3" and creates
4//! the appropriate adapter implementing the Model trait.
5//! Also provides static convenience methods for common operations.
6
7use std::sync::Arc;
8use std::time::Duration;
9
10use super::config::BackendConfig;
11use super::error::{BackendError, ModelError, Result};
12use super::traits::Model;
13use crate::app::Config;
14
15/// Model factory - creates model instances
16pub struct ModelFactory {
17    config: Arc<BackendConfig>,
18}
19
20impl ModelFactory {
21    /// Create a new model factory with explicit backend config
22    pub fn new(config: BackendConfig) -> Self {
23        Self {
24            config: Arc::new(config),
25        }
26    }
27
28    /// Create a factory from app::Config
29    pub fn from_config(config: &Config) -> Self {
30        Self::new(Self::config_to_backend_config(config))
31    }
32
33    /// Create a model from a full identifier (e.g., "ollama/llama3")
34    pub async fn create_model(&self, model_id: &str) -> Result<Box<dyn Model>> {
35        // Parse model identifier: "provider/model_name" or just "model_name" (defaults to ollama)
36        let (provider, model_name) = parse_model_id(model_id);
37
38        match provider.to_lowercase().as_str() {
39            "ollama" => {
40                use super::adapters::ollama::OllamaAdapter;
41                let adapter = OllamaAdapter::new(model_name, self.config.clone()).await?;
42                Ok(Box::new(adapter))
43            },
44            _ => Err(ModelError::InvalidRequest(format!(
45                "Unknown provider: {}. Only ollama/ is supported.",
46                provider
47            ))),
48        }
49    }
50
51    // --- Static convenience API (absorbed from factory.rs) ---
52
53    /// Create a model instance from a model identifier with optional app config
54    ///
55    /// Format examples:
56    /// - "ollama/qwen3-coder:30b" - Explicit Ollama provider
57    /// - "qwen3-coder:30b" - Defaults to Ollama
58    /// - "kimi-k2.5:cloud" - Ollama cloud model
59    pub async fn create(model_id: &str, config: Option<&Config>) -> Result<Box<dyn Model>> {
60        let backend_config = config
61            .map(Self::config_to_backend_config)
62            .unwrap_or_default();
63        let factory = Self::new(backend_config);
64        factory.create_model(model_id).await
65    }
66
67    /// List all models from all available providers
68    ///
69    /// Returns a list of model identifiers in "provider/model" format.
70    /// Only includes providers that are currently available.
71    pub async fn list_all_models() -> Result<Vec<String>> {
72        let factory = Self::new(BackendConfig::default());
73        let providers = factory.available_providers_impl().await;
74
75        let mut all_models = Vec::new();
76        for provider in providers {
77            if let Ok(models) = factory.list_models(&provider).await {
78                for model_name in models {
79                    all_models.push(format!("{}/{}", provider, model_name));
80                }
81            }
82        }
83
84        all_models.sort();
85        Ok(all_models)
86    }
87
88    /// Get list of available providers (static convenience)
89    pub async fn available_providers() -> Vec<String> {
90        let factory = Self::new(BackendConfig::default());
91        factory.available_providers_impl().await
92    }
93
94    // --- Instance methods ---
95
96    /// List available providers with a fast single-shot health check
97    async fn available_providers_impl(&self) -> Vec<String> {
98        let mut providers = Vec::new();
99
100        // Quick Ollama check: single GET with 2s timeout, no retries
101        let url = format!(
102            "{}/api/tags",
103            self.config.ollama_url.trim().trim_end_matches('/')
104        );
105        if let Ok(client) = reqwest::Client::builder()
106            .timeout(Duration::from_secs(2))
107            .build()
108            && let Ok(resp) = client.get(&url).send().await
109            && resp.status().is_success()
110        {
111            providers.push("ollama".to_string());
112        }
113
114        providers
115    }
116
117    /// List all models from a provider without creating a model instance
118    pub async fn list_models(&self, provider: &str) -> Result<Vec<String>> {
119        match provider {
120            "ollama" => {
121                let url = format!(
122                    "{}/api/tags",
123                    self.config.ollama_url.trim().trim_end_matches('/')
124                );
125                let client = reqwest::Client::builder()
126                    .timeout(Duration::from_secs(5))
127                    .build()
128                    .map_err(|e| {
129                        ModelError::Backend(BackendError::ConnectionFailed {
130                            backend: "ollama".to_string(),
131                            url: url.clone(),
132                            reason: e.to_string(),
133                        })
134                    })?;
135                let response = client.get(&url).send().await.map_err(|e| {
136                    ModelError::Backend(BackendError::ConnectionFailed {
137                        backend: "ollama".to_string(),
138                        url: url.clone(),
139                        reason: e.to_string(),
140                    })
141                })?;
142                if !response.status().is_success() {
143                    return Err(ModelError::Backend(BackendError::HttpError {
144                        status: response.status().as_u16(),
145                        message: "Failed to list models".to_string(),
146                    }));
147                }
148                let tags: super::adapters::ollama::OllamaTagsResponse =
149                    response.json().await.map_err(|e| ModelError::ParseError {
150                        message: format!("Failed to parse tags response: {}", e),
151                        raw: None,
152                    })?;
153                Ok(tags.models.into_iter().map(|m| m.name).collect())
154            },
155            _ => Err(ModelError::InvalidRequest(format!(
156                "Unknown provider: {}",
157                provider
158            ))),
159        }
160    }
161
162    /// Convert app::Config to BackendConfig
163    fn config_to_backend_config(config: &Config) -> BackendConfig {
164        let ollama_url = format!("http://{}:{}", config.ollama.host, config.ollama.port);
165
166        BackendConfig {
167            ollama_url,
168            timeout_secs: 10,
169            max_idle_per_host: 10,
170        }
171    }
172}
173
174/// Parse a model identifier into provider and model name
175///
176/// Formats:
177/// - "ollama/llama3" -> ("ollama", "llama3")
178/// - "llama3" -> ("ollama", "llama3")  // defaults to ollama
179/// - "llama3:latest" -> ("ollama", "llama3:latest")  // ollama tag format
180fn parse_model_id(model_id: &str) -> (&str, &str) {
181    if let Some(idx) = model_id.find('/') {
182        // Safe: '/' is ASCII, so byte offset == char offset
183        let provider = &model_id[..idx];
184        let model = &model_id[idx + 1..];
185        (provider, model)
186    } else {
187        // Default to ollama for bare model names
188        ("ollama", model_id)
189    }
190}
191
192#[cfg(test)]
193mod tests {
194    use super::*;
195
196    #[test]
197    fn test_parse_model_id_with_provider() {
198        let (provider, model) = parse_model_id("ollama/llama3");
199        assert_eq!(provider, "ollama");
200        assert_eq!(model, "llama3");
201    }
202
203    #[test]
204    fn test_parse_model_id_bare_name() {
205        let (provider, model) = parse_model_id("llama3");
206        assert_eq!(provider, "ollama");
207        assert_eq!(model, "llama3");
208    }
209
210    #[test]
211    fn test_parse_model_id_with_tag() {
212        let (provider, model) = parse_model_id("ollama/llama3:latest");
213        assert_eq!(provider, "ollama");
214        assert_eq!(model, "llama3:latest");
215    }
216
217    #[test]
218    fn test_parse_model_id_bare_with_tag() {
219        let (provider, model) = parse_model_id("llama3:7b");
220        assert_eq!(provider, "ollama");
221        assert_eq!(model, "llama3:7b");
222    }
223
224    #[test]
225    fn test_model_spec_parsing() {
226        // Test various model spec formats
227        let specs = vec![
228            ("ollama/tinyllama", Some("ollama"), "tinyllama"),
229            ("qwen3-coder:30b", None, "qwen3-coder:30b"),
230            ("kimi-k2.5:cloud", None, "kimi-k2.5:cloud"),
231        ];
232
233        for (spec, expected_provider, expected_model) in specs {
234            let parts: Vec<&str> = spec.split('/').collect();
235            if parts.len() == 2 {
236                assert_eq!(Some(parts[0]), expected_provider);
237                assert_eq!(parts[1], expected_model);
238            } else {
239                assert_eq!(None, expected_provider);
240                assert_eq!(spec, expected_model);
241            }
242        }
243    }
244
245    #[test]
246    fn test_provider_extraction() {
247        fn extract_provider(spec: &str) -> Option<&str> {
248            spec.split('/').next().filter(|_| spec.contains('/'))
249        }
250
251        assert_eq!(extract_provider("ollama/tinyllama"), Some("ollama"));
252        assert_eq!(extract_provider("qwen3-coder:30b"), None);
253    }
254}