use ollama_rs::Ollama;
use std::sync::Arc;
use tokio::sync::RwLock;
#[derive(Debug, Clone)]
pub struct ModelInfo {
pub name: String,
pub size_bytes: u64,
}
pub struct ModelRegistry {
ollama: Ollama,
models: Arc<RwLock<Vec<ModelInfo>>>,
endpoint: String,
}
impl ModelRegistry {
pub fn new(host: &str, port: u16) -> Self {
let endpoint = format!("{}:{}", host, port);
Self {
ollama: Ollama::new(host.to_string(), port),
models: Arc::new(RwLock::new(Vec::new())),
endpoint,
}
}
pub fn default_local() -> Self {
Self::new("http://localhost", 11434)
}
pub async fn discover(&self) -> Result<Vec<ModelInfo>, RegistryError> {
let local_models = self
.ollama
.list_local_models()
.await
.map_err(|e| RegistryError::ConnectionFailed(e.to_string()))?;
let models: Vec<ModelInfo> = local_models
.into_iter()
.map(|m| ModelInfo {
name: m.name,
size_bytes: m.size,
})
.collect();
{
let mut cache = self.models.write().await;
*cache = models.clone();
}
tracing::info!(
endpoint = %self.endpoint,
count = models.len(),
"Discovered {} models",
models.len()
);
Ok(models)
}
pub async fn list(&self) -> Vec<ModelInfo> {
self.models.read().await.clone()
}
pub async fn get(&self, name: &str) -> Option<ModelInfo> {
let models = self.models.read().await;
models.iter().find(|m| m.name == name).cloned()
}
pub async fn by_prefix(&self, prefix: &str) -> Vec<ModelInfo> {
let models = self.models.read().await;
models
.iter()
.filter(|m| m.name.starts_with(prefix))
.cloned()
.collect()
}
pub async fn search(&self, query: &str) -> Vec<ModelInfo> {
let query_lower = query.to_lowercase();
let models = self.models.read().await;
models
.iter()
.filter(|m| m.name.to_lowercase().contains(&query_lower))
.cloned()
.collect()
}
pub async fn exists(&self, name: &str) -> bool {
self.get(name).await.is_some()
}
pub async fn first(&self) -> Option<ModelInfo> {
let models = self.models.read().await;
models.first().cloned()
}
pub async fn resolve(&self, preferred: &str) -> Result<ModelInfo, RegistryError> {
if let Some(model) = self.get(preferred).await {
return Ok(model);
}
self.first()
.await
.ok_or_else(|| RegistryError::NoModelsAvailable {
requested: preferred.to_string(),
})
}
pub fn endpoint(&self) -> &str {
&self.endpoint
}
}
impl Default for ModelRegistry {
fn default() -> Self {
Self::default_local()
}
}
#[derive(Debug, thiserror::Error)]
pub enum RegistryError {
#[error("Failed to connect to Ollama: {0}")]
ConnectionFailed(String),
#[error("Model '{requested}' not found and no fallback available")]
NoModelsAvailable { requested: String },
}
impl From<RegistryError> for swarm_engine_core::error::SwarmError {
fn from(err: RegistryError) -> Self {
match err {
RegistryError::ConnectionFailed(msg) => {
swarm_engine_core::error::SwarmError::NetworkTransient { message: msg }
}
RegistryError::NoModelsAvailable { requested } => {
swarm_engine_core::error::SwarmError::Config {
message: format!("Model '{}' not found", requested),
}
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_registry_creation() {
let registry = ModelRegistry::default_local();
assert_eq!(registry.endpoint(), "http://localhost:11434");
}
}