argyph-embed 1.0.4

Local-first MCP server giving AI coding agents fast, structured, and semantic context over any codebase.
Documentation
use std::path::{Path, PathBuf};

use tokio::io::AsyncWriteExt;
use tracing;

use crate::error::{EmbedError, Result};
use crate::model_hashes;

const BGE_SMALL_MODEL_ID: &str = "bge-small-en-v1.5";
const HF_BASE: &str = "https://huggingface.co/BAAI/bge-small-en-v1.5/resolve/main";

const ONNX_FILENAME: &str = "model.onnx";
const TOKENIZER_FILENAME: &str = "tokenizer.json";

#[derive(Debug)]
pub struct ModelFiles {
    pub onnx_path: PathBuf,
    pub tokenizer_path: PathBuf,
}

impl ModelFiles {
    pub async fn ensure_available(model_id: &str, cache_dir: Option<&Path>) -> Result<ModelFiles> {
        if model_id != BGE_SMALL_MODEL_ID {
            return Err(EmbedError::Config(format!(
                "unknown local model: {model_id}"
            )));
        }

        let cache = cache_dir
            .map(PathBuf::from)
            .unwrap_or_else(Self::default_cache_dir);
        let model_dir = cache.join(model_id);

        let onnx_path = model_dir.join(ONNX_FILENAME);
        let tokenizer_path = model_dir.join(TOKENIZER_FILENAME);

        if Self::needs_download(&model_dir).await {
            tracing::info!(
                model_id = %model_id,
                cache_dir = %model_dir.display(),
                "downloading local model files"
            );

            tokio::fs::create_dir_all(&model_dir).await.map_err(|e| {
                EmbedError::Config(format!(
                    "failed to create cache dir {}: {e}",
                    model_dir.display()
                ))
            })?;

            Self::download_and_verify(
                &format!("{HF_BASE}/onnx/{ONNX_FILENAME}"),
                &onnx_path,
                model_hashes::BGE_SMALL_ONNX_SHA256,
            )
            .await?;

            Self::download_and_verify(
                &format!("{HF_BASE}/{TOKENIZER_FILENAME}"),
                &tokenizer_path,
                model_hashes::BGE_SMALL_TOKENIZER_SHA256,
            )
            .await?;

            tracing::info!(
                model_id = %model_id,
                "model files downloaded and verified"
            );
        }

        Ok(ModelFiles {
            onnx_path: model_dir.join(ONNX_FILENAME),
            tokenizer_path: model_dir.join(TOKENIZER_FILENAME),
        })
    }

    fn default_cache_dir() -> PathBuf {
        let home = dirs_next().unwrap_or_else(|| PathBuf::from("."));
        home.join(".cache").join("argyph").join("models")
    }

    async fn needs_download(model_dir: &Path) -> bool {
        let onnx = model_dir.join(ONNX_FILENAME);
        let tok = model_dir.join(TOKENIZER_FILENAME);

        let onnx_ok = Self::file_hash_matches(&onnx, model_hashes::BGE_SMALL_ONNX_SHA256).await;
        let tok_ok = Self::file_hash_matches(&tok, model_hashes::BGE_SMALL_TOKENIZER_SHA256).await;

        !(onnx_ok && tok_ok)
    }

    async fn file_hash_matches(path: &Path, expected_hex: &str) -> bool {
        match tokio::fs::read(path).await {
            Ok(data) => {
                use sha2::Digest;
                let hash = sha2::Sha256::digest(&data);
                let hex = hex::encode(hash);
                hex == expected_hex
            }
            Err(_) => false,
        }
    }

    async fn download_and_verify(url: &str, dest: &Path, expected_sha256: &str) -> Result<()> {
        let tmp = dest.with_extension("tmp");

        tracing::info!(%url, "downloading");
        let response = reqwest::get(url)
            .await
            .map_err(|e| EmbedError::Config(format!("failed to download {url}: {e}")))?;

        if !response.status().is_success() {
            return Err(EmbedError::Config(format!(
                "download failed for {url}: HTTP {}",
                response.status().as_u16()
            )));
        }

        let bytes = response
            .bytes()
            .await
            .map_err(|e| EmbedError::Config(format!("failed to read response for {url}: {e}")))?;

        {
            use sha2::Digest;
            let hash = sha2::Sha256::digest(&bytes);
            let hex = hex::encode(hash);
            if hex != expected_sha256 {
                return Err(EmbedError::Config(format!(
                    "SHA-256 mismatch for {url}: expected {expected_sha256}, got {hex}"
                )));
            }
        }

        let mut f = tokio::fs::File::create(&tmp).await.map_err(|e| {
            EmbedError::Config(format!("failed to create temp file {}: {e}", tmp.display()))
        })?;
        f.write_all(&bytes).await.map_err(|e| {
            EmbedError::Config(format!("failed to write temp file {}: {e}", tmp.display()))
        })?;
        f.flush().await.map_err(|e| {
            EmbedError::Config(format!("failed to flush temp file {}: {e}", tmp.display()))
        })?;
        drop(f);

        tokio::fs::rename(&tmp, dest).await.map_err(|e| {
            EmbedError::Config(format!(
                "failed to rename {} -> {}: {e}",
                tmp.display(),
                dest.display()
            ))
        })?;

        tracing::info!(%url, "verified and cached");
        Ok(())
    }
}

fn dirs_next() -> Option<PathBuf> {
    std::env::var("HOME")
        .ok()
        .or({
            #[cfg(target_os = "windows")]
            {
                let drive = std::env::var("HOMEDRIVE").unwrap_or_default();
                let path = std::env::var("HOMEPATH").unwrap_or_default();
                if drive.is_empty() || path.is_empty() {
                    None
                } else {
                    Some(format!("{drive}{path}"))
                }
            }
            #[cfg(not(target_os = "windows"))]
            {
                None
            }
        })
        .map(PathBuf::from)
}

#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used)]
mod tests {
    use super::*;

    #[tokio::test]
    async fn unknown_model_id_returns_config_error() {
        let result = ModelFiles::ensure_available("unknown-model", None).await;
        assert!(result.is_err());
        match result.unwrap_err() {
            EmbedError::Config(msg) => assert!(msg.contains("unknown")),
            other => panic!("expected Config error, got: {other:?}"),
        }
    }

    #[tokio::test]
    async fn needs_download_true_for_empty_dir() {
        let dir = std::env::temp_dir().join("argyph_test_empty");
        let _ = std::fs::remove_dir_all(&dir);
        assert!(ModelFiles::needs_download(&dir).await);
    }
}