Skip to main content

blazen_model_cache/
lib.rs

1//! Shared model download and cache layer for Blazen local-inference backends.
2//!
3//! Provides [`ModelCache`] for downloading and caching ML models from
4//! [`HuggingFace` Hub](https://huggingface.co). Designed to be shared by all
5//! local-inference backends (fastembed, mistral.rs, whisper.cpp, etc.).
6//!
7//! ## wasm32 support
8//!
9//! On `wasm32-*` targets the underlying download stack (`hf-hub`, `dirs`,
10//! `tokio::fs`) is not available, so [`ModelCache`] is a stub that always
11//! returns [`CacheError::Unsupported`]. Browser/Worker callers should obtain
12//! model bytes through a different mechanism (e.g. `fetch()` on the JS side,
13//! pre-bundled assets, or a manually populated cache directory).
14
15use std::path::{Path, PathBuf};
16#[cfg(not(target_arch = "wasm32"))]
17use std::sync::Arc;
18#[cfg(not(target_arch = "wasm32"))]
19use std::sync::LazyLock;
20
21#[cfg(all(target_arch = "wasm32", target_os = "wasi"))]
22pub mod wasi;
23#[cfg(all(target_arch = "wasm32", target_os = "wasi"))]
24pub use wasi::{WasiHttpFetch, WasiModelCache};
25
26/// Per-destination-path mutexes that serialize concurrent [`ModelCache::download`]
27/// calls for the same file. Different files download in parallel; same-file
28/// callers wait so hf-hub's internal blob lock isn't raced.
29#[cfg(not(target_arch = "wasm32"))]
30static DOWNLOAD_LOCKS: LazyLock<dashmap::DashMap<PathBuf, Arc<tokio::sync::Mutex<()>>>> =
31    LazyLock::new(dashmap::DashMap::new);
32
33/// Errors that can occur during model cache operations.
34#[derive(Debug, thiserror::Error)]
35pub enum CacheError {
36    /// A model download failed.
37    #[error("failed to download model: {0}")]
38    Download(String),
39
40    /// The cache directory could not be resolved or created.
41    #[error("cache directory error: {0}")]
42    CacheDir(String),
43
44    /// An underlying I/O error.
45    #[error("IO error: {0}")]
46    Io(#[from] std::io::Error),
47
48    /// The requested operation is not supported on this target (e.g. WASM,
49    /// where the underlying `HuggingFace` Hub download stack is unavailable).
50    #[error("model cache operation not supported on this target: {0}")]
51    Unsupported(String),
52}
53
54/// Callback trait for receiving download progress updates.
55///
56/// Implement this on your own type to get notified as bytes are downloaded.
57pub trait ProgressCallback: Send + Sync {
58    /// Called periodically during download.
59    ///
60    /// * `downloaded_bytes` - Total bytes downloaded so far.
61    /// * `total_bytes` - Total file size if known by the server.
62    fn on_progress(&self, downloaded_bytes: u64, total_bytes: Option<u64>);
63}
64
65/// Adapter that bridges our [`ProgressCallback`] trait to `hf_hub`'s
66/// [`Progress`](hf_hub::api::tokio::Progress) trait.
67///
68/// Uses `Arc` so the adapter is `Clone + Send + Sync + 'static` as
69/// required by [`hf_hub::api::tokio::ApiRepo::download_with_progress`].
70#[cfg(not(target_arch = "wasm32"))]
71#[derive(Clone)]
72struct HfProgressAdapter {
73    callback: Arc<dyn ProgressCallback>,
74    downloaded: u64,
75    total: Option<u64>,
76}
77
78#[cfg(not(target_arch = "wasm32"))]
79impl HfProgressAdapter {
80    fn new(callback: Arc<dyn ProgressCallback>) -> Self {
81        Self {
82            callback,
83            downloaded: 0,
84            total: None,
85        }
86    }
87}
88
89#[cfg(not(target_arch = "wasm32"))]
90impl hf_hub::api::tokio::Progress for HfProgressAdapter {
91    async fn init(&mut self, size: usize, _filename: &str) {
92        self.total = Some(size as u64);
93        self.downloaded = 0;
94        self.callback.on_progress(0, self.total);
95    }
96
97    async fn update(&mut self, size: usize) {
98        self.downloaded += size as u64;
99        self.callback.on_progress(self.downloaded, self.total);
100    }
101
102    async fn finish(&mut self) {
103        self.callback
104            .on_progress(self.downloaded.max(1), self.total);
105    }
106}
107
108/// A no-op progress implementation used when no callback is provided.
109#[cfg(not(target_arch = "wasm32"))]
110#[derive(Clone)]
111struct NoProgress;
112
113#[cfg(not(target_arch = "wasm32"))]
114impl hf_hub::api::tokio::Progress for NoProgress {
115    async fn init(&mut self, _size: usize, _filename: &str) {}
116    async fn update(&mut self, _size: usize) {}
117    async fn finish(&mut self) {}
118}
119
120/// Local cache for ML models downloaded from `HuggingFace` Hub.
121///
122/// Models are stored under `{cache_dir}/{repo_id}/{filename}`. Files are
123/// downloaded only once; subsequent calls return the cached path immediately.
124///
125/// # Examples
126///
127/// ```no_run
128/// # async fn example() -> Result<(), blazen_model_cache::CacheError> {
129/// use blazen_model_cache::ModelCache;
130///
131/// let cache = ModelCache::new()?;
132/// let path = cache.download("bert-base-uncased", "config.json", None).await?;
133/// println!("model config at: {}", path.display());
134/// # Ok(())
135/// # }
136/// ```
137pub struct ModelCache {
138    cache_dir: PathBuf,
139}
140
141impl ModelCache {
142    /// Create a cache rooted at a specific directory.
143    ///
144    /// The directory does not need to exist yet; it will be created on the
145    /// first download.
146    ///
147    /// Available on every target — only [`Self::new`] and [`Self::download`]
148    /// require the native `HuggingFace` Hub download stack.
149    #[must_use]
150    pub fn with_dir(cache_dir: impl Into<PathBuf>) -> Self {
151        Self {
152            cache_dir: cache_dir.into(),
153        }
154    }
155
156    /// The root cache directory path.
157    #[must_use]
158    pub fn cache_dir(&self) -> &Path {
159        &self.cache_dir
160    }
161
162    /// Check if a file is already present in the cache (without downloading).
163    ///
164    /// On wasm32 this always returns `false` — there is no filesystem to
165    /// inspect, so callers should not rely on cache hits in the browser.
166    #[must_use]
167    pub fn is_cached(&self, repo_id: &str, filename: &str) -> bool {
168        #[cfg(not(target_arch = "wasm32"))]
169        {
170            self.cached_path(repo_id, filename).is_file()
171        }
172        #[cfg(target_arch = "wasm32")]
173        {
174            let _ = (repo_id, filename);
175            false
176        }
177    }
178
179    /// Compute the expected cache path for a repo/file pair.
180    #[cfg(not(target_arch = "wasm32"))]
181    fn cached_path(&self, repo_id: &str, filename: &str) -> PathBuf {
182        self.cache_dir.join(repo_id).join(filename)
183    }
184}
185
186// -- Native-only methods (filesystem + HuggingFace Hub) -----------------------
187
188#[cfg(not(target_arch = "wasm32"))]
189impl ModelCache {
190    /// Create a cache in the default location.
191    ///
192    /// Uses `$BLAZEN_CACHE_DIR/models/` if the `BLAZEN_CACHE_DIR` environment
193    /// variable is set, otherwise falls back to `~/.cache/blazen/models/`.
194    ///
195    /// # Errors
196    ///
197    /// Returns [`CacheError::CacheDir`] if the home directory cannot be
198    /// determined and `BLAZEN_CACHE_DIR` is not set.
199    pub fn new() -> Result<Self, CacheError> {
200        let cache_dir = if let Ok(dir) = std::env::var("BLAZEN_CACHE_DIR") {
201            PathBuf::from(dir).join("models")
202        } else {
203            dirs::cache_dir()
204                .ok_or_else(|| {
205                    CacheError::CacheDir(
206                        "could not determine home cache directory; \
207                         set BLAZEN_CACHE_DIR to override"
208                            .to_string(),
209                    )
210                })?
211                .join("blazen")
212                .join("models")
213        };
214
215        Ok(Self { cache_dir })
216    }
217
218    /// Download a file from `HuggingFace` Hub if it is not already cached.
219    ///
220    /// Returns the local filesystem path to the cached file.
221    ///
222    /// The file is first downloaded via `hf-hub` into its own managed cache,
223    /// then hard-linked (or copied as fallback) into our
224    /// `{cache_dir}/{repo_id}/{filename}` layout so that callers get a stable,
225    /// predictable path.
226    ///
227    /// # Progress
228    ///
229    /// Pass an `Arc<dyn ProgressCallback>` to receive byte-level progress
230    /// updates during the download. Pass `None` to download silently.
231    ///
232    /// # Errors
233    ///
234    /// Returns [`CacheError::Download`] if the `HuggingFace` API request fails,
235    /// or [`CacheError::Io`] if filesystem operations fail.
236    pub async fn download(
237        &self,
238        repo_id: &str,
239        filename: &str,
240        progress: Option<Arc<dyn ProgressCallback>>,
241    ) -> Result<PathBuf, CacheError> {
242        let dest = self.cached_path(repo_id, filename);
243
244        // Serialize concurrent callers for the same destination path. Different
245        // files download in parallel; same-file callers wait so hf-hub's internal
246        // blob lock isn't raced.
247        let lock = DOWNLOAD_LOCKS
248            .entry(dest.clone())
249            .or_insert_with(|| Arc::new(tokio::sync::Mutex::new(())))
250            .value()
251            .clone();
252        let _guard = lock.lock().await;
253
254        // Already cached -- return immediately (re-check inside the lock so
255        // subsequent callers observe the file created by whoever went first).
256        if dest.is_file() {
257            return Ok(dest);
258        }
259
260        // Ensure the target directory exists.
261        if let Some(parent) = dest.parent() {
262            tokio::fs::create_dir_all(parent).await?;
263        }
264
265        // Build the hf-hub async API.
266        let api = hf_hub::api::tokio::ApiBuilder::new()
267            .with_progress(false) // We handle progress ourselves.
268            .build()
269            .map_err(|e| CacheError::Download(e.to_string()))?;
270
271        let repo = api.model(repo_id.to_string());
272
273        // Download through hf-hub (it manages its own cache under ~/.cache/huggingface).
274        let hf_path = if let Some(cb) = progress {
275            let adapter = HfProgressAdapter::new(Arc::clone(&cb));
276            repo.download_with_progress(filename, adapter)
277                .await
278                .map_err(|e| CacheError::Download(e.to_string()))?
279        } else {
280            let noop = NoProgress;
281            repo.download_with_progress(filename, noop)
282                .await
283                .map_err(|e| CacheError::Download(e.to_string()))?
284        };
285
286        // hf-hub returns `snapshots/main/<filename>` as a symlink into
287        // `blobs/<hash>`. Resolve to the real blob so hard-linking targets the
288        // actual file; otherwise on some filesystems we would hard-link the
289        // symlink itself, which can behave unexpectedly if the snapshot is
290        // pruned later. If canonicalization fails (e.g. broken chain), fall
291        // back to the original path and let the copy fallback handle it.
292        let hf_path_resolved = tokio::fs::canonicalize(&hf_path)
293            .await
294            .unwrap_or_else(|_| hf_path.clone());
295
296        // Link or copy the file into our own cache layout.
297        if dest != hf_path_resolved {
298            // Try hard link first (instant, no extra disk space).
299            if tokio::fs::hard_link(&hf_path_resolved, &dest)
300                .await
301                .is_err()
302            {
303                // Cross-device or unsupported -- fall back to copy.
304                tokio::fs::copy(&hf_path_resolved, &dest).await?;
305            }
306        }
307
308        // Postcondition: dest must exist after a successful download.
309        if !dest.is_file() {
310            return Err(CacheError::Io(std::io::Error::new(
311                std::io::ErrorKind::NotFound,
312                format!(
313                    "download completed but cache path is missing: {}",
314                    dest.display()
315                ),
316            )));
317        }
318
319        Ok(dest)
320    }
321}
322
323// -- wasm32 stubs (browser only; wasi gets a real in-memory cache via `wasi::WasiModelCache`)
324
325#[cfg(all(target_arch = "wasm32", not(target_os = "wasi")))]
326impl ModelCache {
327    /// Stub that returns [`CacheError::Unsupported`] on wasm32.
328    ///
329    /// On wasm32 there is no `dirs::cache_dir()` and no `BLAZEN_CACHE_DIR`
330    /// environment lookup that would produce a usable path. Use
331    /// [`Self::with_dir`] with an explicit virtual path instead.
332    ///
333    /// # Errors
334    ///
335    /// Always returns [`CacheError::Unsupported`].
336    pub fn new() -> Result<Self, CacheError> {
337        Err(CacheError::Unsupported(
338            "ModelCache::new() is not supported on wasm32; use ModelCache::with_dir() instead"
339                .to_string(),
340        ))
341    }
342
343    /// Stub that always returns [`CacheError::Unsupported`] on wasm32.
344    ///
345    /// The HuggingFace Hub client (`hf-hub`) and `tokio::fs` are not
346    /// compatible with the `wasm32-*` targets, so model files cannot be
347    /// downloaded from this side. Browser/Worker callers should fetch model
348    /// bytes via the JavaScript `fetch()` API and pass them in directly.
349    ///
350    /// # Errors
351    ///
352    /// Always returns [`CacheError::Unsupported`].
353    pub async fn download(
354        &self,
355        repo_id: &str,
356        filename: &str,
357        progress: Option<Arc<dyn ProgressCallback>>,
358    ) -> Result<PathBuf, CacheError> {
359        let _ = (repo_id, filename, progress);
360        Err(CacheError::Unsupported(format!(
361            "ModelCache::download() is not supported on wasm32 \
362             (cache_dir={})",
363            self.cache_dir.display()
364        )))
365    }
366}
367
368#[cfg(all(test, not(target_arch = "wasm32")))]
369#[allow(unsafe_code)] // env::set_var / env::remove_var require unsafe in edition 2024
370mod tests {
371    use super::*;
372
373    #[test]
374    fn test_default_cache_dir() {
375        // When BLAZEN_CACHE_DIR is not set, the cache should live under the
376        // platform cache directory (e.g. ~/.cache/blazen/models/ on Linux).
377        // We temporarily remove the env var to test the default path.
378        let had_var = std::env::var("BLAZEN_CACHE_DIR").ok();
379
380        // SAFETY: This test runs single-threaded and restores the variable
381        // immediately after the assertion. No other thread reads this var
382        // concurrently in the test suite (env-var tests are inherently racy
383        // but acceptable in `#[test]` which defaults to `--test-threads=1`
384        // per binary).
385        unsafe {
386            std::env::remove_var("BLAZEN_CACHE_DIR");
387        }
388
389        let cache = ModelCache::new().expect("default cache should succeed");
390        let path = cache.cache_dir();
391
392        // Should end with blazen/models
393        assert!(
394            path.ends_with("blazen/models"),
395            "expected path ending with blazen/models, got: {}",
396            path.display()
397        );
398
399        // Restore env var if it was set.
400        if let Some(val) = had_var {
401            // SAFETY: restoring the original value.
402            unsafe {
403                std::env::set_var("BLAZEN_CACHE_DIR", val);
404            }
405        }
406    }
407
408    #[test]
409    fn test_default_cache_dir_from_env() {
410        let prev = std::env::var("BLAZEN_CACHE_DIR").ok();
411
412        // SAFETY: see `test_default_cache_dir`.
413        unsafe {
414            std::env::set_var("BLAZEN_CACHE_DIR", "/tmp/blazen-test-cache");
415        }
416
417        let cache = ModelCache::new().expect("env-based cache should succeed");
418        assert_eq!(
419            cache.cache_dir(),
420            Path::new("/tmp/blazen-test-cache/models")
421        );
422
423        // Restore.
424        // SAFETY: see `test_default_cache_dir`.
425        unsafe {
426            match prev {
427                Some(val) => std::env::set_var("BLAZEN_CACHE_DIR", val),
428                None => std::env::remove_var("BLAZEN_CACHE_DIR"),
429            }
430        }
431    }
432
433    #[test]
434    fn test_with_dir() {
435        let dir = tempfile::tempdir().expect("tempdir");
436        let cache = ModelCache::with_dir(dir.path());
437        assert_eq!(cache.cache_dir(), dir.path());
438    }
439
440    #[test]
441    fn test_is_cached_false_initially() {
442        let dir = tempfile::tempdir().expect("tempdir");
443        let cache = ModelCache::with_dir(dir.path());
444        assert!(!cache.is_cached("foo/bar", "model.gguf"));
445    }
446
447    #[test]
448    fn test_is_cached_true_after_manual_placement() {
449        let dir = tempfile::tempdir().expect("tempdir");
450        let cache = ModelCache::with_dir(dir.path());
451
452        // Manually create the file to simulate a cached download.
453        let file_dir = dir.path().join("my-org/my-model");
454        std::fs::create_dir_all(&file_dir).unwrap();
455        std::fs::write(file_dir.join("config.json"), b"{}").unwrap();
456
457        assert!(cache.is_cached("my-org/my-model", "config.json"));
458    }
459
460    #[test]
461    fn test_cached_path_layout() {
462        let cache = ModelCache::with_dir("/fake/cache");
463        let path = cache.cached_path("org/model", "weights.bin");
464        assert_eq!(path, PathBuf::from("/fake/cache/org/model/weights.bin"));
465    }
466
467    /// Verifies that the per-path lock actually serializes concurrent callers
468    /// targeting the same destination. We can't easily mock hf-hub inside
469    /// `download()`, so this test exercises the serialization primitive
470    /// directly: if the lock map misbehaves (e.g. hands out independent
471    /// mutexes for the same path), more than one task will sit inside the
472    /// critical section at the same time and the counter will exceed zero
473    /// when observed by another task.
474    #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
475    async fn concurrent_downloads_serialize_same_path() {
476        let tmp = tempfile::tempdir().expect("tempdir");
477        let cache = ModelCache::with_dir(tmp.path().to_path_buf());
478        let dest = cache.cached_path("test/repo", "file.bin");
479
480        let counter = Arc::new(std::sync::atomic::AtomicUsize::new(0));
481        let mut handles = Vec::new();
482        for _ in 0..4 {
483            let dest_clone = dest.clone();
484            let counter_clone = Arc::clone(&counter);
485            handles.push(tokio::spawn(async move {
486                let lock = DOWNLOAD_LOCKS
487                    .entry(dest_clone)
488                    .or_insert_with(|| Arc::new(tokio::sync::Mutex::new(())))
489                    .value()
490                    .clone();
491                let _guard = lock.lock().await;
492                // If another task already holds the lock, it would have
493                // incremented the counter before us; the assertion below
494                // would then catch the violation.
495                let prev = counter_clone.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
496                // Simulate in-flight work to widen the race window.
497                tokio::time::sleep(std::time::Duration::from_millis(10)).await;
498                counter_clone.fetch_sub(1, std::sync::atomic::Ordering::SeqCst);
499                prev
500            }));
501        }
502
503        let results = futures_util::future::join_all(handles).await;
504        for r in results {
505            let prev = r.expect("task panicked");
506            assert_eq!(prev, 0, "another task held the lock concurrently");
507        }
508    }
509
510    /// Integration test that actually downloads from `HuggingFace` Hub.
511    ///
512    /// Ignored by default because it requires network access. Run with:
513    /// ```sh
514    /// cargo test -p blazen-model-cache -- --ignored
515    /// ```
516    #[tokio::test]
517    #[ignore = "requires network access to HuggingFace Hub"]
518    async fn test_download_and_cache() {
519        let dir = tempfile::tempdir().expect("tempdir");
520        let cache = ModelCache::with_dir(dir.path());
521
522        // Download a tiny file (~600 bytes).
523        let path = cache
524            .download("hf-internal-testing/tiny-random-gpt2", "config.json", None)
525            .await
526            .expect("download should succeed");
527
528        // File should exist and have non-zero size.
529        assert!(path.is_file(), "downloaded file should exist");
530        let meta = std::fs::metadata(&path).expect("metadata");
531        assert!(meta.len() > 0, "downloaded file should be non-empty");
532
533        // Second call should return the cached path instantly.
534        let path2 = cache
535            .download("hf-internal-testing/tiny-random-gpt2", "config.json", None)
536            .await
537            .expect("cached download should succeed");
538        assert_eq!(path, path2);
539    }
540
541    /// Integration test verifying progress callback fires.
542    #[tokio::test]
543    #[ignore = "requires network access to HuggingFace Hub"]
544    async fn test_download_with_progress() {
545        use std::sync::atomic::{AtomicU64, Ordering};
546
547        struct TestProgress {
548            calls: AtomicU64,
549        }
550
551        impl ProgressCallback for TestProgress {
552            fn on_progress(&self, _downloaded_bytes: u64, _total_bytes: Option<u64>) {
553                self.calls.fetch_add(1, Ordering::Relaxed);
554            }
555        }
556
557        let dir = tempfile::tempdir().expect("tempdir");
558        let cache = ModelCache::with_dir(dir.path());
559        let progress = Arc::new(TestProgress {
560            calls: AtomicU64::new(0),
561        });
562
563        // Clone the Arc so we retain a handle for assertions.
564        let cb: Arc<dyn ProgressCallback> = Arc::clone(&progress) as Arc<dyn ProgressCallback>;
565
566        cache
567            .download(
568                "hf-internal-testing/tiny-random-gpt2",
569                "config.json",
570                Some(cb),
571            )
572            .await
573            .expect("download should succeed");
574
575        assert!(
576            progress.calls.load(Ordering::Relaxed) > 0,
577            "progress callback should have been called at least once"
578        );
579    }
580}