Skip to main content

mermaid_cli/models/
backend.rs

1//! Model factory - creates model instances from identifiers
2//!
3//! Parses model identifiers like "ollama/llama3" and creates
4//! the appropriate adapter implementing the Model trait.
5//! Also provides static convenience methods for common operations.
6
7use std::sync::Arc;
8use std::time::Duration;
9
10use super::config::BackendConfig;
11use super::error::{BackendError, ModelError, Result};
12use super::traits::Model;
13use crate::app::Config;
14
15/// Model factory - creates model instances
16pub struct ModelFactory {
17    config: Arc<BackendConfig>,
18}
19
20impl ModelFactory {
21    /// Create a new model factory with explicit backend config
22    pub fn new(config: BackendConfig) -> Self {
23        Self {
24            config: Arc::new(config),
25        }
26    }
27
28    /// Create a factory from app::Config
29    pub fn from_config(config: &Config) -> Self {
30        Self::new(Self::config_to_backend_config(config))
31    }
32
33    /// Create a model from a full identifier (e.g., "ollama/llama3")
34    pub async fn create_model(&self, model_id: &str) -> Result<Box<dyn Model>> {
35        // Parse model identifier: "provider/model_name" or just "model_name" (defaults to ollama)
36        let (provider, model_name) = parse_model_id(model_id);
37
38        match provider.to_lowercase().as_str() {
39            "ollama" => {
40                use super::adapters::ollama::OllamaAdapter;
41                let adapter = OllamaAdapter::new(model_name, self.config.clone()).await?;
42                Ok(Box::new(adapter))
43            },
44            _ => Err(ModelError::InvalidRequest(format!(
45                "Unknown provider: {}. Only ollama/ is supported.",
46                provider
47            ))),
48        }
49    }
50
51    // --- Static convenience API (absorbed from factory.rs) ---
52
53    /// Create a model instance from a model identifier with optional app config
54    ///
55    /// Format examples:
56    /// - "ollama/qwen3-coder:30b" - Explicit Ollama provider
57    /// - "qwen3-coder:30b" - Defaults to Ollama
58    /// - "kimi-k2.5:cloud" - Ollama cloud model
59    pub async fn create(model_id: &str, config: Option<&Config>) -> Result<Box<dyn Model>> {
60        let backend_config = config
61            .map(Self::config_to_backend_config)
62            .unwrap_or_default();
63        let factory = Self::new(backend_config);
64        factory.create_model(model_id).await
65    }
66
67    /// List all models from all available providers
68    ///
69    /// Returns a list of model identifiers in "provider/model" format.
70    /// Only includes providers that are currently available.
71    pub async fn list_all_models() -> Result<Vec<String>> {
72        let factory = Self::new(BackendConfig::default());
73        let providers = factory.available_providers_impl().await;
74
75        let mut all_models = Vec::new();
76        for provider in providers {
77            if let Ok(models) = factory.list_models(&provider).await {
78                for model_name in models {
79                    all_models.push(format!("{}/{}", provider, model_name));
80                }
81            }
82        }
83
84        all_models.sort();
85        Ok(all_models)
86    }
87
88    /// Get list of available providers (static convenience)
89    pub async fn available_providers() -> Vec<String> {
90        let factory = Self::new(BackendConfig::default());
91        factory.available_providers_impl().await
92    }
93
94    /// List available providers using this factory's config (instance method)
95    pub async fn available_providers_pub(&self) -> Vec<String> {
96        self.available_providers_impl().await
97    }
98
99    // --- Instance methods ---
100
101    /// List available providers with a fast single-shot health check
102    async fn available_providers_impl(&self) -> Vec<String> {
103        let mut providers = Vec::new();
104
105        // Quick Ollama check: single GET with 2s timeout, no retries
106        let url = format!(
107            "{}/api/tags",
108            self.config.ollama_url.trim().trim_end_matches('/')
109        );
110        if let Ok(client) = reqwest::Client::builder()
111            .timeout(Duration::from_secs(2))
112            .build()
113            && let Ok(resp) = client.get(&url).send().await
114            && resp.status().is_success()
115        {
116            providers.push("ollama".to_string());
117        }
118
119        providers
120    }
121
122    /// List all models from a provider without creating a model instance
123    pub async fn list_models(&self, provider: &str) -> Result<Vec<String>> {
124        match provider {
125            "ollama" => {
126                let url = format!(
127                    "{}/api/tags",
128                    self.config.ollama_url.trim().trim_end_matches('/')
129                );
130                let client = reqwest::Client::builder()
131                    .timeout(Duration::from_secs(5))
132                    .build()
133                    .map_err(|e| {
134                        ModelError::Backend(BackendError::ConnectionFailed {
135                            backend: "ollama".to_string(),
136                            url: url.clone(),
137                            reason: e.to_string(),
138                        })
139                    })?;
140                let response = client.get(&url).send().await.map_err(|e| {
141                    ModelError::Backend(BackendError::ConnectionFailed {
142                        backend: "ollama".to_string(),
143                        url: url.clone(),
144                        reason: e.to_string(),
145                    })
146                })?;
147                if !response.status().is_success() {
148                    return Err(ModelError::Backend(BackendError::HttpError {
149                        status: response.status().as_u16(),
150                        message: "Failed to list models".to_string(),
151                    }));
152                }
153                let tags: super::adapters::ollama::OllamaTagsResponse =
154                    response.json().await.map_err(|e| ModelError::ParseError {
155                        message: format!("Failed to parse tags response: {}", e),
156                        raw: None,
157                    })?;
158                Ok(tags.models.into_iter().map(|m| m.name).collect())
159            },
160            _ => Err(ModelError::InvalidRequest(format!(
161                "Unknown provider: {}",
162                provider
163            ))),
164        }
165    }
166
167    /// Convert app::Config to BackendConfig
168    fn config_to_backend_config(config: &Config) -> BackendConfig {
169        let ollama_url = format!("http://{}:{}", config.ollama.host, config.ollama.port);
170
171        BackendConfig {
172            ollama_url,
173            timeout_secs: 10,
174            max_idle_per_host: 10,
175        }
176    }
177}
178
179/// Parse a model identifier into provider and model name
180///
181/// Formats:
182/// - "ollama/llama3" -> ("ollama", "llama3")
183/// - "llama3" -> ("ollama", "llama3")  // defaults to ollama
184/// - "llama3:latest" -> ("ollama", "llama3:latest")  // ollama tag format
185fn parse_model_id(model_id: &str) -> (&str, &str) {
186    if let Some(idx) = model_id.find('/') {
187        // Safe: '/' is ASCII, so byte offset == char offset
188        let provider = &model_id[..idx];
189        let model = &model_id[idx + 1..];
190        (provider, model)
191    } else {
192        // Default to ollama for bare model names
193        ("ollama", model_id)
194    }
195}
196
197#[cfg(test)]
198mod tests {
199    use super::*;
200
201    #[test]
202    fn test_parse_model_id_with_provider() {
203        let (provider, model) = parse_model_id("ollama/llama3");
204        assert_eq!(provider, "ollama");
205        assert_eq!(model, "llama3");
206    }
207
208    #[test]
209    fn test_parse_model_id_bare_name() {
210        let (provider, model) = parse_model_id("llama3");
211        assert_eq!(provider, "ollama");
212        assert_eq!(model, "llama3");
213    }
214
215    #[test]
216    fn test_parse_model_id_with_tag() {
217        let (provider, model) = parse_model_id("ollama/llama3:latest");
218        assert_eq!(provider, "ollama");
219        assert_eq!(model, "llama3:latest");
220    }
221
222    #[test]
223    fn test_parse_model_id_bare_with_tag() {
224        let (provider, model) = parse_model_id("llama3:7b");
225        assert_eq!(provider, "ollama");
226        assert_eq!(model, "llama3:7b");
227    }
228
229    #[test]
230    fn test_model_spec_parsing() {
231        // Test various model spec formats
232        let specs = vec![
233            ("ollama/tinyllama", Some("ollama"), "tinyllama"),
234            ("qwen3-coder:30b", None, "qwen3-coder:30b"),
235            ("kimi-k2.5:cloud", None, "kimi-k2.5:cloud"),
236        ];
237
238        for (spec, expected_provider, expected_model) in specs {
239            let parts: Vec<&str> = spec.split('/').collect();
240            if parts.len() == 2 {
241                assert_eq!(Some(parts[0]), expected_provider);
242                assert_eq!(parts[1], expected_model);
243            } else {
244                assert_eq!(None, expected_provider);
245                assert_eq!(spec, expected_model);
246            }
247        }
248    }
249
250    #[test]
251    fn test_provider_extraction() {
252        fn extract_provider(spec: &str) -> Option<&str> {
253            spec.split('/').next().filter(|_| spec.contains('/'))
254        }
255
256        assert_eq!(extract_provider("ollama/tinyllama"), Some("ollama"));
257        assert_eq!(extract_provider("qwen3-coder:30b"), None);
258    }
259}