swarm-engine-llm 0.1.6

LLM integration backends for SwarmEngine
Documentation
//! Model Registry - Ollamaモデルの動的検出と管理
//!
//! Ollamaと連携してインストール済みモデルを自動検出

use ollama_rs::Ollama;
use std::sync::Arc;
use tokio::sync::RwLock;

/// モデル情報
#[derive(Debug, Clone)]
pub struct ModelInfo {
    /// モデル名 (e.g., "qwen2.5-coder:1.5b")
    pub name: String,
    /// サイズ (bytes)
    pub size_bytes: u64,
}

/// Model Registry
pub struct ModelRegistry {
    ollama: Ollama,
    /// キャッシュされたモデル一覧
    models: Arc<RwLock<Vec<ModelInfo>>>,
    /// エンドポイント
    endpoint: String,
}

impl ModelRegistry {
    /// 新規作成
    pub fn new(host: &str, port: u16) -> Self {
        let endpoint = format!("{}:{}", host, port);
        Self {
            ollama: Ollama::new(host.to_string(), port),
            models: Arc::new(RwLock::new(Vec::new())),
            endpoint,
        }
    }

    /// デフォルト (localhost:11434)
    pub fn default_local() -> Self {
        Self::new("http://localhost", 11434)
    }

    /// モデル一覧を検出してキャッシュ
    pub async fn discover(&self) -> Result<Vec<ModelInfo>, RegistryError> {
        let local_models = self
            .ollama
            .list_local_models()
            .await
            .map_err(|e| RegistryError::ConnectionFailed(e.to_string()))?;

        let models: Vec<ModelInfo> = local_models
            .into_iter()
            .map(|m| ModelInfo {
                name: m.name,
                size_bytes: m.size,
            })
            .collect();

        // キャッシュ更新
        {
            let mut cache = self.models.write().await;
            *cache = models.clone();
        }

        tracing::info!(
            endpoint = %self.endpoint,
            count = models.len(),
            "Discovered {} models",
            models.len()
        );

        Ok(models)
    }

    /// キャッシュからモデル一覧を取得
    pub async fn list(&self) -> Vec<ModelInfo> {
        self.models.read().await.clone()
    }

    /// モデル名で検索
    pub async fn get(&self, name: &str) -> Option<ModelInfo> {
        let models = self.models.read().await;
        models.iter().find(|m| m.name == name).cloned()
    }

    /// 名前のプレフィックスでフィルタ (e.g., "hf.co/LiquidAI", "qwen")
    pub async fn by_prefix(&self, prefix: &str) -> Vec<ModelInfo> {
        let models = self.models.read().await;
        models
            .iter()
            .filter(|m| m.name.starts_with(prefix))
            .cloned()
            .collect()
    }

    /// 名前に含まれる文字列でフィルタ
    pub async fn search(&self, query: &str) -> Vec<ModelInfo> {
        let query_lower = query.to_lowercase();
        let models = self.models.read().await;
        models
            .iter()
            .filter(|m| m.name.to_lowercase().contains(&query_lower))
            .cloned()
            .collect()
    }

    /// モデルが存在するか確認
    pub async fn exists(&self, name: &str) -> bool {
        self.get(name).await.is_some()
    }

    /// 最初に見つかったモデルを返す(フォールバック用)
    pub async fn first(&self) -> Option<ModelInfo> {
        let models = self.models.read().await;
        models.first().cloned()
    }

    /// モデル名を解決(存在確認 + フォールバック)
    pub async fn resolve(&self, preferred: &str) -> Result<ModelInfo, RegistryError> {
        // 優先モデルが存在すればそれを返す
        if let Some(model) = self.get(preferred).await {
            return Ok(model);
        }

        // なければ最初のモデルを返す
        self.first()
            .await
            .ok_or_else(|| RegistryError::NoModelsAvailable {
                requested: preferred.to_string(),
            })
    }

    /// エンドポイントを取得
    pub fn endpoint(&self) -> &str {
        &self.endpoint
    }
}

impl Default for ModelRegistry {
    fn default() -> Self {
        Self::default_local()
    }
}

/// Registry エラー
#[derive(Debug, thiserror::Error)]
pub enum RegistryError {
    #[error("Failed to connect to Ollama: {0}")]
    ConnectionFailed(String),

    #[error("Model '{requested}' not found and no fallback available")]
    NoModelsAvailable { requested: String },
}

impl From<RegistryError> for swarm_engine_core::error::SwarmError {
    fn from(err: RegistryError) -> Self {
        match err {
            RegistryError::ConnectionFailed(msg) => {
                swarm_engine_core::error::SwarmError::NetworkTransient { message: msg }
            }
            RegistryError::NoModelsAvailable { requested } => {
                swarm_engine_core::error::SwarmError::Config {
                    message: format!("Model '{}' not found", requested),
                }
            }
        }
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[tokio::test]
    async fn test_registry_creation() {
        let registry = ModelRegistry::default_local();
        assert_eq!(registry.endpoint(), "http://localhost:11434");
    }
}