argyph_embed/
model_files.rs1use std::path::{Path, PathBuf};
2
3use tokio::io::AsyncWriteExt;
4use tracing;
5
6use crate::error::{EmbedError, Result};
7use crate::model_hashes;
8
9const BGE_SMALL_MODEL_ID: &str = "bge-small-en-v1.5";
10const HF_BASE: &str = "https://huggingface.co/BAAI/bge-small-en-v1.5/resolve/main";
11
12const ONNX_FILENAME: &str = "model.onnx";
13const TOKENIZER_FILENAME: &str = "tokenizer.json";
14
15#[derive(Debug)]
16pub struct ModelFiles {
17 pub onnx_path: PathBuf,
18 pub tokenizer_path: PathBuf,
19}
20
21impl ModelFiles {
22 pub async fn ensure_available(model_id: &str, cache_dir: Option<&Path>) -> Result<ModelFiles> {
23 if model_id != BGE_SMALL_MODEL_ID {
24 return Err(EmbedError::Config(format!(
25 "unknown local model: {model_id}"
26 )));
27 }
28
29 let cache = cache_dir
30 .map(PathBuf::from)
31 .unwrap_or_else(Self::default_cache_dir);
32 let model_dir = cache.join(model_id);
33
34 let onnx_path = model_dir.join(ONNX_FILENAME);
35 let tokenizer_path = model_dir.join(TOKENIZER_FILENAME);
36
37 if Self::needs_download(&model_dir).await {
38 tracing::info!(
39 model_id = %model_id,
40 cache_dir = %model_dir.display(),
41 "downloading local model files"
42 );
43
44 tokio::fs::create_dir_all(&model_dir).await.map_err(|e| {
45 EmbedError::Config(format!(
46 "failed to create cache dir {}: {e}",
47 model_dir.display()
48 ))
49 })?;
50
51 Self::download_and_verify(
52 &format!("{HF_BASE}/onnx/{ONNX_FILENAME}"),
53 &onnx_path,
54 model_hashes::BGE_SMALL_ONNX_SHA256,
55 )
56 .await?;
57
58 Self::download_and_verify(
59 &format!("{HF_BASE}/{TOKENIZER_FILENAME}"),
60 &tokenizer_path,
61 model_hashes::BGE_SMALL_TOKENIZER_SHA256,
62 )
63 .await?;
64
65 tracing::info!(
66 model_id = %model_id,
67 "model files downloaded and verified"
68 );
69 }
70
71 Ok(ModelFiles {
72 onnx_path: model_dir.join(ONNX_FILENAME),
73 tokenizer_path: model_dir.join(TOKENIZER_FILENAME),
74 })
75 }
76
77 fn default_cache_dir() -> PathBuf {
78 let home = dirs_next().unwrap_or_else(|| PathBuf::from("."));
79 home.join(".cache").join("argyph").join("models")
80 }
81
82 async fn needs_download(model_dir: &Path) -> bool {
83 let onnx = model_dir.join(ONNX_FILENAME);
84 let tok = model_dir.join(TOKENIZER_FILENAME);
85
86 let onnx_ok = Self::file_hash_matches(&onnx, model_hashes::BGE_SMALL_ONNX_SHA256).await;
87 let tok_ok = Self::file_hash_matches(&tok, model_hashes::BGE_SMALL_TOKENIZER_SHA256).await;
88
89 !(onnx_ok && tok_ok)
90 }
91
92 async fn file_hash_matches(path: &Path, expected_hex: &str) -> bool {
93 match tokio::fs::read(path).await {
94 Ok(data) => {
95 use sha2::Digest;
96 let hash = sha2::Sha256::digest(&data);
97 let hex = hex::encode(hash);
98 hex == expected_hex
99 }
100 Err(_) => false,
101 }
102 }
103
104 async fn download_and_verify(url: &str, dest: &Path, expected_sha256: &str) -> Result<()> {
105 let tmp = dest.with_extension("tmp");
106
107 tracing::info!(%url, "downloading");
108 let response = reqwest::get(url)
109 .await
110 .map_err(|e| EmbedError::Config(format!("failed to download {url}: {e}")))?;
111
112 if !response.status().is_success() {
113 return Err(EmbedError::Config(format!(
114 "download failed for {url}: HTTP {}",
115 response.status().as_u16()
116 )));
117 }
118
119 let bytes = response
120 .bytes()
121 .await
122 .map_err(|e| EmbedError::Config(format!("failed to read response for {url}: {e}")))?;
123
124 {
125 use sha2::Digest;
126 let hash = sha2::Sha256::digest(&bytes);
127 let hex = hex::encode(hash);
128 if hex != expected_sha256 {
129 return Err(EmbedError::Config(format!(
130 "SHA-256 mismatch for {url}: expected {expected_sha256}, got {hex}"
131 )));
132 }
133 }
134
135 let mut f = tokio::fs::File::create(&tmp).await.map_err(|e| {
136 EmbedError::Config(format!("failed to create temp file {}: {e}", tmp.display()))
137 })?;
138 f.write_all(&bytes).await.map_err(|e| {
139 EmbedError::Config(format!("failed to write temp file {}: {e}", tmp.display()))
140 })?;
141 f.flush().await.map_err(|e| {
142 EmbedError::Config(format!("failed to flush temp file {}: {e}", tmp.display()))
143 })?;
144 drop(f);
145
146 tokio::fs::rename(&tmp, dest).await.map_err(|e| {
147 EmbedError::Config(format!(
148 "failed to rename {} -> {}: {e}",
149 tmp.display(),
150 dest.display()
151 ))
152 })?;
153
154 tracing::info!(%url, "verified and cached");
155 Ok(())
156 }
157}
158
159fn dirs_next() -> Option<PathBuf> {
160 std::env::var("HOME")
161 .ok()
162 .or({
163 #[cfg(target_os = "windows")]
164 {
165 let drive = std::env::var("HOMEDRIVE").unwrap_or_default();
166 let path = std::env::var("HOMEPATH").unwrap_or_default();
167 if drive.is_empty() || path.is_empty() {
168 None
169 } else {
170 Some(format!("{drive}{path}"))
171 }
172 }
173 #[cfg(not(target_os = "windows"))]
174 {
175 None
176 }
177 })
178 .map(PathBuf::from)
179}
180
181#[cfg(test)]
182#[allow(clippy::unwrap_used, clippy::expect_used)]
183mod tests {
184 use super::*;
185
186 #[tokio::test]
187 async fn unknown_model_id_returns_config_error() {
188 let result = ModelFiles::ensure_available("unknown-model", None).await;
189 assert!(result.is_err());
190 match result.unwrap_err() {
191 EmbedError::Config(msg) => assert!(msg.contains("unknown")),
192 other => panic!("expected Config error, got: {other:?}"),
193 }
194 }
195
196 #[tokio::test]
197 async fn needs_download_true_for_empty_dir() {
198 let dir = std::env::temp_dir().join("argyph_test_empty");
199 let _ = std::fs::remove_dir_all(&dir);
200 assert!(ModelFiles::needs_download(&dir).await);
201 }
202}