mermaid_cli/models/
factory.rs1use super::backend::ModelFactory as InternalFactory;
7use super::config::BackendConfig;
8use super::error::Result;
9use super::traits::Model;
10use crate::app::Config;
11
12pub struct ModelFactory;
14
15impl ModelFactory {
16 pub async fn create(model_id: &str, config: Option<&Config>) -> Result<Box<dyn Model>> {
23 let backend_config = if let Some(cfg) = config {
24 Self::config_to_backend_config(cfg)
25 } else {
26 BackendConfig::default()
27 };
28
29 let factory = InternalFactory::new(backend_config);
30 factory.create_model(model_id).await
31 }
32
33 pub async fn create_default(model_id: &str) -> Result<Box<dyn Model>> {
35 let factory = InternalFactory::new(BackendConfig::default());
36 factory.create_model(model_id).await
37 }
38
39 pub async fn create_with_provider(
46 model_id: &str,
47 config: Option<&Config>,
48 provider: Option<&str>,
49 ) -> Result<Box<dyn Model>> {
50 let backend_config = if let Some(cfg) = config {
51 Self::config_to_backend_config(cfg)
52 } else {
53 BackendConfig::default()
54 };
55
56 let final_model_id = if let Some(provider_name) = provider {
58 if model_id.contains('/') {
59 model_id.to_string()
61 } else {
62 format!("{}/{}", provider_name, model_id)
64 }
65 } else {
66 model_id.to_string()
67 };
68
69 let factory = InternalFactory::new(backend_config);
70 factory.create_model(&final_model_id).await
71 }
72
73 pub async fn create_with_backend(
75 model_id: &str,
76 config: Option<&Config>,
77 backend: Option<&str>,
78 ) -> Result<Box<dyn Model>> {
79 Self::create_with_provider(model_id, config, backend).await
80 }
81
82 pub async fn get_available_backends() -> Vec<String> {
84 let factory = InternalFactory::new(BackendConfig::default());
85 factory.available_providers().await
86 }
87
88 pub async fn list_all_backend_models() -> Result<Vec<String>> {
93 let factory = InternalFactory::new(BackendConfig::default());
94 let providers = factory.available_providers().await;
95
96 let mut all_models = Vec::new();
97
98 for provider in providers {
99 let dummy_model_id = format!("{}/dummy", provider);
101 if let Ok(model) = factory.create_model(&dummy_model_id).await
102 && let Ok(models) = model.list_models().await {
103 for model_name in models {
104 all_models.push(format!("{}/{}", provider, model_name));
105 }
106 }
107 }
108
109 all_models.sort();
110 Ok(all_models)
111 }
112
113 fn config_to_backend_config(config: &Config) -> BackendConfig {
115 let ollama_url = format!("http://{}:{}", config.ollama.host, config.ollama.port);
117
118 BackendConfig {
119 ollama_url,
120 timeout_secs: 10,
121 request_timeout_secs: 120,
122 max_idle_per_host: 10,
123 health_check_interval_secs: 30,
124 }
125 }
126}
127
128#[cfg(test)]
129mod tests {
130 #[test]
131 fn test_model_spec_parsing() {
132 let specs = vec![
134 ("ollama/tinyllama", Some("ollama"), "tinyllama"),
135 ("qwen3-coder:30b", None, "qwen3-coder:30b"),
136 ("kimi-k2.5:cloud", None, "kimi-k2.5:cloud"),
137 ];
138
139 for (spec, expected_provider, expected_model) in specs {
140 let parts: Vec<&str> = spec.split('/').collect();
141 if parts.len() == 2 {
142 assert_eq!(Some(parts[0]), expected_provider);
143 assert_eq!(parts[1], expected_model);
144 } else {
145 assert_eq!(None, expected_provider);
146 assert_eq!(spec, expected_model);
147 }
148 }
149 }
150
151 #[test]
152 fn test_provider_extraction() {
153 fn extract_provider(spec: &str) -> Option<&str> {
154 spec.split('/').next().filter(|_| spec.contains('/'))
155 }
156
157 assert_eq!(extract_provider("ollama/tinyllama"), Some("ollama"));
158 assert_eq!(extract_provider("qwen3-coder:30b"), None);
159 }
160}