Skip to main content

argyph_embed/
model_files.rs

1use std::path::{Path, PathBuf};
2
3use tokio::io::AsyncWriteExt;
4use tracing;
5
6use crate::error::{EmbedError, Result};
7use crate::model_hashes;
8
9const BGE_SMALL_MODEL_ID: &str = "bge-small-en-v1.5";
10const HF_BASE: &str = "https://huggingface.co/BAAI/bge-small-en-v1.5/resolve/main";
11
12const ONNX_FILENAME: &str = "model.onnx";
13const TOKENIZER_FILENAME: &str = "tokenizer.json";
14
15#[derive(Debug)]
16pub struct ModelFiles {
17    pub onnx_path: PathBuf,
18    pub tokenizer_path: PathBuf,
19}
20
21impl ModelFiles {
22    pub async fn ensure_available(model_id: &str, cache_dir: Option<&Path>) -> Result<ModelFiles> {
23        if model_id != BGE_SMALL_MODEL_ID {
24            return Err(EmbedError::Config(format!(
25                "unknown local model: {model_id}"
26            )));
27        }
28
29        let cache = cache_dir
30            .map(PathBuf::from)
31            .unwrap_or_else(Self::default_cache_dir);
32        let model_dir = cache.join(model_id);
33
34        let onnx_path = model_dir.join(ONNX_FILENAME);
35        let tokenizer_path = model_dir.join(TOKENIZER_FILENAME);
36
37        if Self::needs_download(&model_dir).await {
38            tracing::info!(
39                model_id = %model_id,
40                cache_dir = %model_dir.display(),
41                "downloading local model files"
42            );
43
44            tokio::fs::create_dir_all(&model_dir).await.map_err(|e| {
45                EmbedError::Config(format!(
46                    "failed to create cache dir {}: {e}",
47                    model_dir.display()
48                ))
49            })?;
50
51            Self::download_and_verify(
52                &format!("{HF_BASE}/onnx/{ONNX_FILENAME}"),
53                &onnx_path,
54                model_hashes::BGE_SMALL_ONNX_SHA256,
55            )
56            .await?;
57
58            Self::download_and_verify(
59                &format!("{HF_BASE}/{TOKENIZER_FILENAME}"),
60                &tokenizer_path,
61                model_hashes::BGE_SMALL_TOKENIZER_SHA256,
62            )
63            .await?;
64
65            tracing::info!(
66                model_id = %model_id,
67                "model files downloaded and verified"
68            );
69        }
70
71        Ok(ModelFiles {
72            onnx_path: model_dir.join(ONNX_FILENAME),
73            tokenizer_path: model_dir.join(TOKENIZER_FILENAME),
74        })
75    }
76
77    fn default_cache_dir() -> PathBuf {
78        let home = dirs_next().unwrap_or_else(|| PathBuf::from("."));
79        home.join(".cache").join("argyph").join("models")
80    }
81
82    async fn needs_download(model_dir: &Path) -> bool {
83        let onnx = model_dir.join(ONNX_FILENAME);
84        let tok = model_dir.join(TOKENIZER_FILENAME);
85
86        let onnx_ok = Self::file_hash_matches(&onnx, model_hashes::BGE_SMALL_ONNX_SHA256).await;
87        let tok_ok = Self::file_hash_matches(&tok, model_hashes::BGE_SMALL_TOKENIZER_SHA256).await;
88
89        !(onnx_ok && tok_ok)
90    }
91
92    async fn file_hash_matches(path: &Path, expected_hex: &str) -> bool {
93        match tokio::fs::read(path).await {
94            Ok(data) => {
95                use sha2::Digest;
96                let hash = sha2::Sha256::digest(&data);
97                let hex = hex::encode(hash);
98                hex == expected_hex
99            }
100            Err(_) => false,
101        }
102    }
103
104    async fn download_and_verify(url: &str, dest: &Path, expected_sha256: &str) -> Result<()> {
105        let tmp = dest.with_extension("tmp");
106
107        tracing::info!(%url, "downloading");
108        let response = reqwest::get(url)
109            .await
110            .map_err(|e| EmbedError::Config(format!("failed to download {url}: {e}")))?;
111
112        if !response.status().is_success() {
113            return Err(EmbedError::Config(format!(
114                "download failed for {url}: HTTP {}",
115                response.status().as_u16()
116            )));
117        }
118
119        let bytes = response
120            .bytes()
121            .await
122            .map_err(|e| EmbedError::Config(format!("failed to read response for {url}: {e}")))?;
123
124        {
125            use sha2::Digest;
126            let hash = sha2::Sha256::digest(&bytes);
127            let hex = hex::encode(hash);
128            if hex != expected_sha256 {
129                return Err(EmbedError::Config(format!(
130                    "SHA-256 mismatch for {url}: expected {expected_sha256}, got {hex}"
131                )));
132            }
133        }
134
135        let mut f = tokio::fs::File::create(&tmp).await.map_err(|e| {
136            EmbedError::Config(format!("failed to create temp file {}: {e}", tmp.display()))
137        })?;
138        f.write_all(&bytes).await.map_err(|e| {
139            EmbedError::Config(format!("failed to write temp file {}: {e}", tmp.display()))
140        })?;
141        f.flush().await.map_err(|e| {
142            EmbedError::Config(format!("failed to flush temp file {}: {e}", tmp.display()))
143        })?;
144        drop(f);
145
146        tokio::fs::rename(&tmp, dest).await.map_err(|e| {
147            EmbedError::Config(format!(
148                "failed to rename {} -> {}: {e}",
149                tmp.display(),
150                dest.display()
151            ))
152        })?;
153
154        tracing::info!(%url, "verified and cached");
155        Ok(())
156    }
157}
158
159fn dirs_next() -> Option<PathBuf> {
160    std::env::var("HOME")
161        .ok()
162        .or({
163            #[cfg(target_os = "windows")]
164            {
165                let drive = std::env::var("HOMEDRIVE").unwrap_or_default();
166                let path = std::env::var("HOMEPATH").unwrap_or_default();
167                if drive.is_empty() || path.is_empty() {
168                    None
169                } else {
170                    Some(format!("{drive}{path}"))
171                }
172            }
173            #[cfg(not(target_os = "windows"))]
174            {
175                None
176            }
177        })
178        .map(PathBuf::from)
179}
180
181#[cfg(test)]
182#[allow(clippy::unwrap_used, clippy::expect_used)]
183mod tests {
184    use super::*;
185
186    #[tokio::test]
187    async fn unknown_model_id_returns_config_error() {
188        let result = ModelFiles::ensure_available("unknown-model", None).await;
189        assert!(result.is_err());
190        match result.unwrap_err() {
191            EmbedError::Config(msg) => assert!(msg.contains("unknown")),
192            other => panic!("expected Config error, got: {other:?}"),
193        }
194    }
195
196    #[tokio::test]
197    async fn needs_download_true_for_empty_dir() {
198        let dir = std::env::temp_dir().join("argyph_test_empty");
199        let _ = std::fs::remove_dir_all(&dir);
200        assert!(ModelFiles::needs_download(&dir).await);
201    }
202}