blazen_model_cache/lib.rs
1//! Shared model download and cache layer for Blazen local-inference backends.
2//!
3//! Provides [`ModelCache`] for downloading and caching ML models from
4//! [`HuggingFace` Hub](https://huggingface.co). Designed to be shared by all
5//! local-inference backends (fastembed, mistral.rs, whisper.cpp, etc.).
6//!
7//! ## wasm32 support
8//!
9//! On `wasm32-*` targets the underlying download stack (`hf-hub`, `dirs`,
10//! `tokio::fs`) is not available, so [`ModelCache`] is a stub that always
11//! returns [`CacheError::Unsupported`]. Browser/Worker callers should obtain
12//! model bytes through a different mechanism (e.g. `fetch()` on the JS side,
13//! pre-bundled assets, or a manually populated cache directory).
14
15use std::path::{Path, PathBuf};
16#[cfg(not(target_arch = "wasm32"))]
17use std::sync::Arc;
18#[cfg(not(target_arch = "wasm32"))]
19use std::sync::LazyLock;
20
21#[cfg(all(target_arch = "wasm32", target_os = "wasi"))]
22pub mod wasi;
23#[cfg(all(target_arch = "wasm32", target_os = "wasi"))]
24pub use wasi::{WasiHttpFetch, WasiModelCache};
25
26/// Per-destination-path mutexes that serialize concurrent [`ModelCache::download`]
27/// calls for the same file. Different files download in parallel; same-file
28/// callers wait so hf-hub's internal blob lock isn't raced.
29#[cfg(not(target_arch = "wasm32"))]
30static DOWNLOAD_LOCKS: LazyLock<dashmap::DashMap<PathBuf, Arc<tokio::sync::Mutex<()>>>> =
31 LazyLock::new(dashmap::DashMap::new);
32
33/// Errors that can occur during model cache operations.
34#[derive(Debug, thiserror::Error)]
35pub enum CacheError {
36 /// A model download failed.
37 #[error("failed to download model: {0}")]
38 Download(String),
39
40 /// The cache directory could not be resolved or created.
41 #[error("cache directory error: {0}")]
42 CacheDir(String),
43
44 /// An underlying I/O error.
45 #[error("IO error: {0}")]
46 Io(#[from] std::io::Error),
47
48 /// The requested operation is not supported on this target (e.g. WASM,
49 /// where the underlying `HuggingFace` Hub download stack is unavailable).
50 #[error("model cache operation not supported on this target: {0}")]
51 Unsupported(String),
52}
53
54/// Callback trait for receiving download progress updates.
55///
56/// Implement this on your own type to get notified as bytes are downloaded.
57pub trait ProgressCallback: Send + Sync {
58 /// Called periodically during download.
59 ///
60 /// * `downloaded_bytes` - Total bytes downloaded so far.
61 /// * `total_bytes` - Total file size if known by the server.
62 fn on_progress(&self, downloaded_bytes: u64, total_bytes: Option<u64>);
63}
64
65/// Adapter that bridges our [`ProgressCallback`] trait to `hf_hub`'s
66/// [`Progress`](hf_hub::api::tokio::Progress) trait.
67///
68/// Uses `Arc` so the adapter is `Clone + Send + Sync + 'static` as
69/// required by [`hf_hub::api::tokio::ApiRepo::download_with_progress`].
70#[cfg(not(target_arch = "wasm32"))]
71#[derive(Clone)]
72struct HfProgressAdapter {
73 callback: Arc<dyn ProgressCallback>,
74 downloaded: u64,
75 total: Option<u64>,
76}
77
78#[cfg(not(target_arch = "wasm32"))]
79impl HfProgressAdapter {
80 fn new(callback: Arc<dyn ProgressCallback>) -> Self {
81 Self {
82 callback,
83 downloaded: 0,
84 total: None,
85 }
86 }
87}
88
89#[cfg(not(target_arch = "wasm32"))]
90impl hf_hub::api::tokio::Progress for HfProgressAdapter {
91 async fn init(&mut self, size: usize, _filename: &str) {
92 self.total = Some(size as u64);
93 self.downloaded = 0;
94 self.callback.on_progress(0, self.total);
95 }
96
97 async fn update(&mut self, size: usize) {
98 self.downloaded += size as u64;
99 self.callback.on_progress(self.downloaded, self.total);
100 }
101
102 async fn finish(&mut self) {
103 self.callback
104 .on_progress(self.downloaded.max(1), self.total);
105 }
106}
107
108/// A no-op progress implementation used when no callback is provided.
109#[cfg(not(target_arch = "wasm32"))]
110#[derive(Clone)]
111struct NoProgress;
112
113#[cfg(not(target_arch = "wasm32"))]
114impl hf_hub::api::tokio::Progress for NoProgress {
115 async fn init(&mut self, _size: usize, _filename: &str) {}
116 async fn update(&mut self, _size: usize) {}
117 async fn finish(&mut self) {}
118}
119
120/// Local cache for ML models downloaded from `HuggingFace` Hub.
121///
122/// Models are stored under `{cache_dir}/{repo_id}/{filename}`. Files are
123/// downloaded only once; subsequent calls return the cached path immediately.
124///
125/// # Examples
126///
127/// ```no_run
128/// # async fn example() -> Result<(), blazen_model_cache::CacheError> {
129/// use blazen_model_cache::ModelCache;
130///
131/// let cache = ModelCache::new()?;
132/// let path = cache.download("bert-base-uncased", "config.json", None).await?;
133/// println!("model config at: {}", path.display());
134/// # Ok(())
135/// # }
136/// ```
137pub struct ModelCache {
138 cache_dir: PathBuf,
139}
140
141impl ModelCache {
142 /// Create a cache rooted at a specific directory.
143 ///
144 /// The directory does not need to exist yet; it will be created on the
145 /// first download.
146 ///
147 /// Available on every target — only [`Self::new`] and [`Self::download`]
148 /// require the native `HuggingFace` Hub download stack.
149 #[must_use]
150 pub fn with_dir(cache_dir: impl Into<PathBuf>) -> Self {
151 Self {
152 cache_dir: cache_dir.into(),
153 }
154 }
155
156 /// The root cache directory path.
157 #[must_use]
158 pub fn cache_dir(&self) -> &Path {
159 &self.cache_dir
160 }
161
162 /// Check if a file is already present in the cache (without downloading).
163 ///
164 /// On wasm32 this always returns `false` — there is no filesystem to
165 /// inspect, so callers should not rely on cache hits in the browser.
166 #[must_use]
167 pub fn is_cached(&self, repo_id: &str, filename: &str) -> bool {
168 #[cfg(not(target_arch = "wasm32"))]
169 {
170 self.cached_path(repo_id, filename).is_file()
171 }
172 #[cfg(target_arch = "wasm32")]
173 {
174 let _ = (repo_id, filename);
175 false
176 }
177 }
178
179 /// Compute the expected cache path for a repo/file pair.
180 #[cfg(not(target_arch = "wasm32"))]
181 fn cached_path(&self, repo_id: &str, filename: &str) -> PathBuf {
182 self.cache_dir.join(repo_id).join(filename)
183 }
184}
185
186// -- Native-only methods (filesystem + HuggingFace Hub) -----------------------
187
188#[cfg(not(target_arch = "wasm32"))]
189impl ModelCache {
190 /// Create a cache in the default location.
191 ///
192 /// Uses `$BLAZEN_CACHE_DIR/models/` if the `BLAZEN_CACHE_DIR` environment
193 /// variable is set, otherwise falls back to `~/.cache/blazen/models/`.
194 ///
195 /// # Errors
196 ///
197 /// Returns [`CacheError::CacheDir`] if the home directory cannot be
198 /// determined and `BLAZEN_CACHE_DIR` is not set.
199 pub fn new() -> Result<Self, CacheError> {
200 let cache_dir = if let Ok(dir) = std::env::var("BLAZEN_CACHE_DIR") {
201 PathBuf::from(dir).join("models")
202 } else {
203 dirs::cache_dir()
204 .ok_or_else(|| {
205 CacheError::CacheDir(
206 "could not determine home cache directory; \
207 set BLAZEN_CACHE_DIR to override"
208 .to_string(),
209 )
210 })?
211 .join("blazen")
212 .join("models")
213 };
214
215 Ok(Self { cache_dir })
216 }
217
218 /// Download a file from `HuggingFace` Hub if it is not already cached.
219 ///
220 /// Returns the local filesystem path to the cached file.
221 ///
222 /// The file is first downloaded via `hf-hub` into its own managed cache,
223 /// then hard-linked (or copied as fallback) into our
224 /// `{cache_dir}/{repo_id}/{filename}` layout so that callers get a stable,
225 /// predictable path.
226 ///
227 /// # Progress
228 ///
229 /// Pass an `Arc<dyn ProgressCallback>` to receive byte-level progress
230 /// updates during the download. Pass `None` to download silently.
231 ///
232 /// # Errors
233 ///
234 /// Returns [`CacheError::Download`] if the `HuggingFace` API request fails,
235 /// or [`CacheError::Io`] if filesystem operations fail.
236 pub async fn download(
237 &self,
238 repo_id: &str,
239 filename: &str,
240 progress: Option<Arc<dyn ProgressCallback>>,
241 ) -> Result<PathBuf, CacheError> {
242 let dest = self.cached_path(repo_id, filename);
243
244 // Serialize concurrent callers for the same destination path. Different
245 // files download in parallel; same-file callers wait so hf-hub's internal
246 // blob lock isn't raced.
247 let lock = DOWNLOAD_LOCKS
248 .entry(dest.clone())
249 .or_insert_with(|| Arc::new(tokio::sync::Mutex::new(())))
250 .value()
251 .clone();
252 let _guard = lock.lock().await;
253
254 // Already cached -- return immediately (re-check inside the lock so
255 // subsequent callers observe the file created by whoever went first).
256 if dest.is_file() {
257 return Ok(dest);
258 }
259
260 // Ensure the target directory exists.
261 if let Some(parent) = dest.parent() {
262 tokio::fs::create_dir_all(parent).await?;
263 }
264
265 // Build the hf-hub async API.
266 let api = hf_hub::api::tokio::ApiBuilder::new()
267 .with_progress(false) // We handle progress ourselves.
268 .build()
269 .map_err(|e| CacheError::Download(e.to_string()))?;
270
271 let repo = api.model(repo_id.to_string());
272
273 // Download through hf-hub (it manages its own cache under ~/.cache/huggingface).
274 let hf_path = if let Some(cb) = progress {
275 let adapter = HfProgressAdapter::new(Arc::clone(&cb));
276 repo.download_with_progress(filename, adapter)
277 .await
278 .map_err(|e| CacheError::Download(e.to_string()))?
279 } else {
280 let noop = NoProgress;
281 repo.download_with_progress(filename, noop)
282 .await
283 .map_err(|e| CacheError::Download(e.to_string()))?
284 };
285
286 // hf-hub returns `snapshots/main/<filename>` as a symlink into
287 // `blobs/<hash>`. Resolve to the real blob so hard-linking targets the
288 // actual file; otherwise on some filesystems we would hard-link the
289 // symlink itself, which can behave unexpectedly if the snapshot is
290 // pruned later. If canonicalization fails (e.g. broken chain), fall
291 // back to the original path and let the copy fallback handle it.
292 let hf_path_resolved = tokio::fs::canonicalize(&hf_path)
293 .await
294 .unwrap_or_else(|_| hf_path.clone());
295
296 // Link or copy the file into our own cache layout.
297 if dest != hf_path_resolved {
298 // Try hard link first (instant, no extra disk space).
299 if tokio::fs::hard_link(&hf_path_resolved, &dest)
300 .await
301 .is_err()
302 {
303 // Cross-device or unsupported -- fall back to copy.
304 tokio::fs::copy(&hf_path_resolved, &dest).await?;
305 }
306 }
307
308 // Postcondition: dest must exist after a successful download.
309 if !dest.is_file() {
310 return Err(CacheError::Io(std::io::Error::new(
311 std::io::ErrorKind::NotFound,
312 format!(
313 "download completed but cache path is missing: {}",
314 dest.display()
315 ),
316 )));
317 }
318
319 Ok(dest)
320 }
321}
322
323// -- wasm32 stubs (browser only; wasi gets a real in-memory cache via `wasi::WasiModelCache`)
324
325#[cfg(all(target_arch = "wasm32", not(target_os = "wasi")))]
326impl ModelCache {
327 /// Stub that returns [`CacheError::Unsupported`] on wasm32.
328 ///
329 /// On wasm32 there is no `dirs::cache_dir()` and no `BLAZEN_CACHE_DIR`
330 /// environment lookup that would produce a usable path. Use
331 /// [`Self::with_dir`] with an explicit virtual path instead.
332 ///
333 /// # Errors
334 ///
335 /// Always returns [`CacheError::Unsupported`].
336 pub fn new() -> Result<Self, CacheError> {
337 Err(CacheError::Unsupported(
338 "ModelCache::new() is not supported on wasm32; use ModelCache::with_dir() instead"
339 .to_string(),
340 ))
341 }
342
343 /// Stub that always returns [`CacheError::Unsupported`] on wasm32.
344 ///
345 /// The HuggingFace Hub client (`hf-hub`) and `tokio::fs` are not
346 /// compatible with the `wasm32-*` targets, so model files cannot be
347 /// downloaded from this side. Browser/Worker callers should fetch model
348 /// bytes via the JavaScript `fetch()` API and pass them in directly.
349 ///
350 /// # Errors
351 ///
352 /// Always returns [`CacheError::Unsupported`].
353 pub async fn download(
354 &self,
355 repo_id: &str,
356 filename: &str,
357 progress: Option<Arc<dyn ProgressCallback>>,
358 ) -> Result<PathBuf, CacheError> {
359 let _ = (repo_id, filename, progress);
360 Err(CacheError::Unsupported(format!(
361 "ModelCache::download() is not supported on wasm32 \
362 (cache_dir={})",
363 self.cache_dir.display()
364 )))
365 }
366}
367
368#[cfg(all(test, not(target_arch = "wasm32")))]
369#[allow(unsafe_code)] // env::set_var / env::remove_var require unsafe in edition 2024
370mod tests {
371 use super::*;
372
373 #[test]
374 fn test_default_cache_dir() {
375 // When BLAZEN_CACHE_DIR is not set, the cache should live under the
376 // platform cache directory (e.g. ~/.cache/blazen/models/ on Linux).
377 // We temporarily remove the env var to test the default path.
378 let had_var = std::env::var("BLAZEN_CACHE_DIR").ok();
379
380 // SAFETY: This test runs single-threaded and restores the variable
381 // immediately after the assertion. No other thread reads this var
382 // concurrently in the test suite (env-var tests are inherently racy
383 // but acceptable in `#[test]` which defaults to `--test-threads=1`
384 // per binary).
385 unsafe {
386 std::env::remove_var("BLAZEN_CACHE_DIR");
387 }
388
389 let cache = ModelCache::new().expect("default cache should succeed");
390 let path = cache.cache_dir();
391
392 // Should end with blazen/models
393 assert!(
394 path.ends_with("blazen/models"),
395 "expected path ending with blazen/models, got: {}",
396 path.display()
397 );
398
399 // Restore env var if it was set.
400 if let Some(val) = had_var {
401 // SAFETY: restoring the original value.
402 unsafe {
403 std::env::set_var("BLAZEN_CACHE_DIR", val);
404 }
405 }
406 }
407
408 #[test]
409 fn test_default_cache_dir_from_env() {
410 let prev = std::env::var("BLAZEN_CACHE_DIR").ok();
411
412 // SAFETY: see `test_default_cache_dir`.
413 unsafe {
414 std::env::set_var("BLAZEN_CACHE_DIR", "/tmp/blazen-test-cache");
415 }
416
417 let cache = ModelCache::new().expect("env-based cache should succeed");
418 assert_eq!(
419 cache.cache_dir(),
420 Path::new("/tmp/blazen-test-cache/models")
421 );
422
423 // Restore.
424 // SAFETY: see `test_default_cache_dir`.
425 unsafe {
426 match prev {
427 Some(val) => std::env::set_var("BLAZEN_CACHE_DIR", val),
428 None => std::env::remove_var("BLAZEN_CACHE_DIR"),
429 }
430 }
431 }
432
433 #[test]
434 fn test_with_dir() {
435 let dir = tempfile::tempdir().expect("tempdir");
436 let cache = ModelCache::with_dir(dir.path());
437 assert_eq!(cache.cache_dir(), dir.path());
438 }
439
440 #[test]
441 fn test_is_cached_false_initially() {
442 let dir = tempfile::tempdir().expect("tempdir");
443 let cache = ModelCache::with_dir(dir.path());
444 assert!(!cache.is_cached("foo/bar", "model.gguf"));
445 }
446
447 #[test]
448 fn test_is_cached_true_after_manual_placement() {
449 let dir = tempfile::tempdir().expect("tempdir");
450 let cache = ModelCache::with_dir(dir.path());
451
452 // Manually create the file to simulate a cached download.
453 let file_dir = dir.path().join("my-org/my-model");
454 std::fs::create_dir_all(&file_dir).unwrap();
455 std::fs::write(file_dir.join("config.json"), b"{}").unwrap();
456
457 assert!(cache.is_cached("my-org/my-model", "config.json"));
458 }
459
460 #[test]
461 fn test_cached_path_layout() {
462 let cache = ModelCache::with_dir("/fake/cache");
463 let path = cache.cached_path("org/model", "weights.bin");
464 assert_eq!(path, PathBuf::from("/fake/cache/org/model/weights.bin"));
465 }
466
467 /// Verifies that the per-path lock actually serializes concurrent callers
468 /// targeting the same destination. We can't easily mock hf-hub inside
469 /// `download()`, so this test exercises the serialization primitive
470 /// directly: if the lock map misbehaves (e.g. hands out independent
471 /// mutexes for the same path), more than one task will sit inside the
472 /// critical section at the same time and the counter will exceed zero
473 /// when observed by another task.
474 #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
475 async fn concurrent_downloads_serialize_same_path() {
476 let tmp = tempfile::tempdir().expect("tempdir");
477 let cache = ModelCache::with_dir(tmp.path().to_path_buf());
478 let dest = cache.cached_path("test/repo", "file.bin");
479
480 let counter = Arc::new(std::sync::atomic::AtomicUsize::new(0));
481 let mut handles = Vec::new();
482 for _ in 0..4 {
483 let dest_clone = dest.clone();
484 let counter_clone = Arc::clone(&counter);
485 handles.push(tokio::spawn(async move {
486 let lock = DOWNLOAD_LOCKS
487 .entry(dest_clone)
488 .or_insert_with(|| Arc::new(tokio::sync::Mutex::new(())))
489 .value()
490 .clone();
491 let _guard = lock.lock().await;
492 // If another task already holds the lock, it would have
493 // incremented the counter before us; the assertion below
494 // would then catch the violation.
495 let prev = counter_clone.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
496 // Simulate in-flight work to widen the race window.
497 tokio::time::sleep(std::time::Duration::from_millis(10)).await;
498 counter_clone.fetch_sub(1, std::sync::atomic::Ordering::SeqCst);
499 prev
500 }));
501 }
502
503 let results = futures_util::future::join_all(handles).await;
504 for r in results {
505 let prev = r.expect("task panicked");
506 assert_eq!(prev, 0, "another task held the lock concurrently");
507 }
508 }
509
510 /// Integration test that actually downloads from `HuggingFace` Hub.
511 ///
512 /// Ignored by default because it requires network access. Run with:
513 /// ```sh
514 /// cargo test -p blazen-model-cache -- --ignored
515 /// ```
516 #[tokio::test]
517 #[ignore = "requires network access to HuggingFace Hub"]
518 async fn test_download_and_cache() {
519 let dir = tempfile::tempdir().expect("tempdir");
520 let cache = ModelCache::with_dir(dir.path());
521
522 // Download a tiny file (~600 bytes).
523 let path = cache
524 .download("hf-internal-testing/tiny-random-gpt2", "config.json", None)
525 .await
526 .expect("download should succeed");
527
528 // File should exist and have non-zero size.
529 assert!(path.is_file(), "downloaded file should exist");
530 let meta = std::fs::metadata(&path).expect("metadata");
531 assert!(meta.len() > 0, "downloaded file should be non-empty");
532
533 // Second call should return the cached path instantly.
534 let path2 = cache
535 .download("hf-internal-testing/tiny-random-gpt2", "config.json", None)
536 .await
537 .expect("cached download should succeed");
538 assert_eq!(path, path2);
539 }
540
541 /// Integration test verifying progress callback fires.
542 #[tokio::test]
543 #[ignore = "requires network access to HuggingFace Hub"]
544 async fn test_download_with_progress() {
545 use std::sync::atomic::{AtomicU64, Ordering};
546
547 struct TestProgress {
548 calls: AtomicU64,
549 }
550
551 impl ProgressCallback for TestProgress {
552 fn on_progress(&self, _downloaded_bytes: u64, _total_bytes: Option<u64>) {
553 self.calls.fetch_add(1, Ordering::Relaxed);
554 }
555 }
556
557 let dir = tempfile::tempdir().expect("tempdir");
558 let cache = ModelCache::with_dir(dir.path());
559 let progress = Arc::new(TestProgress {
560 calls: AtomicU64::new(0),
561 });
562
563 // Clone the Arc so we retain a handle for assertions.
564 let cb: Arc<dyn ProgressCallback> = Arc::clone(&progress) as Arc<dyn ProgressCallback>;
565
566 cache
567 .download(
568 "hf-internal-testing/tiny-random-gpt2",
569 "config.json",
570 Some(cb),
571 )
572 .await
573 .expect("download should succeed");
574
575 assert!(
576 progress.calls.load(Ordering::Relaxed) > 0,
577 "progress callback should have been called at least once"
578 );
579 }
580}