mermaid_cli/models/
factory.rs

1/// Factory for creating model instances with unified backend architecture
2///
3/// This factory provides a simple API for creating models that automatically
4/// handle backend discovery, routing, and configuration.
5
6use super::config::BackendConfig;
7use super::error::Result;
8use super::model::{create_model, create_model_default};
9use super::router::BackendRouter;
10use super::traits::Model;
11use crate::app::Config;
12
13/// Factory for creating model instances
14pub struct ModelFactory;
15
16impl ModelFactory {
17    /// Create a model instance from a model identifier
18    ///
19    /// Format examples:
20    /// - "ollama/qwen3-coder:30b" - Explicit backend
21    /// - "qwen3-coder:30b" - Auto-detect backend
22    /// - "gpt-4" - Search backends (likely vLLM or LiteLLM)
23    ///
24    /// The factory will automatically discover available backends and route
25    /// the request appropriately.
26    pub async fn create(model_id: &str, config: Option<&Config>) -> Result<Box<dyn Model>> {
27        let backend_config = if let Some(cfg) = config {
28            Self::config_to_backend_config(cfg)
29        } else {
30            BackendConfig::default()
31        };
32
33        create_model(model_id, backend_config).await
34    }
35
36    /// Create a model with default configuration
37    pub async fn create_default(model_id: &str) -> Result<Box<dyn Model>> {
38        create_model_default(model_id).await
39    }
40
41    /// Create a model with explicit backend preference
42    ///
43    /// If backend is specified, the model_id will be prefixed with the backend name
44    /// if it's not already specified. For example:
45    /// - backend="ollama", model_id="tinyllama" -> "ollama/tinyllama"
46    /// - backend="vllm", model_id="gpt-4" -> "vllm/gpt-4"
47    /// - backend=None, model_id="qwen3-coder:30b" -> auto-detect backend
48    pub async fn create_with_backend(
49        model_id: &str,
50        config: Option<&Config>,
51        backend: Option<&str>,
52    ) -> Result<Box<dyn Model>> {
53        let backend_config = if let Some(cfg) = config {
54            Self::config_to_backend_config(cfg)
55        } else {
56            BackendConfig::default()
57        };
58
59        // If backend is explicitly specified, prefix the model_id
60        let final_model_id = if let Some(backend_name) = backend {
61            if model_id.contains('/') {
62                // Already has a provider prefix
63                model_id.to_string()
64            } else {
65                // Add backend prefix
66                format!("{}/{}", backend_name, model_id)
67            }
68        } else {
69            model_id.to_string()
70        };
71
72        create_model(&final_model_id, backend_config).await
73    }
74
75    /// List all available models from all backends
76    pub async fn list_all_backend_models() -> Result<Vec<String>> {
77        let router = BackendRouter::new(BackendConfig::default());
78        let all_models = router.list_all_models().await?;
79
80        // Flatten the backend -> models map into a single list with backend prefix
81        let mut model_list = Vec::new();
82        for (backend_name, models) in all_models {
83            for model in models {
84                model_list.push(format!("{}/{}", backend_name, model));
85            }
86        }
87
88        model_list.sort();
89        Ok(model_list)
90    }
91
92    /// List available models (alias for list_all_backend_models)
93    pub async fn list_available() -> Result<Vec<String>> {
94        Self::list_all_backend_models().await
95    }
96
97    /// Get available backends
98    pub async fn get_available_backends() -> Vec<String> {
99        let router = BackendRouter::new(BackendConfig::default());
100        router.available_backends().await
101    }
102
103    /// Validate that a model is accessible
104    pub async fn validate(model_id: &str, config: Option<&Config>) -> Result<bool> {
105        match Self::create(model_id, config).await {
106            Ok(model) => model.validate_connection().await,
107            Err(_) => Ok(false),
108        }
109    }
110
111    /// Convert app::Config to BackendConfig
112    fn config_to_backend_config(config: &Config) -> BackendConfig {
113        // Construct Ollama URL from host and port
114        let ollama_url = format!("http://{}:{}", config.ollama.host, config.ollama.port);
115
116        BackendConfig {
117            ollama_url,
118            vllm_url: std::env::var("VLLM_API_BASE")
119                .unwrap_or_else(|_| "http://localhost:8000".to_string()),
120            litellm_url: config.litellm.proxy_url.clone(),
121            litellm_master_key: config.litellm.master_key.clone(),
122            timeout_secs: 10,
123            request_timeout_secs: 120,
124            max_idle_per_host: 10,
125            health_check_interval_secs: 30,
126        }
127    }
128}
129
130#[cfg(test)]
131mod tests {
132    use super::*;
133
134    #[test]
135    fn test_model_spec_parsing() {
136        // Test various model spec formats
137        let specs = vec![
138            ("ollama/tinyllama", Some("ollama"), "tinyllama"),
139            ("qwen3-coder:30b", None, "qwen3-coder:30b"),
140            ("gpt-4", None, "gpt-4"),
141        ];
142
143        for (spec, expected_backend, expected_model) in specs {
144            let parts: Vec<&str> = spec.split('/').collect();
145            if parts.len() == 2 {
146                assert_eq!(Some(parts[0]), expected_backend);
147                assert_eq!(parts[1], expected_model);
148            } else {
149                assert_eq!(None, expected_backend);
150                assert_eq!(spec, expected_model);
151            }
152        }
153    }
154
155    #[test]
156    fn test_ollama_provider_detection() {
157        assert!("ollama/tinyllama".starts_with("ollama/"));
158        assert!("ollama/llama2".starts_with("ollama/"));
159        assert!(!"openai/gpt-4".starts_with("ollama/"));
160        assert!(!"qwen3-coder:30b".starts_with("ollama/"));
161    }
162
163    #[test]
164    fn test_provider_extraction() {
165        fn extract_provider(spec: &str) -> Option<&str> {
166            spec.split('/').next().filter(|_| spec.contains('/'))
167        }
168
169        assert_eq!(extract_provider("ollama/tinyllama"), Some("ollama"));
170        assert_eq!(extract_provider("vllm/gpt-4"), Some("vllm"));
171        assert_eq!(extract_provider("qwen3-coder:30b"), None);
172    }
173
174    #[test]
175    fn test_model_name_extraction() {
176        fn extract_model(spec: &str) -> &str {
177            if let Some(pos) = spec.find('/') {
178                &spec[pos + 1..]
179            } else {
180                spec
181            }
182        }
183
184        assert_eq!(extract_model("ollama/tinyllama"), "tinyllama");
185        assert_eq!(extract_model("vllm/gpt-4"), "gpt-4");
186        assert_eq!(extract_model("qwen3-coder:30b"), "qwen3-coder:30b");
187    }
188}