mermaid_cli/models/
factory.rs1use super::config::BackendConfig;
7use super::error::Result;
8use super::model::{create_model, create_model_default};
9use super::router::BackendRouter;
10use super::traits::Model;
11use crate::app::Config;
12
13pub struct ModelFactory;
15
16impl ModelFactory {
17 pub async fn create(model_id: &str, config: Option<&Config>) -> Result<Box<dyn Model>> {
27 let backend_config = if let Some(cfg) = config {
28 Self::config_to_backend_config(cfg)
29 } else {
30 BackendConfig::default()
31 };
32
33 create_model(model_id, backend_config).await
34 }
35
36 pub async fn create_default(model_id: &str) -> Result<Box<dyn Model>> {
38 create_model_default(model_id).await
39 }
40
41 pub async fn create_with_backend(
49 model_id: &str,
50 config: Option<&Config>,
51 backend: Option<&str>,
52 ) -> Result<Box<dyn Model>> {
53 let backend_config = if let Some(cfg) = config {
54 Self::config_to_backend_config(cfg)
55 } else {
56 BackendConfig::default()
57 };
58
59 let final_model_id = if let Some(backend_name) = backend {
61 if model_id.contains('/') {
62 model_id.to_string()
64 } else {
65 format!("{}/{}", backend_name, model_id)
67 }
68 } else {
69 model_id.to_string()
70 };
71
72 create_model(&final_model_id, backend_config).await
73 }
74
75 pub async fn list_all_backend_models() -> Result<Vec<String>> {
77 let router = BackendRouter::new(BackendConfig::default());
78 let all_models = router.list_all_models().await?;
79
80 let mut model_list = Vec::new();
82 for (backend_name, models) in all_models {
83 for model in models {
84 model_list.push(format!("{}/{}", backend_name, model));
85 }
86 }
87
88 model_list.sort();
89 Ok(model_list)
90 }
91
92 pub async fn list_available() -> Result<Vec<String>> {
94 Self::list_all_backend_models().await
95 }
96
97 pub async fn get_available_backends() -> Vec<String> {
99 let router = BackendRouter::new(BackendConfig::default());
100 router.available_backends().await
101 }
102
103 pub async fn validate(model_id: &str, config: Option<&Config>) -> Result<bool> {
105 match Self::create(model_id, config).await {
106 Ok(model) => model.validate_connection().await,
107 Err(_) => Ok(false),
108 }
109 }
110
111 fn config_to_backend_config(config: &Config) -> BackendConfig {
113 let ollama_url = format!("http://{}:{}", config.ollama.host, config.ollama.port);
115
116 BackendConfig {
117 ollama_url,
118 vllm_url: std::env::var("VLLM_API_BASE")
119 .unwrap_or_else(|_| "http://localhost:8000".to_string()),
120 litellm_url: config.litellm.proxy_url.clone(),
121 litellm_master_key: config.litellm.master_key.clone(),
122 timeout_secs: 10,
123 request_timeout_secs: 120,
124 max_idle_per_host: 10,
125 health_check_interval_secs: 30,
126 }
127 }
128}
129
130#[cfg(test)]
131mod tests {
132 use super::*;
133
134 #[test]
135 fn test_model_spec_parsing() {
136 let specs = vec![
138 ("ollama/tinyllama", Some("ollama"), "tinyllama"),
139 ("qwen3-coder:30b", None, "qwen3-coder:30b"),
140 ("gpt-4", None, "gpt-4"),
141 ];
142
143 for (spec, expected_backend, expected_model) in specs {
144 let parts: Vec<&str> = spec.split('/').collect();
145 if parts.len() == 2 {
146 assert_eq!(Some(parts[0]), expected_backend);
147 assert_eq!(parts[1], expected_model);
148 } else {
149 assert_eq!(None, expected_backend);
150 assert_eq!(spec, expected_model);
151 }
152 }
153 }
154
155 #[test]
156 fn test_ollama_provider_detection() {
157 assert!("ollama/tinyllama".starts_with("ollama/"));
158 assert!("ollama/llama2".starts_with("ollama/"));
159 assert!(!"openai/gpt-4".starts_with("ollama/"));
160 assert!(!"qwen3-coder:30b".starts_with("ollama/"));
161 }
162
163 #[test]
164 fn test_provider_extraction() {
165 fn extract_provider(spec: &str) -> Option<&str> {
166 spec.split('/').next().filter(|_| spec.contains('/'))
167 }
168
169 assert_eq!(extract_provider("ollama/tinyllama"), Some("ollama"));
170 assert_eq!(extract_provider("vllm/gpt-4"), Some("vllm"));
171 assert_eq!(extract_provider("qwen3-coder:30b"), None);
172 }
173
174 #[test]
175 fn test_model_name_extraction() {
176 fn extract_model(spec: &str) -> &str {
177 if let Some(pos) = spec.find('/') {
178 &spec[pos + 1..]
179 } else {
180 spec
181 }
182 }
183
184 assert_eq!(extract_model("ollama/tinyllama"), "tinyllama");
185 assert_eq!(extract_model("vllm/gpt-4"), "gpt-4");
186 assert_eq!(extract_model("qwen3-coder:30b"), "qwen3-coder:30b");
187 }
188}