swarm_engine_llm/
registry.rs1use ollama_rs::Ollama;
6use std::sync::Arc;
7use tokio::sync::RwLock;
8
9#[derive(Debug, Clone)]
11pub struct ModelInfo {
12 pub name: String,
14 pub size_bytes: u64,
16}
17
18pub struct ModelRegistry {
20 ollama: Ollama,
21 models: Arc<RwLock<Vec<ModelInfo>>>,
23 endpoint: String,
25}
26
27impl ModelRegistry {
28 pub fn new(host: &str, port: u16) -> Self {
30 let endpoint = format!("{}:{}", host, port);
31 Self {
32 ollama: Ollama::new(host.to_string(), port),
33 models: Arc::new(RwLock::new(Vec::new())),
34 endpoint,
35 }
36 }
37
38 pub fn default_local() -> Self {
40 Self::new("http://localhost", 11434)
41 }
42
43 pub async fn discover(&self) -> Result<Vec<ModelInfo>, RegistryError> {
45 let local_models = self
46 .ollama
47 .list_local_models()
48 .await
49 .map_err(|e| RegistryError::ConnectionFailed(e.to_string()))?;
50
51 let models: Vec<ModelInfo> = local_models
52 .into_iter()
53 .map(|m| ModelInfo {
54 name: m.name,
55 size_bytes: m.size,
56 })
57 .collect();
58
59 {
61 let mut cache = self.models.write().await;
62 *cache = models.clone();
63 }
64
65 tracing::info!(
66 endpoint = %self.endpoint,
67 count = models.len(),
68 "Discovered {} models",
69 models.len()
70 );
71
72 Ok(models)
73 }
74
75 pub async fn list(&self) -> Vec<ModelInfo> {
77 self.models.read().await.clone()
78 }
79
80 pub async fn get(&self, name: &str) -> Option<ModelInfo> {
82 let models = self.models.read().await;
83 models.iter().find(|m| m.name == name).cloned()
84 }
85
86 pub async fn by_prefix(&self, prefix: &str) -> Vec<ModelInfo> {
88 let models = self.models.read().await;
89 models
90 .iter()
91 .filter(|m| m.name.starts_with(prefix))
92 .cloned()
93 .collect()
94 }
95
96 pub async fn search(&self, query: &str) -> Vec<ModelInfo> {
98 let query_lower = query.to_lowercase();
99 let models = self.models.read().await;
100 models
101 .iter()
102 .filter(|m| m.name.to_lowercase().contains(&query_lower))
103 .cloned()
104 .collect()
105 }
106
107 pub async fn exists(&self, name: &str) -> bool {
109 self.get(name).await.is_some()
110 }
111
112 pub async fn first(&self) -> Option<ModelInfo> {
114 let models = self.models.read().await;
115 models.first().cloned()
116 }
117
118 pub async fn resolve(&self, preferred: &str) -> Result<ModelInfo, RegistryError> {
120 if let Some(model) = self.get(preferred).await {
122 return Ok(model);
123 }
124
125 self.first()
127 .await
128 .ok_or_else(|| RegistryError::NoModelsAvailable {
129 requested: preferred.to_string(),
130 })
131 }
132
133 pub fn endpoint(&self) -> &str {
135 &self.endpoint
136 }
137}
138
139impl Default for ModelRegistry {
140 fn default() -> Self {
141 Self::default_local()
142 }
143}
144
145#[derive(Debug, thiserror::Error)]
147pub enum RegistryError {
148 #[error("Failed to connect to Ollama: {0}")]
149 ConnectionFailed(String),
150
151 #[error("Model '{requested}' not found and no fallback available")]
152 NoModelsAvailable { requested: String },
153}
154
155impl From<RegistryError> for swarm_engine_core::error::SwarmError {
156 fn from(err: RegistryError) -> Self {
157 match err {
158 RegistryError::ConnectionFailed(msg) => {
159 swarm_engine_core::error::SwarmError::NetworkTransient { message: msg }
160 }
161 RegistryError::NoModelsAvailable { requested } => {
162 swarm_engine_core::error::SwarmError::Config {
163 message: format!("Model '{}' not found", requested),
164 }
165 }
166 }
167 }
168}
169
170#[cfg(test)]
171mod tests {
172 use super::*;
173
174 #[tokio::test]
175 async fn test_registry_creation() {
176 let registry = ModelRegistry::default_local();
177 assert_eq!(registry.endpoint(), "http://localhost:11434");
178 }
179}