Skip to main content

blazen_model_cache/
lib.rs

1//! Shared model download and cache layer for Blazen local-inference backends.
2//!
3//! Provides [`ModelCache`] for downloading and caching ML models from
4//! [`HuggingFace` Hub](https://huggingface.co). Designed to be shared by all
5//! local-inference backends (fastembed, mistral.rs, whisper.cpp, etc.).
6//!
7//! ## wasm32 support
8//!
9//! On `wasm32-*` targets the underlying download stack (`hf-hub`, `dirs`,
10//! `tokio::fs`) is not available, so [`ModelCache`] is a stub that always
11//! returns [`CacheError::Unsupported`]. Browser/Worker callers should obtain
12//! model bytes through a different mechanism (e.g. `fetch()` on the JS side,
13//! pre-bundled assets, or a manually populated cache directory).
14
15use std::path::{Path, PathBuf};
16use std::sync::Arc;
17#[cfg(not(target_arch = "wasm32"))]
18use std::sync::LazyLock;
19
20/// Per-destination-path mutexes that serialize concurrent [`ModelCache::download`]
21/// calls for the same file. Different files download in parallel; same-file
22/// callers wait so hf-hub's internal blob lock isn't raced.
23#[cfg(not(target_arch = "wasm32"))]
24static DOWNLOAD_LOCKS: LazyLock<dashmap::DashMap<PathBuf, Arc<tokio::sync::Mutex<()>>>> =
25    LazyLock::new(dashmap::DashMap::new);
26
27/// Errors that can occur during model cache operations.
28#[derive(Debug, thiserror::Error)]
29pub enum CacheError {
30    /// A model download failed.
31    #[error("failed to download model: {0}")]
32    Download(String),
33
34    /// The cache directory could not be resolved or created.
35    #[error("cache directory error: {0}")]
36    CacheDir(String),
37
38    /// An underlying I/O error.
39    #[error("IO error: {0}")]
40    Io(#[from] std::io::Error),
41
42    /// The requested operation is not supported on this target (e.g. WASM,
43    /// where the underlying `HuggingFace` Hub download stack is unavailable).
44    #[error("model cache operation not supported on this target: {0}")]
45    Unsupported(String),
46}
47
48/// Callback trait for receiving download progress updates.
49///
50/// Implement this on your own type to get notified as bytes are downloaded.
51pub trait ProgressCallback: Send + Sync {
52    /// Called periodically during download.
53    ///
54    /// * `downloaded_bytes` - Total bytes downloaded so far.
55    /// * `total_bytes` - Total file size if known by the server.
56    fn on_progress(&self, downloaded_bytes: u64, total_bytes: Option<u64>);
57}
58
59/// Adapter that bridges our [`ProgressCallback`] trait to `hf_hub`'s
60/// [`Progress`](hf_hub::api::tokio::Progress) trait.
61///
62/// Uses `Arc` so the adapter is `Clone + Send + Sync + 'static` as
63/// required by [`hf_hub::api::tokio::ApiRepo::download_with_progress`].
64#[cfg(not(target_arch = "wasm32"))]
65#[derive(Clone)]
66struct HfProgressAdapter {
67    callback: Arc<dyn ProgressCallback>,
68    downloaded: u64,
69    total: Option<u64>,
70}
71
72#[cfg(not(target_arch = "wasm32"))]
73impl HfProgressAdapter {
74    fn new(callback: Arc<dyn ProgressCallback>) -> Self {
75        Self {
76            callback,
77            downloaded: 0,
78            total: None,
79        }
80    }
81}
82
83#[cfg(not(target_arch = "wasm32"))]
84impl hf_hub::api::tokio::Progress for HfProgressAdapter {
85    async fn init(&mut self, size: usize, _filename: &str) {
86        self.total = Some(size as u64);
87        self.downloaded = 0;
88        self.callback.on_progress(0, self.total);
89    }
90
91    async fn update(&mut self, size: usize) {
92        self.downloaded += size as u64;
93        self.callback.on_progress(self.downloaded, self.total);
94    }
95
96    async fn finish(&mut self) {
97        self.callback
98            .on_progress(self.downloaded.max(1), self.total);
99    }
100}
101
102/// A no-op progress implementation used when no callback is provided.
103#[cfg(not(target_arch = "wasm32"))]
104#[derive(Clone)]
105struct NoProgress;
106
107#[cfg(not(target_arch = "wasm32"))]
108impl hf_hub::api::tokio::Progress for NoProgress {
109    async fn init(&mut self, _size: usize, _filename: &str) {}
110    async fn update(&mut self, _size: usize) {}
111    async fn finish(&mut self) {}
112}
113
114/// Local cache for ML models downloaded from `HuggingFace` Hub.
115///
116/// Models are stored under `{cache_dir}/{repo_id}/{filename}`. Files are
117/// downloaded only once; subsequent calls return the cached path immediately.
118///
119/// # Examples
120///
121/// ```no_run
122/// # async fn example() -> Result<(), blazen_model_cache::CacheError> {
123/// use blazen_model_cache::ModelCache;
124///
125/// let cache = ModelCache::new()?;
126/// let path = cache.download("bert-base-uncased", "config.json", None).await?;
127/// println!("model config at: {}", path.display());
128/// # Ok(())
129/// # }
130/// ```
131pub struct ModelCache {
132    cache_dir: PathBuf,
133}
134
135impl ModelCache {
136    /// Create a cache rooted at a specific directory.
137    ///
138    /// The directory does not need to exist yet; it will be created on the
139    /// first download.
140    ///
141    /// Available on every target — only [`Self::new`] and [`Self::download`]
142    /// require the native `HuggingFace` Hub download stack.
143    #[must_use]
144    pub fn with_dir(cache_dir: impl Into<PathBuf>) -> Self {
145        Self {
146            cache_dir: cache_dir.into(),
147        }
148    }
149
150    /// The root cache directory path.
151    #[must_use]
152    pub fn cache_dir(&self) -> &Path {
153        &self.cache_dir
154    }
155
156    /// Check if a file is already present in the cache (without downloading).
157    ///
158    /// On wasm32 this always returns `false` — there is no filesystem to
159    /// inspect, so callers should not rely on cache hits in the browser.
160    #[must_use]
161    pub fn is_cached(&self, repo_id: &str, filename: &str) -> bool {
162        #[cfg(not(target_arch = "wasm32"))]
163        {
164            self.cached_path(repo_id, filename).is_file()
165        }
166        #[cfg(target_arch = "wasm32")]
167        {
168            let _ = (repo_id, filename);
169            false
170        }
171    }
172
173    /// Compute the expected cache path for a repo/file pair.
174    fn cached_path(&self, repo_id: &str, filename: &str) -> PathBuf {
175        self.cache_dir.join(repo_id).join(filename)
176    }
177}
178
179// -- Native-only methods (filesystem + HuggingFace Hub) -----------------------
180
181#[cfg(not(target_arch = "wasm32"))]
182impl ModelCache {
183    /// Create a cache in the default location.
184    ///
185    /// Uses `$BLAZEN_CACHE_DIR/models/` if the `BLAZEN_CACHE_DIR` environment
186    /// variable is set, otherwise falls back to `~/.cache/blazen/models/`.
187    ///
188    /// # Errors
189    ///
190    /// Returns [`CacheError::CacheDir`] if the home directory cannot be
191    /// determined and `BLAZEN_CACHE_DIR` is not set.
192    pub fn new() -> Result<Self, CacheError> {
193        let cache_dir = if let Ok(dir) = std::env::var("BLAZEN_CACHE_DIR") {
194            PathBuf::from(dir).join("models")
195        } else {
196            dirs::cache_dir()
197                .ok_or_else(|| {
198                    CacheError::CacheDir(
199                        "could not determine home cache directory; \
200                         set BLAZEN_CACHE_DIR to override"
201                            .to_string(),
202                    )
203                })?
204                .join("blazen")
205                .join("models")
206        };
207
208        Ok(Self { cache_dir })
209    }
210
211    /// Download a file from `HuggingFace` Hub if it is not already cached.
212    ///
213    /// Returns the local filesystem path to the cached file.
214    ///
215    /// The file is first downloaded via `hf-hub` into its own managed cache,
216    /// then hard-linked (or copied as fallback) into our
217    /// `{cache_dir}/{repo_id}/{filename}` layout so that callers get a stable,
218    /// predictable path.
219    ///
220    /// # Progress
221    ///
222    /// Pass an `Arc<dyn ProgressCallback>` to receive byte-level progress
223    /// updates during the download. Pass `None` to download silently.
224    ///
225    /// # Errors
226    ///
227    /// Returns [`CacheError::Download`] if the `HuggingFace` API request fails,
228    /// or [`CacheError::Io`] if filesystem operations fail.
229    pub async fn download(
230        &self,
231        repo_id: &str,
232        filename: &str,
233        progress: Option<Arc<dyn ProgressCallback>>,
234    ) -> Result<PathBuf, CacheError> {
235        let dest = self.cached_path(repo_id, filename);
236
237        // Serialize concurrent callers for the same destination path. Different
238        // files download in parallel; same-file callers wait so hf-hub's internal
239        // blob lock isn't raced.
240        let lock = DOWNLOAD_LOCKS
241            .entry(dest.clone())
242            .or_insert_with(|| Arc::new(tokio::sync::Mutex::new(())))
243            .value()
244            .clone();
245        let _guard = lock.lock().await;
246
247        // Already cached -- return immediately (re-check inside the lock so
248        // subsequent callers observe the file created by whoever went first).
249        if dest.is_file() {
250            return Ok(dest);
251        }
252
253        // Ensure the target directory exists.
254        if let Some(parent) = dest.parent() {
255            tokio::fs::create_dir_all(parent).await?;
256        }
257
258        // Build the hf-hub async API.
259        let api = hf_hub::api::tokio::ApiBuilder::new()
260            .with_progress(false) // We handle progress ourselves.
261            .build()
262            .map_err(|e| CacheError::Download(e.to_string()))?;
263
264        let repo = api.model(repo_id.to_string());
265
266        // Download through hf-hub (it manages its own cache under ~/.cache/huggingface).
267        let hf_path = if let Some(cb) = progress {
268            let adapter = HfProgressAdapter::new(Arc::clone(&cb));
269            repo.download_with_progress(filename, adapter)
270                .await
271                .map_err(|e| CacheError::Download(e.to_string()))?
272        } else {
273            let noop = NoProgress;
274            repo.download_with_progress(filename, noop)
275                .await
276                .map_err(|e| CacheError::Download(e.to_string()))?
277        };
278
279        // hf-hub returns `snapshots/main/<filename>` as a symlink into
280        // `blobs/<hash>`. Resolve to the real blob so hard-linking targets the
281        // actual file; otherwise on some filesystems we would hard-link the
282        // symlink itself, which can behave unexpectedly if the snapshot is
283        // pruned later. If canonicalization fails (e.g. broken chain), fall
284        // back to the original path and let the copy fallback handle it.
285        let hf_path_resolved = tokio::fs::canonicalize(&hf_path)
286            .await
287            .unwrap_or_else(|_| hf_path.clone());
288
289        // Link or copy the file into our own cache layout.
290        if dest != hf_path_resolved {
291            // Try hard link first (instant, no extra disk space).
292            if tokio::fs::hard_link(&hf_path_resolved, &dest)
293                .await
294                .is_err()
295            {
296                // Cross-device or unsupported -- fall back to copy.
297                tokio::fs::copy(&hf_path_resolved, &dest).await?;
298            }
299        }
300
301        // Postcondition: dest must exist after a successful download.
302        if !dest.is_file() {
303            return Err(CacheError::Io(std::io::Error::new(
304                std::io::ErrorKind::NotFound,
305                format!(
306                    "download completed but cache path is missing: {}",
307                    dest.display()
308                ),
309            )));
310        }
311
312        Ok(dest)
313    }
314}
315
316// -- wasm32 stubs -------------------------------------------------------------
317
318#[cfg(target_arch = "wasm32")]
319impl ModelCache {
320    /// Stub that returns [`CacheError::Unsupported`] on wasm32.
321    ///
322    /// On wasm32 there is no `dirs::cache_dir()` and no `BLAZEN_CACHE_DIR`
323    /// environment lookup that would produce a usable path. Use
324    /// [`Self::with_dir`] with an explicit virtual path instead.
325    ///
326    /// # Errors
327    ///
328    /// Always returns [`CacheError::Unsupported`].
329    pub fn new() -> Result<Self, CacheError> {
330        Err(CacheError::Unsupported(
331            "ModelCache::new() is not supported on wasm32; use ModelCache::with_dir() instead"
332                .to_string(),
333        ))
334    }
335
336    /// Stub that always returns [`CacheError::Unsupported`] on wasm32.
337    ///
338    /// The HuggingFace Hub client (`hf-hub`) and `tokio::fs` are not
339    /// compatible with the `wasm32-*` targets, so model files cannot be
340    /// downloaded from this side. Browser/Worker callers should fetch model
341    /// bytes via the JavaScript `fetch()` API and pass them in directly.
342    ///
343    /// # Errors
344    ///
345    /// Always returns [`CacheError::Unsupported`].
346    pub async fn download(
347        &self,
348        repo_id: &str,
349        filename: &str,
350        progress: Option<Arc<dyn ProgressCallback>>,
351    ) -> Result<PathBuf, CacheError> {
352        let _ = (repo_id, filename, progress);
353        Err(CacheError::Unsupported(format!(
354            "ModelCache::download() is not supported on wasm32 \
355             (cache_dir={})",
356            self.cache_dir.display()
357        )))
358    }
359}
360
361#[cfg(all(test, not(target_arch = "wasm32")))]
362#[allow(unsafe_code)] // env::set_var / env::remove_var require unsafe in edition 2024
363mod tests {
364    use super::*;
365
366    #[test]
367    fn test_default_cache_dir() {
368        // When BLAZEN_CACHE_DIR is not set, the cache should live under the
369        // platform cache directory (e.g. ~/.cache/blazen/models/ on Linux).
370        // We temporarily remove the env var to test the default path.
371        let had_var = std::env::var("BLAZEN_CACHE_DIR").ok();
372
373        // SAFETY: This test runs single-threaded and restores the variable
374        // immediately after the assertion. No other thread reads this var
375        // concurrently in the test suite (env-var tests are inherently racy
376        // but acceptable in `#[test]` which defaults to `--test-threads=1`
377        // per binary).
378        unsafe {
379            std::env::remove_var("BLAZEN_CACHE_DIR");
380        }
381
382        let cache = ModelCache::new().expect("default cache should succeed");
383        let path = cache.cache_dir();
384
385        // Should end with blazen/models
386        assert!(
387            path.ends_with("blazen/models"),
388            "expected path ending with blazen/models, got: {}",
389            path.display()
390        );
391
392        // Restore env var if it was set.
393        if let Some(val) = had_var {
394            // SAFETY: restoring the original value.
395            unsafe {
396                std::env::set_var("BLAZEN_CACHE_DIR", val);
397            }
398        }
399    }
400
401    #[test]
402    fn test_default_cache_dir_from_env() {
403        let prev = std::env::var("BLAZEN_CACHE_DIR").ok();
404
405        // SAFETY: see `test_default_cache_dir`.
406        unsafe {
407            std::env::set_var("BLAZEN_CACHE_DIR", "/tmp/blazen-test-cache");
408        }
409
410        let cache = ModelCache::new().expect("env-based cache should succeed");
411        assert_eq!(
412            cache.cache_dir(),
413            Path::new("/tmp/blazen-test-cache/models")
414        );
415
416        // Restore.
417        // SAFETY: see `test_default_cache_dir`.
418        unsafe {
419            match prev {
420                Some(val) => std::env::set_var("BLAZEN_CACHE_DIR", val),
421                None => std::env::remove_var("BLAZEN_CACHE_DIR"),
422            }
423        }
424    }
425
426    #[test]
427    fn test_with_dir() {
428        let dir = tempfile::tempdir().expect("tempdir");
429        let cache = ModelCache::with_dir(dir.path());
430        assert_eq!(cache.cache_dir(), dir.path());
431    }
432
433    #[test]
434    fn test_is_cached_false_initially() {
435        let dir = tempfile::tempdir().expect("tempdir");
436        let cache = ModelCache::with_dir(dir.path());
437        assert!(!cache.is_cached("foo/bar", "model.gguf"));
438    }
439
440    #[test]
441    fn test_is_cached_true_after_manual_placement() {
442        let dir = tempfile::tempdir().expect("tempdir");
443        let cache = ModelCache::with_dir(dir.path());
444
445        // Manually create the file to simulate a cached download.
446        let file_dir = dir.path().join("my-org/my-model");
447        std::fs::create_dir_all(&file_dir).unwrap();
448        std::fs::write(file_dir.join("config.json"), b"{}").unwrap();
449
450        assert!(cache.is_cached("my-org/my-model", "config.json"));
451    }
452
453    #[test]
454    fn test_cached_path_layout() {
455        let cache = ModelCache::with_dir("/fake/cache");
456        let path = cache.cached_path("org/model", "weights.bin");
457        assert_eq!(path, PathBuf::from("/fake/cache/org/model/weights.bin"));
458    }
459
460    /// Verifies that the per-path lock actually serializes concurrent callers
461    /// targeting the same destination. We can't easily mock hf-hub inside
462    /// `download()`, so this test exercises the serialization primitive
463    /// directly: if the lock map misbehaves (e.g. hands out independent
464    /// mutexes for the same path), more than one task will sit inside the
465    /// critical section at the same time and the counter will exceed zero
466    /// when observed by another task.
467    #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
468    async fn concurrent_downloads_serialize_same_path() {
469        let tmp = tempfile::tempdir().expect("tempdir");
470        let cache = ModelCache::with_dir(tmp.path().to_path_buf());
471        let dest = cache.cached_path("test/repo", "file.bin");
472
473        let counter = Arc::new(std::sync::atomic::AtomicUsize::new(0));
474        let mut handles = Vec::new();
475        for _ in 0..4 {
476            let dest_clone = dest.clone();
477            let counter_clone = Arc::clone(&counter);
478            handles.push(tokio::spawn(async move {
479                let lock = DOWNLOAD_LOCKS
480                    .entry(dest_clone)
481                    .or_insert_with(|| Arc::new(tokio::sync::Mutex::new(())))
482                    .value()
483                    .clone();
484                let _guard = lock.lock().await;
485                // If another task already holds the lock, it would have
486                // incremented the counter before us; the assertion below
487                // would then catch the violation.
488                let prev = counter_clone.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
489                // Simulate in-flight work to widen the race window.
490                tokio::time::sleep(std::time::Duration::from_millis(10)).await;
491                counter_clone.fetch_sub(1, std::sync::atomic::Ordering::SeqCst);
492                prev
493            }));
494        }
495
496        let results = futures_util::future::join_all(handles).await;
497        for r in results {
498            let prev = r.expect("task panicked");
499            assert_eq!(prev, 0, "another task held the lock concurrently");
500        }
501    }
502
503    /// Integration test that actually downloads from `HuggingFace` Hub.
504    ///
505    /// Ignored by default because it requires network access. Run with:
506    /// ```sh
507    /// cargo test -p blazen-model-cache -- --ignored
508    /// ```
509    #[tokio::test]
510    #[ignore = "requires network access to HuggingFace Hub"]
511    async fn test_download_and_cache() {
512        let dir = tempfile::tempdir().expect("tempdir");
513        let cache = ModelCache::with_dir(dir.path());
514
515        // Download a tiny file (~600 bytes).
516        let path = cache
517            .download("hf-internal-testing/tiny-random-gpt2", "config.json", None)
518            .await
519            .expect("download should succeed");
520
521        // File should exist and have non-zero size.
522        assert!(path.is_file(), "downloaded file should exist");
523        let meta = std::fs::metadata(&path).expect("metadata");
524        assert!(meta.len() > 0, "downloaded file should be non-empty");
525
526        // Second call should return the cached path instantly.
527        let path2 = cache
528            .download("hf-internal-testing/tiny-random-gpt2", "config.json", None)
529            .await
530            .expect("cached download should succeed");
531        assert_eq!(path, path2);
532    }
533
534    /// Integration test verifying progress callback fires.
535    #[tokio::test]
536    #[ignore = "requires network access to HuggingFace Hub"]
537    async fn test_download_with_progress() {
538        use std::sync::atomic::{AtomicU64, Ordering};
539
540        struct TestProgress {
541            calls: AtomicU64,
542        }
543
544        impl ProgressCallback for TestProgress {
545            fn on_progress(&self, _downloaded_bytes: u64, _total_bytes: Option<u64>) {
546                self.calls.fetch_add(1, Ordering::Relaxed);
547            }
548        }
549
550        let dir = tempfile::tempdir().expect("tempdir");
551        let cache = ModelCache::with_dir(dir.path());
552        let progress = Arc::new(TestProgress {
553            calls: AtomicU64::new(0),
554        });
555
556        // Clone the Arc so we retain a handle for assertions.
557        let cb: Arc<dyn ProgressCallback> = Arc::clone(&progress) as Arc<dyn ProgressCallback>;
558
559        cache
560            .download(
561                "hf-internal-testing/tiny-random-gpt2",
562                "config.json",
563                Some(cb),
564            )
565            .await
566            .expect("download should succeed");
567
568        assert!(
569            progress.calls.load(Ordering::Relaxed) > 0,
570            "progress callback should have been called at least once"
571        );
572    }
573}