Skip to main content

blazen_model_cache/
lib.rs

1//! Shared model download and cache layer for Blazen local-inference backends.
2//!
3//! Provides [`ModelCache`] for downloading and caching ML models from
4//! [`HuggingFace` Hub](https://huggingface.co). Designed to be shared by all
5//! local-inference backends (fastembed, mistral.rs, whisper.cpp, etc.).
6
7use std::path::{Path, PathBuf};
8use std::sync::{Arc, LazyLock};
9
10/// Per-destination-path mutexes that serialize concurrent [`ModelCache::download`]
11/// calls for the same file. Different files download in parallel; same-file
12/// callers wait so hf-hub's internal blob lock isn't raced.
13static DOWNLOAD_LOCKS: LazyLock<dashmap::DashMap<PathBuf, Arc<tokio::sync::Mutex<()>>>> =
14    LazyLock::new(dashmap::DashMap::new);
15
16/// Errors that can occur during model cache operations.
17#[derive(Debug, thiserror::Error)]
18pub enum CacheError {
19    /// A model download failed.
20    #[error("failed to download model: {0}")]
21    Download(String),
22
23    /// The cache directory could not be resolved or created.
24    #[error("cache directory error: {0}")]
25    CacheDir(String),
26
27    /// An underlying I/O error.
28    #[error("IO error: {0}")]
29    Io(#[from] std::io::Error),
30}
31
32/// Callback trait for receiving download progress updates.
33///
34/// Implement this on your own type to get notified as bytes are downloaded.
35pub trait ProgressCallback: Send + Sync {
36    /// Called periodically during download.
37    ///
38    /// * `downloaded_bytes` - Total bytes downloaded so far.
39    /// * `total_bytes` - Total file size if known by the server.
40    fn on_progress(&self, downloaded_bytes: u64, total_bytes: Option<u64>);
41}
42
43/// Adapter that bridges our [`ProgressCallback`] trait to `hf_hub`'s
44/// [`Progress`](hf_hub::api::tokio::Progress) trait.
45///
46/// Uses `Arc` so the adapter is `Clone + Send + Sync + 'static` as
47/// required by [`hf_hub::api::tokio::ApiRepo::download_with_progress`].
48#[derive(Clone)]
49struct HfProgressAdapter {
50    callback: Arc<dyn ProgressCallback>,
51    downloaded: u64,
52    total: Option<u64>,
53}
54
55impl HfProgressAdapter {
56    fn new(callback: Arc<dyn ProgressCallback>) -> Self {
57        Self {
58            callback,
59            downloaded: 0,
60            total: None,
61        }
62    }
63}
64
65impl hf_hub::api::tokio::Progress for HfProgressAdapter {
66    async fn init(&mut self, size: usize, _filename: &str) {
67        self.total = Some(size as u64);
68        self.downloaded = 0;
69        self.callback.on_progress(0, self.total);
70    }
71
72    async fn update(&mut self, size: usize) {
73        self.downloaded += size as u64;
74        self.callback.on_progress(self.downloaded, self.total);
75    }
76
77    async fn finish(&mut self) {
78        self.callback
79            .on_progress(self.downloaded.max(1), self.total);
80    }
81}
82
83/// A no-op progress implementation used when no callback is provided.
84#[derive(Clone)]
85struct NoProgress;
86
87impl hf_hub::api::tokio::Progress for NoProgress {
88    async fn init(&mut self, _size: usize, _filename: &str) {}
89    async fn update(&mut self, _size: usize) {}
90    async fn finish(&mut self) {}
91}
92
93/// Local cache for ML models downloaded from `HuggingFace` Hub.
94///
95/// Models are stored under `{cache_dir}/{repo_id}/{filename}`. Files are
96/// downloaded only once; subsequent calls return the cached path immediately.
97///
98/// # Examples
99///
100/// ```no_run
101/// # async fn example() -> Result<(), blazen_model_cache::CacheError> {
102/// use blazen_model_cache::ModelCache;
103///
104/// let cache = ModelCache::new()?;
105/// let path = cache.download("bert-base-uncased", "config.json", None).await?;
106/// println!("model config at: {}", path.display());
107/// # Ok(())
108/// # }
109/// ```
110pub struct ModelCache {
111    cache_dir: PathBuf,
112}
113
114impl ModelCache {
115    /// Create a cache in the default location.
116    ///
117    /// Uses `$BLAZEN_CACHE_DIR/models/` if the `BLAZEN_CACHE_DIR` environment
118    /// variable is set, otherwise falls back to `~/.cache/blazen/models/`.
119    ///
120    /// # Errors
121    ///
122    /// Returns [`CacheError::CacheDir`] if the home directory cannot be
123    /// determined and `BLAZEN_CACHE_DIR` is not set.
124    pub fn new() -> Result<Self, CacheError> {
125        let cache_dir = if let Ok(dir) = std::env::var("BLAZEN_CACHE_DIR") {
126            PathBuf::from(dir).join("models")
127        } else {
128            dirs::cache_dir()
129                .ok_or_else(|| {
130                    CacheError::CacheDir(
131                        "could not determine home cache directory; \
132                         set BLAZEN_CACHE_DIR to override"
133                            .to_string(),
134                    )
135                })?
136                .join("blazen")
137                .join("models")
138        };
139
140        Ok(Self { cache_dir })
141    }
142
143    /// Create a cache rooted at a specific directory.
144    ///
145    /// The directory does not need to exist yet; it will be created on the
146    /// first download.
147    #[must_use]
148    pub fn with_dir(cache_dir: impl Into<PathBuf>) -> Self {
149        Self {
150            cache_dir: cache_dir.into(),
151        }
152    }
153
154    /// The root cache directory path.
155    #[must_use]
156    pub fn cache_dir(&self) -> &Path {
157        &self.cache_dir
158    }
159
160    /// Check if a file is already present in the cache (without downloading).
161    #[must_use]
162    pub fn is_cached(&self, repo_id: &str, filename: &str) -> bool {
163        self.cached_path(repo_id, filename).is_file()
164    }
165
166    /// Download a file from `HuggingFace` Hub if it is not already cached.
167    ///
168    /// Returns the local filesystem path to the cached file.
169    ///
170    /// The file is first downloaded via `hf-hub` into its own managed cache,
171    /// then hard-linked (or copied as fallback) into our
172    /// `{cache_dir}/{repo_id}/{filename}` layout so that callers get a stable,
173    /// predictable path.
174    ///
175    /// # Progress
176    ///
177    /// Pass an `Arc<dyn ProgressCallback>` to receive byte-level progress
178    /// updates during the download. Pass `None` to download silently.
179    ///
180    /// # Errors
181    ///
182    /// Returns [`CacheError::Download`] if the `HuggingFace` API request fails,
183    /// or [`CacheError::Io`] if filesystem operations fail.
184    pub async fn download(
185        &self,
186        repo_id: &str,
187        filename: &str,
188        progress: Option<Arc<dyn ProgressCallback>>,
189    ) -> Result<PathBuf, CacheError> {
190        let dest = self.cached_path(repo_id, filename);
191
192        // Serialize concurrent callers for the same destination path. Different
193        // files download in parallel; same-file callers wait so hf-hub's internal
194        // blob lock isn't raced.
195        let lock = DOWNLOAD_LOCKS
196            .entry(dest.clone())
197            .or_insert_with(|| Arc::new(tokio::sync::Mutex::new(())))
198            .value()
199            .clone();
200        let _guard = lock.lock().await;
201
202        // Already cached -- return immediately (re-check inside the lock so
203        // subsequent callers observe the file created by whoever went first).
204        if dest.is_file() {
205            return Ok(dest);
206        }
207
208        // Ensure the target directory exists.
209        if let Some(parent) = dest.parent() {
210            tokio::fs::create_dir_all(parent).await?;
211        }
212
213        // Build the hf-hub async API.
214        let api = hf_hub::api::tokio::ApiBuilder::new()
215            .with_progress(false) // We handle progress ourselves.
216            .build()
217            .map_err(|e| CacheError::Download(e.to_string()))?;
218
219        let repo = api.model(repo_id.to_string());
220
221        // Download through hf-hub (it manages its own cache under ~/.cache/huggingface).
222        let hf_path = if let Some(cb) = progress {
223            let adapter = HfProgressAdapter::new(Arc::clone(&cb));
224            repo.download_with_progress(filename, adapter)
225                .await
226                .map_err(|e| CacheError::Download(e.to_string()))?
227        } else {
228            let noop = NoProgress;
229            repo.download_with_progress(filename, noop)
230                .await
231                .map_err(|e| CacheError::Download(e.to_string()))?
232        };
233
234        // hf-hub returns `snapshots/main/<filename>` as a symlink into
235        // `blobs/<hash>`. Resolve to the real blob so hard-linking targets the
236        // actual file; otherwise on some filesystems we would hard-link the
237        // symlink itself, which can behave unexpectedly if the snapshot is
238        // pruned later. If canonicalization fails (e.g. broken chain), fall
239        // back to the original path and let the copy fallback handle it.
240        let hf_path_resolved = tokio::fs::canonicalize(&hf_path)
241            .await
242            .unwrap_or_else(|_| hf_path.clone());
243
244        // Link or copy the file into our own cache layout.
245        if dest != hf_path_resolved {
246            // Try hard link first (instant, no extra disk space).
247            if tokio::fs::hard_link(&hf_path_resolved, &dest)
248                .await
249                .is_err()
250            {
251                // Cross-device or unsupported -- fall back to copy.
252                tokio::fs::copy(&hf_path_resolved, &dest).await?;
253            }
254        }
255
256        // Postcondition: dest must exist after a successful download.
257        if !dest.is_file() {
258            return Err(CacheError::Io(std::io::Error::new(
259                std::io::ErrorKind::NotFound,
260                format!(
261                    "download completed but cache path is missing: {}",
262                    dest.display()
263                ),
264            )));
265        }
266
267        Ok(dest)
268    }
269
270    /// Compute the expected cache path for a repo/file pair.
271    fn cached_path(&self, repo_id: &str, filename: &str) -> PathBuf {
272        self.cache_dir.join(repo_id).join(filename)
273    }
274}
275
276#[cfg(test)]
277#[allow(unsafe_code)] // env::set_var / env::remove_var require unsafe in edition 2024
278mod tests {
279    use super::*;
280
281    #[test]
282    fn test_default_cache_dir() {
283        // When BLAZEN_CACHE_DIR is not set, the cache should live under the
284        // platform cache directory (e.g. ~/.cache/blazen/models/ on Linux).
285        // We temporarily remove the env var to test the default path.
286        let had_var = std::env::var("BLAZEN_CACHE_DIR").ok();
287
288        // SAFETY: This test runs single-threaded and restores the variable
289        // immediately after the assertion. No other thread reads this var
290        // concurrently in the test suite (env-var tests are inherently racy
291        // but acceptable in `#[test]` which defaults to `--test-threads=1`
292        // per binary).
293        unsafe {
294            std::env::remove_var("BLAZEN_CACHE_DIR");
295        }
296
297        let cache = ModelCache::new().expect("default cache should succeed");
298        let path = cache.cache_dir();
299
300        // Should end with blazen/models
301        assert!(
302            path.ends_with("blazen/models"),
303            "expected path ending with blazen/models, got: {}",
304            path.display()
305        );
306
307        // Restore env var if it was set.
308        if let Some(val) = had_var {
309            // SAFETY: restoring the original value.
310            unsafe {
311                std::env::set_var("BLAZEN_CACHE_DIR", val);
312            }
313        }
314    }
315
316    #[test]
317    fn test_default_cache_dir_from_env() {
318        let prev = std::env::var("BLAZEN_CACHE_DIR").ok();
319
320        // SAFETY: see `test_default_cache_dir`.
321        unsafe {
322            std::env::set_var("BLAZEN_CACHE_DIR", "/tmp/blazen-test-cache");
323        }
324
325        let cache = ModelCache::new().expect("env-based cache should succeed");
326        assert_eq!(
327            cache.cache_dir(),
328            Path::new("/tmp/blazen-test-cache/models")
329        );
330
331        // Restore.
332        // SAFETY: see `test_default_cache_dir`.
333        unsafe {
334            match prev {
335                Some(val) => std::env::set_var("BLAZEN_CACHE_DIR", val),
336                None => std::env::remove_var("BLAZEN_CACHE_DIR"),
337            }
338        }
339    }
340
341    #[test]
342    fn test_with_dir() {
343        let dir = tempfile::tempdir().expect("tempdir");
344        let cache = ModelCache::with_dir(dir.path());
345        assert_eq!(cache.cache_dir(), dir.path());
346    }
347
348    #[test]
349    fn test_is_cached_false_initially() {
350        let dir = tempfile::tempdir().expect("tempdir");
351        let cache = ModelCache::with_dir(dir.path());
352        assert!(!cache.is_cached("foo/bar", "model.gguf"));
353    }
354
355    #[test]
356    fn test_is_cached_true_after_manual_placement() {
357        let dir = tempfile::tempdir().expect("tempdir");
358        let cache = ModelCache::with_dir(dir.path());
359
360        // Manually create the file to simulate a cached download.
361        let file_dir = dir.path().join("my-org/my-model");
362        std::fs::create_dir_all(&file_dir).unwrap();
363        std::fs::write(file_dir.join("config.json"), b"{}").unwrap();
364
365        assert!(cache.is_cached("my-org/my-model", "config.json"));
366    }
367
368    #[test]
369    fn test_cached_path_layout() {
370        let cache = ModelCache::with_dir("/fake/cache");
371        let path = cache.cached_path("org/model", "weights.bin");
372        assert_eq!(path, PathBuf::from("/fake/cache/org/model/weights.bin"));
373    }
374
375    /// Verifies that the per-path lock actually serializes concurrent callers
376    /// targeting the same destination. We can't easily mock hf-hub inside
377    /// `download()`, so this test exercises the serialization primitive
378    /// directly: if the lock map misbehaves (e.g. hands out independent
379    /// mutexes for the same path), more than one task will sit inside the
380    /// critical section at the same time and the counter will exceed zero
381    /// when observed by another task.
382    #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
383    async fn concurrent_downloads_serialize_same_path() {
384        let tmp = tempfile::tempdir().expect("tempdir");
385        let cache = ModelCache::with_dir(tmp.path().to_path_buf());
386        let dest = cache.cached_path("test/repo", "file.bin");
387
388        let counter = Arc::new(std::sync::atomic::AtomicUsize::new(0));
389        let mut handles = Vec::new();
390        for _ in 0..4 {
391            let dest_clone = dest.clone();
392            let counter_clone = Arc::clone(&counter);
393            handles.push(tokio::spawn(async move {
394                let lock = DOWNLOAD_LOCKS
395                    .entry(dest_clone)
396                    .or_insert_with(|| Arc::new(tokio::sync::Mutex::new(())))
397                    .value()
398                    .clone();
399                let _guard = lock.lock().await;
400                // If another task already holds the lock, it would have
401                // incremented the counter before us; the assertion below
402                // would then catch the violation.
403                let prev = counter_clone.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
404                // Simulate in-flight work to widen the race window.
405                tokio::time::sleep(std::time::Duration::from_millis(10)).await;
406                counter_clone.fetch_sub(1, std::sync::atomic::Ordering::SeqCst);
407                prev
408            }));
409        }
410
411        let results = futures_util::future::join_all(handles).await;
412        for r in results {
413            let prev = r.expect("task panicked");
414            assert_eq!(prev, 0, "another task held the lock concurrently");
415        }
416    }
417
418    /// Integration test that actually downloads from `HuggingFace` Hub.
419    ///
420    /// Ignored by default because it requires network access. Run with:
421    /// ```sh
422    /// cargo test -p blazen-model-cache -- --ignored
423    /// ```
424    #[tokio::test]
425    #[ignore = "requires network access to HuggingFace Hub"]
426    async fn test_download_and_cache() {
427        let dir = tempfile::tempdir().expect("tempdir");
428        let cache = ModelCache::with_dir(dir.path());
429
430        // Download a tiny file (~600 bytes).
431        let path = cache
432            .download("hf-internal-testing/tiny-random-gpt2", "config.json", None)
433            .await
434            .expect("download should succeed");
435
436        // File should exist and have non-zero size.
437        assert!(path.is_file(), "downloaded file should exist");
438        let meta = std::fs::metadata(&path).expect("metadata");
439        assert!(meta.len() > 0, "downloaded file should be non-empty");
440
441        // Second call should return the cached path instantly.
442        let path2 = cache
443            .download("hf-internal-testing/tiny-random-gpt2", "config.json", None)
444            .await
445            .expect("cached download should succeed");
446        assert_eq!(path, path2);
447    }
448
449    /// Integration test verifying progress callback fires.
450    #[tokio::test]
451    #[ignore = "requires network access to HuggingFace Hub"]
452    async fn test_download_with_progress() {
453        use std::sync::atomic::{AtomicU64, Ordering};
454
455        struct TestProgress {
456            calls: AtomicU64,
457        }
458
459        impl ProgressCallback for TestProgress {
460            fn on_progress(&self, _downloaded_bytes: u64, _total_bytes: Option<u64>) {
461                self.calls.fetch_add(1, Ordering::Relaxed);
462            }
463        }
464
465        let dir = tempfile::tempdir().expect("tempdir");
466        let cache = ModelCache::with_dir(dir.path());
467        let progress = Arc::new(TestProgress {
468            calls: AtomicU64::new(0),
469        });
470
471        // Clone the Arc so we retain a handle for assertions.
472        let cb: Arc<dyn ProgressCallback> = Arc::clone(&progress) as Arc<dyn ProgressCallback>;
473
474        cache
475            .download(
476                "hf-internal-testing/tiny-random-gpt2",
477                "config.json",
478                Some(cb),
479            )
480            .await
481            .expect("download should succeed");
482
483        assert!(
484            progress.calls.load(Ordering::Relaxed) > 0,
485            "progress callback should have been called at least once"
486        );
487    }
488}