mermaid_cli/models/
backend.rs1use std::sync::Arc;
8use std::time::Duration;
9
10use super::config::BackendConfig;
11use super::error::{BackendError, ModelError, Result};
12use super::traits::Model;
13use crate::app::Config;
14
15pub struct ModelFactory {
17 config: Arc<BackendConfig>,
18}
19
20impl ModelFactory {
21 pub fn new(config: BackendConfig) -> Self {
23 Self {
24 config: Arc::new(config),
25 }
26 }
27
28 pub fn from_config(config: &Config) -> Self {
30 Self::new(Self::config_to_backend_config(config))
31 }
32
33 pub async fn create_model(&self, model_id: &str) -> Result<Box<dyn Model>> {
35 let (provider, model_name) = parse_model_id(model_id);
37
38 match provider.to_lowercase().as_str() {
39 "ollama" => {
40 use super::adapters::ollama::OllamaAdapter;
41 let adapter = OllamaAdapter::new(model_name, self.config.clone()).await?;
42 Ok(Box::new(adapter))
43 },
44 _ => Err(ModelError::InvalidRequest(format!(
45 "Unknown provider: {}. Only ollama/ is supported.",
46 provider
47 ))),
48 }
49 }
50
51 pub async fn create(model_id: &str, config: Option<&Config>) -> Result<Box<dyn Model>> {
60 let backend_config = config
61 .map(Self::config_to_backend_config)
62 .unwrap_or_default();
63 let factory = Self::new(backend_config);
64 factory.create_model(model_id).await
65 }
66
67 pub async fn list_all_models() -> Result<Vec<String>> {
72 let factory = Self::new(BackendConfig::default());
73 let providers = factory.available_providers_impl().await;
74
75 let mut all_models = Vec::new();
76 for provider in providers {
77 if let Ok(models) = factory.list_models(&provider).await {
78 for model_name in models {
79 all_models.push(format!("{}/{}", provider, model_name));
80 }
81 }
82 }
83
84 all_models.sort();
85 Ok(all_models)
86 }
87
88 pub async fn available_providers() -> Vec<String> {
90 let factory = Self::new(BackendConfig::default());
91 factory.available_providers_impl().await
92 }
93
94 async fn available_providers_impl(&self) -> Vec<String> {
98 let mut providers = Vec::new();
99
100 let url = format!(
102 "{}/api/tags",
103 self.config.ollama_url.trim().trim_end_matches('/')
104 );
105 if let Ok(client) = reqwest::Client::builder()
106 .timeout(Duration::from_secs(2))
107 .build()
108 && let Ok(resp) = client.get(&url).send().await
109 && resp.status().is_success()
110 {
111 providers.push("ollama".to_string());
112 }
113
114 providers
115 }
116
117 pub async fn list_models(&self, provider: &str) -> Result<Vec<String>> {
119 match provider {
120 "ollama" => {
121 let url = format!(
122 "{}/api/tags",
123 self.config.ollama_url.trim().trim_end_matches('/')
124 );
125 let client = reqwest::Client::builder()
126 .timeout(Duration::from_secs(5))
127 .build()
128 .map_err(|e| {
129 ModelError::Backend(BackendError::ConnectionFailed {
130 backend: "ollama".to_string(),
131 url: url.clone(),
132 reason: e.to_string(),
133 })
134 })?;
135 let response = client.get(&url).send().await.map_err(|e| {
136 ModelError::Backend(BackendError::ConnectionFailed {
137 backend: "ollama".to_string(),
138 url: url.clone(),
139 reason: e.to_string(),
140 })
141 })?;
142 if !response.status().is_success() {
143 return Err(ModelError::Backend(BackendError::HttpError {
144 status: response.status().as_u16(),
145 message: "Failed to list models".to_string(),
146 }));
147 }
148 let tags: super::adapters::ollama::OllamaTagsResponse =
149 response.json().await.map_err(|e| ModelError::ParseError {
150 message: format!("Failed to parse tags response: {}", e),
151 raw: None,
152 })?;
153 Ok(tags.models.into_iter().map(|m| m.name).collect())
154 },
155 _ => Err(ModelError::InvalidRequest(format!(
156 "Unknown provider: {}",
157 provider
158 ))),
159 }
160 }
161
162 fn config_to_backend_config(config: &Config) -> BackendConfig {
164 let ollama_url = format!("http://{}:{}", config.ollama.host, config.ollama.port);
165
166 BackendConfig {
167 ollama_url,
168 timeout_secs: 10,
169 max_idle_per_host: 10,
170 }
171 }
172}
173
174fn parse_model_id(model_id: &str) -> (&str, &str) {
181 if let Some(idx) = model_id.find('/') {
182 let provider = &model_id[..idx];
184 let model = &model_id[idx + 1..];
185 (provider, model)
186 } else {
187 ("ollama", model_id)
189 }
190}
191
192#[cfg(test)]
193mod tests {
194 use super::*;
195
196 #[test]
197 fn test_parse_model_id_with_provider() {
198 let (provider, model) = parse_model_id("ollama/llama3");
199 assert_eq!(provider, "ollama");
200 assert_eq!(model, "llama3");
201 }
202
203 #[test]
204 fn test_parse_model_id_bare_name() {
205 let (provider, model) = parse_model_id("llama3");
206 assert_eq!(provider, "ollama");
207 assert_eq!(model, "llama3");
208 }
209
210 #[test]
211 fn test_parse_model_id_with_tag() {
212 let (provider, model) = parse_model_id("ollama/llama3:latest");
213 assert_eq!(provider, "ollama");
214 assert_eq!(model, "llama3:latest");
215 }
216
217 #[test]
218 fn test_parse_model_id_bare_with_tag() {
219 let (provider, model) = parse_model_id("llama3:7b");
220 assert_eq!(provider, "ollama");
221 assert_eq!(model, "llama3:7b");
222 }
223
224 #[test]
225 fn test_model_spec_parsing() {
226 let specs = vec![
228 ("ollama/tinyllama", Some("ollama"), "tinyllama"),
229 ("qwen3-coder:30b", None, "qwen3-coder:30b"),
230 ("kimi-k2.5:cloud", None, "kimi-k2.5:cloud"),
231 ];
232
233 for (spec, expected_provider, expected_model) in specs {
234 let parts: Vec<&str> = spec.split('/').collect();
235 if parts.len() == 2 {
236 assert_eq!(Some(parts[0]), expected_provider);
237 assert_eq!(parts[1], expected_model);
238 } else {
239 assert_eq!(None, expected_provider);
240 assert_eq!(spec, expected_model);
241 }
242 }
243 }
244
245 #[test]
246 fn test_provider_extraction() {
247 fn extract_provider(spec: &str) -> Option<&str> {
248 spec.split('/').next().filter(|_| spec.contains('/'))
249 }
250
251 assert_eq!(extract_provider("ollama/tinyllama"), Some("ollama"));
252 assert_eq!(extract_provider("qwen3-coder:30b"), None);
253 }
254}