mermaid_cli/models/
backend.rs1use std::sync::Arc;
8use std::time::Duration;
9
10use super::config::BackendConfig;
11use super::error::{BackendError, ModelError, Result};
12use super::traits::Model;
13use crate::app::Config;
14
15pub struct ModelFactory {
17 config: Arc<BackendConfig>,
18}
19
20impl ModelFactory {
21 pub fn new(config: BackendConfig) -> Self {
23 Self {
24 config: Arc::new(config),
25 }
26 }
27
28 pub fn from_config(config: &Config) -> Self {
30 Self::new(Self::config_to_backend_config(config))
31 }
32
33 pub async fn create_model(&self, model_id: &str) -> Result<Box<dyn Model>> {
35 let (provider, model_name) = parse_model_id(model_id);
37
38 match provider.to_lowercase().as_str() {
39 "ollama" => {
40 use super::adapters::ollama::OllamaAdapter;
41 let adapter = OllamaAdapter::new(model_name, self.config.clone()).await?;
42 Ok(Box::new(adapter))
43 },
44 _ => Err(ModelError::InvalidRequest(format!(
45 "Unknown provider: {}. Only ollama/ is supported.",
46 provider
47 ))),
48 }
49 }
50
51 pub async fn create(model_id: &str, config: Option<&Config>) -> Result<Box<dyn Model>> {
60 let backend_config = config
61 .map(Self::config_to_backend_config)
62 .unwrap_or_default();
63 let factory = Self::new(backend_config);
64 factory.create_model(model_id).await
65 }
66
67 pub async fn list_all_models() -> Result<Vec<String>> {
72 let factory = Self::new(BackendConfig::default());
73 let providers = factory.available_providers_impl().await;
74
75 let mut all_models = Vec::new();
76 for provider in providers {
77 if let Ok(models) = factory.list_models(&provider).await {
78 for model_name in models {
79 all_models.push(format!("{}/{}", provider, model_name));
80 }
81 }
82 }
83
84 all_models.sort();
85 Ok(all_models)
86 }
87
88 pub async fn available_providers() -> Vec<String> {
90 let factory = Self::new(BackendConfig::default());
91 factory.available_providers_impl().await
92 }
93
94 pub async fn available_providers_pub(&self) -> Vec<String> {
96 self.available_providers_impl().await
97 }
98
99 async fn available_providers_impl(&self) -> Vec<String> {
103 let mut providers = Vec::new();
104
105 let url = format!(
107 "{}/api/tags",
108 self.config.ollama_url.trim().trim_end_matches('/')
109 );
110 if let Ok(client) = reqwest::Client::builder()
111 .timeout(Duration::from_secs(2))
112 .build()
113 && let Ok(resp) = client.get(&url).send().await
114 && resp.status().is_success()
115 {
116 providers.push("ollama".to_string());
117 }
118
119 providers
120 }
121
122 pub async fn list_models(&self, provider: &str) -> Result<Vec<String>> {
124 match provider {
125 "ollama" => {
126 let url = format!(
127 "{}/api/tags",
128 self.config.ollama_url.trim().trim_end_matches('/')
129 );
130 let client = reqwest::Client::builder()
131 .timeout(Duration::from_secs(5))
132 .build()
133 .map_err(|e| {
134 ModelError::Backend(BackendError::ConnectionFailed {
135 backend: "ollama".to_string(),
136 url: url.clone(),
137 reason: e.to_string(),
138 })
139 })?;
140 let response = client.get(&url).send().await.map_err(|e| {
141 ModelError::Backend(BackendError::ConnectionFailed {
142 backend: "ollama".to_string(),
143 url: url.clone(),
144 reason: e.to_string(),
145 })
146 })?;
147 if !response.status().is_success() {
148 return Err(ModelError::Backend(BackendError::HttpError {
149 status: response.status().as_u16(),
150 message: "Failed to list models".to_string(),
151 }));
152 }
153 let tags: super::adapters::ollama::OllamaTagsResponse =
154 response.json().await.map_err(|e| ModelError::ParseError {
155 message: format!("Failed to parse tags response: {}", e),
156 raw: None,
157 })?;
158 Ok(tags.models.into_iter().map(|m| m.name).collect())
159 },
160 _ => Err(ModelError::InvalidRequest(format!(
161 "Unknown provider: {}",
162 provider
163 ))),
164 }
165 }
166
167 fn config_to_backend_config(config: &Config) -> BackendConfig {
169 let ollama_url = format!("http://{}:{}", config.ollama.host, config.ollama.port);
170
171 BackendConfig {
172 ollama_url,
173 timeout_secs: 10,
174 max_idle_per_host: 10,
175 }
176 }
177}
178
179fn parse_model_id(model_id: &str) -> (&str, &str) {
186 if let Some(idx) = model_id.find('/') {
187 let provider = &model_id[..idx];
189 let model = &model_id[idx + 1..];
190 (provider, model)
191 } else {
192 ("ollama", model_id)
194 }
195}
196
197#[cfg(test)]
198mod tests {
199 use super::*;
200
201 #[test]
202 fn test_parse_model_id_with_provider() {
203 let (provider, model) = parse_model_id("ollama/llama3");
204 assert_eq!(provider, "ollama");
205 assert_eq!(model, "llama3");
206 }
207
208 #[test]
209 fn test_parse_model_id_bare_name() {
210 let (provider, model) = parse_model_id("llama3");
211 assert_eq!(provider, "ollama");
212 assert_eq!(model, "llama3");
213 }
214
215 #[test]
216 fn test_parse_model_id_with_tag() {
217 let (provider, model) = parse_model_id("ollama/llama3:latest");
218 assert_eq!(provider, "ollama");
219 assert_eq!(model, "llama3:latest");
220 }
221
222 #[test]
223 fn test_parse_model_id_bare_with_tag() {
224 let (provider, model) = parse_model_id("llama3:7b");
225 assert_eq!(provider, "ollama");
226 assert_eq!(model, "llama3:7b");
227 }
228
229 #[test]
230 fn test_model_spec_parsing() {
231 let specs = vec![
233 ("ollama/tinyllama", Some("ollama"), "tinyllama"),
234 ("qwen3-coder:30b", None, "qwen3-coder:30b"),
235 ("kimi-k2.5:cloud", None, "kimi-k2.5:cloud"),
236 ];
237
238 for (spec, expected_provider, expected_model) in specs {
239 let parts: Vec<&str> = spec.split('/').collect();
240 if parts.len() == 2 {
241 assert_eq!(Some(parts[0]), expected_provider);
242 assert_eq!(parts[1], expected_model);
243 } else {
244 assert_eq!(None, expected_provider);
245 assert_eq!(spec, expected_model);
246 }
247 }
248 }
249
250 #[test]
251 fn test_provider_extraction() {
252 fn extract_provider(spec: &str) -> Option<&str> {
253 spec.split('/').next().filter(|_| spec.contains('/'))
254 }
255
256 assert_eq!(extract_provider("ollama/tinyllama"), Some("ollama"));
257 assert_eq!(extract_provider("qwen3-coder:30b"), None);
258 }
259}