mermaid_cli/models/
router.rs

1/// Backend router with lazy discovery
2///
3/// Intelligently routes model requests to the appropriate backend without
4/// heavy upfront scanning. Discovers backends on-demand and caches results.
5
6use std::collections::HashMap;
7use std::sync::Arc;
8use tokio::sync::RwLock;
9
10use super::backend::{Backend, BackendFactory};
11use super::config::BackendConfig;
12use super::error::{ModelError, Result};
13
14/// Smart router for backend discovery and selection
15pub struct BackendRouter {
16    factory: BackendFactory,
17    /// Cache of model -> backend mappings
18    model_cache: Arc<RwLock<HashMap<String, String>>>,
19    /// Cache of available backends
20    backend_cache: Arc<RwLock<Option<Vec<String>>>>,
21}
22
23impl BackendRouter {
24    /// Create a new backend router
25    pub fn new(config: BackendConfig) -> Self {
26        Self {
27            factory: BackendFactory::new(config),
28            model_cache: Arc::new(RwLock::new(HashMap::new())),
29            backend_cache: Arc::new(RwLock::new(None)),
30        }
31    }
32
33    /// Resolve a model string to a backend
34    ///
35    /// Format examples:
36    /// - "ollama/qwen3-coder:30b" -> Explicit backend
37    /// - "qwen3-coder:30b" -> Search backends
38    /// - "gpt-4" -> Search backends (likely vLLM or LiteLLM)
39    pub async fn resolve_model(&self, model_spec: &str) -> Result<(Arc<dyn Backend>, String)> {
40        // Parse model spec
41        let (backend_hint, model_name) = parse_model_spec(model_spec);
42
43        // If backend is explicitly specified, use it
44        if let Some(backend_name) = backend_hint {
45            let backend = self.factory.create_backend(backend_name).await?;
46            return Ok((backend, model_name.to_string()));
47        }
48
49        // Check cache first
50        {
51            let cache = self.model_cache.read().await;
52            if let Some(backend_name) = cache.get(model_name) {
53                let backend = self.factory.create_backend(backend_name).await?;
54                return Ok((backend, model_name.to_string()));
55            }
56        }
57
58        // No cache hit - discover backends lazily
59        let backend = self.discover_model(model_name).await?;
60        Ok((backend, model_name.to_string()))
61    }
62
63    /// Discover which backend has a specific model
64    async fn discover_model(&self, model_name: &str) -> Result<Arc<dyn Backend>> {
65        // Try backends in priority order: Ollama, vLLM
66        let backends_to_try = vec!["ollama", "vllm"];
67
68        for backend_name in &backends_to_try {
69            if let Ok(backend) = self.factory.create_backend(backend_name).await {
70                // Check if backend is available
71                if backend.health_check().await.is_ok() {
72                    // Check if model exists
73                    if let Ok(true) = backend.has_model(model_name).await {
74                        // Cache the result
75                        let mut cache = self.model_cache.write().await;
76                        cache.insert(model_name.to_string(), backend_name.to_string());
77                        return Ok(backend);
78                    }
79                }
80            }
81        }
82
83        // Model not found on any backend
84        let searched: Vec<String> = backends_to_try.iter().map(|s| s.to_string()).collect();
85        Err(ModelError::ModelNotFound {
86            model: model_name.to_string(),
87            searched,
88        })
89    }
90
91    /// Get list of available backends
92    pub async fn available_backends(&self) -> Vec<String> {
93        // Check cache first
94        {
95            let cache = self.backend_cache.read().await;
96            if let Some(ref backends) = *cache {
97                return backends.clone();
98            }
99        }
100
101        // Discover available backends
102        let backends = self.factory.available_backends().await;
103
104        // Cache the result
105        {
106            let mut cache = self.backend_cache.write().await;
107            *cache = Some(backends.clone());
108        }
109
110        backends
111    }
112
113    /// List all models from all available backends
114    pub async fn list_all_models(&self) -> Result<HashMap<String, Vec<String>>> {
115        let mut all_models = HashMap::new();
116        let backends = self.available_backends().await;
117
118        for backend_name in backends {
119            if let Ok(backend) = self.factory.create_backend(&backend_name).await {
120                if let Ok(models) = backend.list_models().await {
121                    all_models.insert(backend_name, models);
122                }
123            }
124        }
125
126        Ok(all_models)
127    }
128
129    /// Clear caches (useful for testing or when backends change)
130    pub async fn clear_cache(&self) {
131        let mut model_cache = self.model_cache.write().await;
132        model_cache.clear();
133
134        let mut backend_cache = self.backend_cache.write().await;
135        *backend_cache = None;
136    }
137}
138
139/// Parse a model specification into (backend_hint, model_name)
140///
141/// Examples:
142/// - "ollama/qwen3-coder:30b" -> (Some("ollama"), "qwen3-coder:30b")
143/// - "qwen3-coder:30b" -> (None, "qwen3-coder:30b")
144/// - "gpt-4" -> (None, "gpt-4")
145fn parse_model_spec(spec: &str) -> (Option<&str>, &str) {
146    if let Some(slash_pos) = spec.find('/') {
147        let backend = &spec[..slash_pos];
148        let model = &spec[slash_pos + 1..];
149        (Some(backend), model)
150    } else {
151        (None, spec)
152    }
153}
154
155#[cfg(test)]
156mod tests {
157    use super::*;
158
159    #[test]
160    fn test_parse_model_spec() {
161        assert_eq!(
162            parse_model_spec("ollama/tinyllama"),
163            (Some("ollama"), "tinyllama")
164        );
165        assert_eq!(
166            parse_model_spec("qwen3-coder:30b"),
167            (None, "qwen3-coder:30b")
168        );
169        assert_eq!(parse_model_spec("gpt-4"), (None, "gpt-4"));
170        assert_eq!(
171            parse_model_spec("vllm/llama-3-70b"),
172            (Some("vllm"), "llama-3-70b")
173        );
174    }
175}