blazen_model_cache/
lib.rs1use std::path::{Path, PathBuf};
8use std::sync::{Arc, LazyLock};
9
10static DOWNLOAD_LOCKS: LazyLock<dashmap::DashMap<PathBuf, Arc<tokio::sync::Mutex<()>>>> =
14 LazyLock::new(dashmap::DashMap::new);
15
16#[derive(Debug, thiserror::Error)]
18pub enum CacheError {
19 #[error("failed to download model: {0}")]
21 Download(String),
22
23 #[error("cache directory error: {0}")]
25 CacheDir(String),
26
27 #[error("IO error: {0}")]
29 Io(#[from] std::io::Error),
30}
31
32pub trait ProgressCallback: Send + Sync {
36 fn on_progress(&self, downloaded_bytes: u64, total_bytes: Option<u64>);
41}
42
43#[derive(Clone)]
49struct HfProgressAdapter {
50 callback: Arc<dyn ProgressCallback>,
51 downloaded: u64,
52 total: Option<u64>,
53}
54
55impl HfProgressAdapter {
56 fn new(callback: Arc<dyn ProgressCallback>) -> Self {
57 Self {
58 callback,
59 downloaded: 0,
60 total: None,
61 }
62 }
63}
64
65impl hf_hub::api::tokio::Progress for HfProgressAdapter {
66 async fn init(&mut self, size: usize, _filename: &str) {
67 self.total = Some(size as u64);
68 self.downloaded = 0;
69 self.callback.on_progress(0, self.total);
70 }
71
72 async fn update(&mut self, size: usize) {
73 self.downloaded += size as u64;
74 self.callback.on_progress(self.downloaded, self.total);
75 }
76
77 async fn finish(&mut self) {
78 self.callback
79 .on_progress(self.downloaded.max(1), self.total);
80 }
81}
82
83#[derive(Clone)]
85struct NoProgress;
86
87impl hf_hub::api::tokio::Progress for NoProgress {
88 async fn init(&mut self, _size: usize, _filename: &str) {}
89 async fn update(&mut self, _size: usize) {}
90 async fn finish(&mut self) {}
91}
92
93pub struct ModelCache {
111 cache_dir: PathBuf,
112}
113
114impl ModelCache {
115 pub fn new() -> Result<Self, CacheError> {
125 let cache_dir = if let Ok(dir) = std::env::var("BLAZEN_CACHE_DIR") {
126 PathBuf::from(dir).join("models")
127 } else {
128 dirs::cache_dir()
129 .ok_or_else(|| {
130 CacheError::CacheDir(
131 "could not determine home cache directory; \
132 set BLAZEN_CACHE_DIR to override"
133 .to_string(),
134 )
135 })?
136 .join("blazen")
137 .join("models")
138 };
139
140 Ok(Self { cache_dir })
141 }
142
143 #[must_use]
148 pub fn with_dir(cache_dir: impl Into<PathBuf>) -> Self {
149 Self {
150 cache_dir: cache_dir.into(),
151 }
152 }
153
154 #[must_use]
156 pub fn cache_dir(&self) -> &Path {
157 &self.cache_dir
158 }
159
160 #[must_use]
162 pub fn is_cached(&self, repo_id: &str, filename: &str) -> bool {
163 self.cached_path(repo_id, filename).is_file()
164 }
165
166 pub async fn download(
185 &self,
186 repo_id: &str,
187 filename: &str,
188 progress: Option<Arc<dyn ProgressCallback>>,
189 ) -> Result<PathBuf, CacheError> {
190 let dest = self.cached_path(repo_id, filename);
191
192 let lock = DOWNLOAD_LOCKS
196 .entry(dest.clone())
197 .or_insert_with(|| Arc::new(tokio::sync::Mutex::new(())))
198 .value()
199 .clone();
200 let _guard = lock.lock().await;
201
202 if dest.is_file() {
205 return Ok(dest);
206 }
207
208 if let Some(parent) = dest.parent() {
210 tokio::fs::create_dir_all(parent).await?;
211 }
212
213 let api = hf_hub::api::tokio::ApiBuilder::new()
215 .with_progress(false) .build()
217 .map_err(|e| CacheError::Download(e.to_string()))?;
218
219 let repo = api.model(repo_id.to_string());
220
221 let hf_path = if let Some(cb) = progress {
223 let adapter = HfProgressAdapter::new(Arc::clone(&cb));
224 repo.download_with_progress(filename, adapter)
225 .await
226 .map_err(|e| CacheError::Download(e.to_string()))?
227 } else {
228 let noop = NoProgress;
229 repo.download_with_progress(filename, noop)
230 .await
231 .map_err(|e| CacheError::Download(e.to_string()))?
232 };
233
234 let hf_path_resolved = tokio::fs::canonicalize(&hf_path)
241 .await
242 .unwrap_or_else(|_| hf_path.clone());
243
244 if dest != hf_path_resolved {
246 if tokio::fs::hard_link(&hf_path_resolved, &dest)
248 .await
249 .is_err()
250 {
251 tokio::fs::copy(&hf_path_resolved, &dest).await?;
253 }
254 }
255
256 if !dest.is_file() {
258 return Err(CacheError::Io(std::io::Error::new(
259 std::io::ErrorKind::NotFound,
260 format!(
261 "download completed but cache path is missing: {}",
262 dest.display()
263 ),
264 )));
265 }
266
267 Ok(dest)
268 }
269
270 fn cached_path(&self, repo_id: &str, filename: &str) -> PathBuf {
272 self.cache_dir.join(repo_id).join(filename)
273 }
274}
275
276#[cfg(test)]
277#[allow(unsafe_code)] mod tests {
279 use super::*;
280
281 #[test]
282 fn test_default_cache_dir() {
283 let had_var = std::env::var("BLAZEN_CACHE_DIR").ok();
287
288 unsafe {
294 std::env::remove_var("BLAZEN_CACHE_DIR");
295 }
296
297 let cache = ModelCache::new().expect("default cache should succeed");
298 let path = cache.cache_dir();
299
300 assert!(
302 path.ends_with("blazen/models"),
303 "expected path ending with blazen/models, got: {}",
304 path.display()
305 );
306
307 if let Some(val) = had_var {
309 unsafe {
311 std::env::set_var("BLAZEN_CACHE_DIR", val);
312 }
313 }
314 }
315
316 #[test]
317 fn test_default_cache_dir_from_env() {
318 let prev = std::env::var("BLAZEN_CACHE_DIR").ok();
319
320 unsafe {
322 std::env::set_var("BLAZEN_CACHE_DIR", "/tmp/blazen-test-cache");
323 }
324
325 let cache = ModelCache::new().expect("env-based cache should succeed");
326 assert_eq!(
327 cache.cache_dir(),
328 Path::new("/tmp/blazen-test-cache/models")
329 );
330
331 unsafe {
334 match prev {
335 Some(val) => std::env::set_var("BLAZEN_CACHE_DIR", val),
336 None => std::env::remove_var("BLAZEN_CACHE_DIR"),
337 }
338 }
339 }
340
341 #[test]
342 fn test_with_dir() {
343 let dir = tempfile::tempdir().expect("tempdir");
344 let cache = ModelCache::with_dir(dir.path());
345 assert_eq!(cache.cache_dir(), dir.path());
346 }
347
348 #[test]
349 fn test_is_cached_false_initially() {
350 let dir = tempfile::tempdir().expect("tempdir");
351 let cache = ModelCache::with_dir(dir.path());
352 assert!(!cache.is_cached("foo/bar", "model.gguf"));
353 }
354
355 #[test]
356 fn test_is_cached_true_after_manual_placement() {
357 let dir = tempfile::tempdir().expect("tempdir");
358 let cache = ModelCache::with_dir(dir.path());
359
360 let file_dir = dir.path().join("my-org/my-model");
362 std::fs::create_dir_all(&file_dir).unwrap();
363 std::fs::write(file_dir.join("config.json"), b"{}").unwrap();
364
365 assert!(cache.is_cached("my-org/my-model", "config.json"));
366 }
367
368 #[test]
369 fn test_cached_path_layout() {
370 let cache = ModelCache::with_dir("/fake/cache");
371 let path = cache.cached_path("org/model", "weights.bin");
372 assert_eq!(path, PathBuf::from("/fake/cache/org/model/weights.bin"));
373 }
374
375 #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
383 async fn concurrent_downloads_serialize_same_path() {
384 let tmp = tempfile::tempdir().expect("tempdir");
385 let cache = ModelCache::with_dir(tmp.path().to_path_buf());
386 let dest = cache.cached_path("test/repo", "file.bin");
387
388 let counter = Arc::new(std::sync::atomic::AtomicUsize::new(0));
389 let mut handles = Vec::new();
390 for _ in 0..4 {
391 let dest_clone = dest.clone();
392 let counter_clone = Arc::clone(&counter);
393 handles.push(tokio::spawn(async move {
394 let lock = DOWNLOAD_LOCKS
395 .entry(dest_clone)
396 .or_insert_with(|| Arc::new(tokio::sync::Mutex::new(())))
397 .value()
398 .clone();
399 let _guard = lock.lock().await;
400 let prev = counter_clone.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
404 tokio::time::sleep(std::time::Duration::from_millis(10)).await;
406 counter_clone.fetch_sub(1, std::sync::atomic::Ordering::SeqCst);
407 prev
408 }));
409 }
410
411 let results = futures_util::future::join_all(handles).await;
412 for r in results {
413 let prev = r.expect("task panicked");
414 assert_eq!(prev, 0, "another task held the lock concurrently");
415 }
416 }
417
418 #[tokio::test]
425 #[ignore = "requires network access to HuggingFace Hub"]
426 async fn test_download_and_cache() {
427 let dir = tempfile::tempdir().expect("tempdir");
428 let cache = ModelCache::with_dir(dir.path());
429
430 let path = cache
432 .download("hf-internal-testing/tiny-random-gpt2", "config.json", None)
433 .await
434 .expect("download should succeed");
435
436 assert!(path.is_file(), "downloaded file should exist");
438 let meta = std::fs::metadata(&path).expect("metadata");
439 assert!(meta.len() > 0, "downloaded file should be non-empty");
440
441 let path2 = cache
443 .download("hf-internal-testing/tiny-random-gpt2", "config.json", None)
444 .await
445 .expect("cached download should succeed");
446 assert_eq!(path, path2);
447 }
448
449 #[tokio::test]
451 #[ignore = "requires network access to HuggingFace Hub"]
452 async fn test_download_with_progress() {
453 use std::sync::atomic::{AtomicU64, Ordering};
454
455 struct TestProgress {
456 calls: AtomicU64,
457 }
458
459 impl ProgressCallback for TestProgress {
460 fn on_progress(&self, _downloaded_bytes: u64, _total_bytes: Option<u64>) {
461 self.calls.fetch_add(1, Ordering::Relaxed);
462 }
463 }
464
465 let dir = tempfile::tempdir().expect("tempdir");
466 let cache = ModelCache::with_dir(dir.path());
467 let progress = Arc::new(TestProgress {
468 calls: AtomicU64::new(0),
469 });
470
471 let cb: Arc<dyn ProgressCallback> = Arc::clone(&progress) as Arc<dyn ProgressCallback>;
473
474 cache
475 .download(
476 "hf-internal-testing/tiny-random-gpt2",
477 "config.json",
478 Some(cb),
479 )
480 .await
481 .expect("download should succeed");
482
483 assert!(
484 progress.calls.load(Ordering::Relaxed) > 0,
485 "progress callback should have been called at least once"
486 );
487 }
488}