mermaid_cli/models/
factory.rs1use super::backend::ModelFactory as InternalFactory;
7use super::config::BackendConfig;
8use super::error::Result;
9use super::traits::Model;
10use crate::app::Config;
11
12pub struct ModelFactory;
14
15impl ModelFactory {
16 pub async fn create(model_id: &str, config: Option<&Config>) -> Result<Box<dyn Model>> {
23 let backend_config = if let Some(cfg) = config {
24 Self::config_to_backend_config(cfg)
25 } else {
26 BackendConfig::default()
27 };
28
29 let factory = InternalFactory::new(backend_config);
30 factory.create_model(model_id).await
31 }
32
33 pub async fn create_default(model_id: &str) -> Result<Box<dyn Model>> {
35 let factory = InternalFactory::new(BackendConfig::default());
36 factory.create_model(model_id).await
37 }
38
39 pub async fn create_with_provider(
46 model_id: &str,
47 config: Option<&Config>,
48 provider: Option<&str>,
49 ) -> Result<Box<dyn Model>> {
50 let backend_config = if let Some(cfg) = config {
51 Self::config_to_backend_config(cfg)
52 } else {
53 BackendConfig::default()
54 };
55
56 let final_model_id = if let Some(provider_name) = provider {
58 if model_id.contains('/') {
59 model_id.to_string()
61 } else {
62 format!("{}/{}", provider_name, model_id)
64 }
65 } else {
66 model_id.to_string()
67 };
68
69 let factory = InternalFactory::new(backend_config);
70 factory.create_model(&final_model_id).await
71 }
72
73 pub async fn create_with_backend(
75 model_id: &str,
76 config: Option<&Config>,
77 backend: Option<&str>,
78 ) -> Result<Box<dyn Model>> {
79 Self::create_with_provider(model_id, config, backend).await
80 }
81
82 pub async fn get_available_backends() -> Vec<String> {
84 let factory = InternalFactory::new(BackendConfig::default());
85 factory.available_providers().await
86 }
87
88 pub async fn list_all_backend_models() -> Result<Vec<String>> {
93 let factory = InternalFactory::new(BackendConfig::default());
94 let providers = factory.available_providers().await;
95
96 let mut all_models = Vec::new();
97
98 for provider in providers {
99 let dummy_model_id = format!("{}/dummy", provider);
101 if let Ok(model) = factory.create_model(&dummy_model_id).await {
102 if let Ok(models) = model.list_models().await {
103 for model_name in models {
104 all_models.push(format!("{}/{}", provider, model_name));
105 }
106 }
107 }
108 }
109
110 all_models.sort();
111 Ok(all_models)
112 }
113
114 fn config_to_backend_config(config: &Config) -> BackendConfig {
116 let ollama_url = format!("http://{}:{}", config.ollama.host, config.ollama.port);
118
119 BackendConfig {
120 ollama_url,
121 timeout_secs: 10,
122 request_timeout_secs: 120,
123 max_idle_per_host: 10,
124 health_check_interval_secs: 30,
125 }
126 }
127}
128
129#[cfg(test)]
130mod tests {
131 #[test]
132 fn test_model_spec_parsing() {
133 let specs = vec![
135 ("ollama/tinyllama", Some("ollama"), "tinyllama"),
136 ("qwen3-coder:30b", None, "qwen3-coder:30b"),
137 ("kimi-k2.5:cloud", None, "kimi-k2.5:cloud"),
138 ];
139
140 for (spec, expected_provider, expected_model) in specs {
141 let parts: Vec<&str> = spec.split('/').collect();
142 if parts.len() == 2 {
143 assert_eq!(Some(parts[0]), expected_provider);
144 assert_eq!(parts[1], expected_model);
145 } else {
146 assert_eq!(None, expected_provider);
147 assert_eq!(spec, expected_model);
148 }
149 }
150 }
151
152 #[test]
153 fn test_provider_extraction() {
154 fn extract_provider(spec: &str) -> Option<&str> {
155 spec.split('/').next().filter(|_| spec.contains('/'))
156 }
157
158 assert_eq!(extract_provider("ollama/tinyllama"), Some("ollama"));
159 assert_eq!(extract_provider("qwen3-coder:30b"), None);
160 }
161}