mermaid_cli/models/
router.rs1use std::collections::HashMap;
7use std::sync::Arc;
8use tokio::sync::RwLock;
9
10use super::backend::{Backend, BackendFactory};
11use super::config::BackendConfig;
12use super::error::{ModelError, Result};
13
14pub struct BackendRouter {
16 factory: BackendFactory,
17 model_cache: Arc<RwLock<HashMap<String, String>>>,
19 backend_cache: Arc<RwLock<Option<Vec<String>>>>,
21}
22
23impl BackendRouter {
24 pub fn new(config: BackendConfig) -> Self {
26 Self {
27 factory: BackendFactory::new(config),
28 model_cache: Arc::new(RwLock::new(HashMap::new())),
29 backend_cache: Arc::new(RwLock::new(None)),
30 }
31 }
32
33 pub async fn resolve_model(&self, model_spec: &str) -> Result<(Arc<dyn Backend>, String)> {
40 let (backend_hint, model_name) = parse_model_spec(model_spec);
42
43 if let Some(backend_name) = backend_hint {
45 let backend = self.factory.create_backend(backend_name).await?;
46 return Ok((backend, model_name.to_string()));
47 }
48
49 {
51 let cache = self.model_cache.read().await;
52 if let Some(backend_name) = cache.get(model_name) {
53 let backend = self.factory.create_backend(backend_name).await?;
54 return Ok((backend, model_name.to_string()));
55 }
56 }
57
58 let backend = self.discover_model(model_name).await?;
60 Ok((backend, model_name.to_string()))
61 }
62
63 async fn discover_model(&self, model_name: &str) -> Result<Arc<dyn Backend>> {
65 let backends_to_try = vec!["ollama", "vllm"];
67
68 for backend_name in &backends_to_try {
69 if let Ok(backend) = self.factory.create_backend(backend_name).await {
70 if backend.health_check().await.is_ok() {
72 if let Ok(true) = backend.has_model(model_name).await {
74 let mut cache = self.model_cache.write().await;
76 cache.insert(model_name.to_string(), backend_name.to_string());
77 return Ok(backend);
78 }
79 }
80 }
81 }
82
83 let searched: Vec<String> = backends_to_try.iter().map(|s| s.to_string()).collect();
85 Err(ModelError::ModelNotFound {
86 model: model_name.to_string(),
87 searched,
88 })
89 }
90
91 pub async fn available_backends(&self) -> Vec<String> {
93 {
95 let cache = self.backend_cache.read().await;
96 if let Some(ref backends) = *cache {
97 return backends.clone();
98 }
99 }
100
101 let backends = self.factory.available_backends().await;
103
104 {
106 let mut cache = self.backend_cache.write().await;
107 *cache = Some(backends.clone());
108 }
109
110 backends
111 }
112
113 pub async fn list_all_models(&self) -> Result<HashMap<String, Vec<String>>> {
115 let mut all_models = HashMap::new();
116 let backends = self.available_backends().await;
117
118 for backend_name in backends {
119 if let Ok(backend) = self.factory.create_backend(&backend_name).await {
120 if let Ok(models) = backend.list_models().await {
121 all_models.insert(backend_name, models);
122 }
123 }
124 }
125
126 Ok(all_models)
127 }
128
129 pub async fn clear_cache(&self) {
131 let mut model_cache = self.model_cache.write().await;
132 model_cache.clear();
133
134 let mut backend_cache = self.backend_cache.write().await;
135 *backend_cache = None;
136 }
137}
138
139fn parse_model_spec(spec: &str) -> (Option<&str>, &str) {
146 if let Some(slash_pos) = spec.find('/') {
147 let backend = &spec[..slash_pos];
148 let model = &spec[slash_pos + 1..];
149 (Some(backend), model)
150 } else {
151 (None, spec)
152 }
153}
154
155#[cfg(test)]
156mod tests {
157 use super::*;
158
159 #[test]
160 fn test_parse_model_spec() {
161 assert_eq!(
162 parse_model_spec("ollama/tinyllama"),
163 (Some("ollama"), "tinyllama")
164 );
165 assert_eq!(
166 parse_model_spec("qwen3-coder:30b"),
167 (None, "qwen3-coder:30b")
168 );
169 assert_eq!(parse_model_spec("gpt-4"), (None, "gpt-4"));
170 assert_eq!(
171 parse_model_spec("vllm/llama-3-70b"),
172 (Some("vllm"), "llama-3-70b")
173 );
174 }
175}