Skip to main content

mermaid_cli/models/
factory.rs

1/// Factory for creating model instances
2///
3/// This factory provides the public API for creating models. It handles
4/// configuration conversion and delegates to the internal ModelFactory.
5
6use super::backend::ModelFactory as InternalFactory;
7use super::config::BackendConfig;
8use super::error::Result;
9use super::traits::Model;
10use crate::app::Config;
11
12/// Factory for creating model instances
13pub struct ModelFactory;
14
15impl ModelFactory {
16    /// Create a model instance from a model identifier
17    ///
18    /// Format examples:
19    /// - "ollama/qwen3-coder:30b" - Explicit Ollama provider
20    /// - "qwen3-coder:30b" - Defaults to Ollama
21    /// - "kimi-k2.5:cloud" - Ollama cloud model
22    pub async fn create(model_id: &str, config: Option<&Config>) -> Result<Box<dyn Model>> {
23        let backend_config = if let Some(cfg) = config {
24            Self::config_to_backend_config(cfg)
25        } else {
26            BackendConfig::default()
27        };
28
29        let factory = InternalFactory::new(backend_config);
30        factory.create_model(model_id).await
31    }
32
33    /// Create a model with default configuration
34    pub async fn create_default(model_id: &str) -> Result<Box<dyn Model>> {
35        let factory = InternalFactory::new(BackendConfig::default());
36        factory.create_model(model_id).await
37    }
38
39    /// Create a model with explicit provider preference
40    ///
41    /// If provider is specified, the model_id will be prefixed with the provider name
42    /// if it's not already specified. For example:
43    /// - provider="ollama", model_id="tinyllama" -> "ollama/tinyllama"
44    /// - provider=None, model_id="qwen3-coder:30b" -> defaults to ollama
45    pub async fn create_with_provider(
46        model_id: &str,
47        config: Option<&Config>,
48        provider: Option<&str>,
49    ) -> Result<Box<dyn Model>> {
50        let backend_config = if let Some(cfg) = config {
51            Self::config_to_backend_config(cfg)
52        } else {
53            BackendConfig::default()
54        };
55
56        // If provider is explicitly specified, prefix the model_id
57        let final_model_id = if let Some(provider_name) = provider {
58            if model_id.contains('/') {
59                // Already has a provider prefix
60                model_id.to_string()
61            } else {
62                // Add provider prefix
63                format!("{}/{}", provider_name, model_id)
64            }
65        } else {
66            model_id.to_string()
67        };
68
69        let factory = InternalFactory::new(backend_config);
70        factory.create_model(&final_model_id).await
71    }
72
73    /// Create a model with explicit backend preference (alias for create_with_provider)
74    pub async fn create_with_backend(
75        model_id: &str,
76        config: Option<&Config>,
77        backend: Option<&str>,
78    ) -> Result<Box<dyn Model>> {
79        Self::create_with_provider(model_id, config, backend).await
80    }
81
82    /// Get available backends (providers)
83    pub async fn get_available_backends() -> Vec<String> {
84        let factory = InternalFactory::new(BackendConfig::default());
85        factory.available_providers().await
86    }
87
88    /// List all models from all available backends
89    ///
90    /// Returns a list of model identifiers in "provider/model" format.
91    /// Only includes backends that are currently available.
92    pub async fn list_all_backend_models() -> Result<Vec<String>> {
93        let factory = InternalFactory::new(BackendConfig::default());
94        let providers = factory.available_providers().await;
95
96        let mut all_models = Vec::new();
97
98        for provider in providers {
99            // Create a dummy model to list models from this provider
100            let dummy_model_id = format!("{}/dummy", provider);
101            if let Ok(model) = factory.create_model(&dummy_model_id).await {
102                if let Ok(models) = model.list_models().await {
103                    for model_name in models {
104                        all_models.push(format!("{}/{}", provider, model_name));
105                    }
106                }
107            }
108        }
109
110        all_models.sort();
111        Ok(all_models)
112    }
113
114    /// Convert app::Config to BackendConfig
115    fn config_to_backend_config(config: &Config) -> BackendConfig {
116        // Construct Ollama URL from host and port
117        let ollama_url = format!("http://{}:{}", config.ollama.host, config.ollama.port);
118
119        BackendConfig {
120            ollama_url,
121            timeout_secs: 10,
122            request_timeout_secs: 120,
123            max_idle_per_host: 10,
124            health_check_interval_secs: 30,
125        }
126    }
127}
128
129#[cfg(test)]
130mod tests {
131    #[test]
132    fn test_model_spec_parsing() {
133        // Test various model spec formats
134        let specs = vec![
135            ("ollama/tinyllama", Some("ollama"), "tinyllama"),
136            ("qwen3-coder:30b", None, "qwen3-coder:30b"),
137            ("kimi-k2.5:cloud", None, "kimi-k2.5:cloud"),
138        ];
139
140        for (spec, expected_provider, expected_model) in specs {
141            let parts: Vec<&str> = spec.split('/').collect();
142            if parts.len() == 2 {
143                assert_eq!(Some(parts[0]), expected_provider);
144                assert_eq!(parts[1], expected_model);
145            } else {
146                assert_eq!(None, expected_provider);
147                assert_eq!(spec, expected_model);
148            }
149        }
150    }
151
152    #[test]
153    fn test_provider_extraction() {
154        fn extract_provider(spec: &str) -> Option<&str> {
155            spec.split('/').next().filter(|_| spec.contains('/'))
156        }
157
158        assert_eq!(extract_provider("ollama/tinyllama"), Some("ollama"));
159        assert_eq!(extract_provider("qwen3-coder:30b"), None);
160    }
161}