blazen_model_cache/lib.rs
1//! Shared model download and cache layer for Blazen local-inference backends.
2//!
3//! Provides [`ModelCache`] for downloading and caching ML models from
4//! [`HuggingFace` Hub](https://huggingface.co). Designed to be shared by all
5//! local-inference backends (fastembed, mistral.rs, whisper.cpp, etc.).
6//!
7//! ## wasm32 support
8//!
9//! On `wasm32-*` targets the underlying download stack (`hf-hub`, `dirs`,
10//! `tokio::fs`) is not available, so [`ModelCache`] is a stub that always
11//! returns [`CacheError::Unsupported`]. Browser/Worker callers should obtain
12//! model bytes through a different mechanism (e.g. `fetch()` on the JS side,
13//! pre-bundled assets, or a manually populated cache directory).
14
15use std::path::{Path, PathBuf};
16use std::sync::Arc;
17#[cfg(not(target_arch = "wasm32"))]
18use std::sync::LazyLock;
19
20/// Per-destination-path mutexes that serialize concurrent [`ModelCache::download`]
21/// calls for the same file. Different files download in parallel; same-file
22/// callers wait so hf-hub's internal blob lock isn't raced.
23#[cfg(not(target_arch = "wasm32"))]
24static DOWNLOAD_LOCKS: LazyLock<dashmap::DashMap<PathBuf, Arc<tokio::sync::Mutex<()>>>> =
25 LazyLock::new(dashmap::DashMap::new);
26
27/// Errors that can occur during model cache operations.
28#[derive(Debug, thiserror::Error)]
29pub enum CacheError {
30 /// A model download failed.
31 #[error("failed to download model: {0}")]
32 Download(String),
33
34 /// The cache directory could not be resolved or created.
35 #[error("cache directory error: {0}")]
36 CacheDir(String),
37
38 /// An underlying I/O error.
39 #[error("IO error: {0}")]
40 Io(#[from] std::io::Error),
41
42 /// The requested operation is not supported on this target (e.g. WASM,
43 /// where the underlying `HuggingFace` Hub download stack is unavailable).
44 #[error("model cache operation not supported on this target: {0}")]
45 Unsupported(String),
46}
47
48/// Callback trait for receiving download progress updates.
49///
50/// Implement this on your own type to get notified as bytes are downloaded.
51pub trait ProgressCallback: Send + Sync {
52 /// Called periodically during download.
53 ///
54 /// * `downloaded_bytes` - Total bytes downloaded so far.
55 /// * `total_bytes` - Total file size if known by the server.
56 fn on_progress(&self, downloaded_bytes: u64, total_bytes: Option<u64>);
57}
58
59/// Adapter that bridges our [`ProgressCallback`] trait to `hf_hub`'s
60/// [`Progress`](hf_hub::api::tokio::Progress) trait.
61///
62/// Uses `Arc` so the adapter is `Clone + Send + Sync + 'static` as
63/// required by [`hf_hub::api::tokio::ApiRepo::download_with_progress`].
64#[cfg(not(target_arch = "wasm32"))]
65#[derive(Clone)]
66struct HfProgressAdapter {
67 callback: Arc<dyn ProgressCallback>,
68 downloaded: u64,
69 total: Option<u64>,
70}
71
72#[cfg(not(target_arch = "wasm32"))]
73impl HfProgressAdapter {
74 fn new(callback: Arc<dyn ProgressCallback>) -> Self {
75 Self {
76 callback,
77 downloaded: 0,
78 total: None,
79 }
80 }
81}
82
83#[cfg(not(target_arch = "wasm32"))]
84impl hf_hub::api::tokio::Progress for HfProgressAdapter {
85 async fn init(&mut self, size: usize, _filename: &str) {
86 self.total = Some(size as u64);
87 self.downloaded = 0;
88 self.callback.on_progress(0, self.total);
89 }
90
91 async fn update(&mut self, size: usize) {
92 self.downloaded += size as u64;
93 self.callback.on_progress(self.downloaded, self.total);
94 }
95
96 async fn finish(&mut self) {
97 self.callback
98 .on_progress(self.downloaded.max(1), self.total);
99 }
100}
101
102/// A no-op progress implementation used when no callback is provided.
103#[cfg(not(target_arch = "wasm32"))]
104#[derive(Clone)]
105struct NoProgress;
106
107#[cfg(not(target_arch = "wasm32"))]
108impl hf_hub::api::tokio::Progress for NoProgress {
109 async fn init(&mut self, _size: usize, _filename: &str) {}
110 async fn update(&mut self, _size: usize) {}
111 async fn finish(&mut self) {}
112}
113
114/// Local cache for ML models downloaded from `HuggingFace` Hub.
115///
116/// Models are stored under `{cache_dir}/{repo_id}/{filename}`. Files are
117/// downloaded only once; subsequent calls return the cached path immediately.
118///
119/// # Examples
120///
121/// ```no_run
122/// # async fn example() -> Result<(), blazen_model_cache::CacheError> {
123/// use blazen_model_cache::ModelCache;
124///
125/// let cache = ModelCache::new()?;
126/// let path = cache.download("bert-base-uncased", "config.json", None).await?;
127/// println!("model config at: {}", path.display());
128/// # Ok(())
129/// # }
130/// ```
131pub struct ModelCache {
132 cache_dir: PathBuf,
133}
134
135impl ModelCache {
136 /// Create a cache rooted at a specific directory.
137 ///
138 /// The directory does not need to exist yet; it will be created on the
139 /// first download.
140 ///
141 /// Available on every target — only [`Self::new`] and [`Self::download`]
142 /// require the native `HuggingFace` Hub download stack.
143 #[must_use]
144 pub fn with_dir(cache_dir: impl Into<PathBuf>) -> Self {
145 Self {
146 cache_dir: cache_dir.into(),
147 }
148 }
149
150 /// The root cache directory path.
151 #[must_use]
152 pub fn cache_dir(&self) -> &Path {
153 &self.cache_dir
154 }
155
156 /// Check if a file is already present in the cache (without downloading).
157 ///
158 /// On wasm32 this always returns `false` — there is no filesystem to
159 /// inspect, so callers should not rely on cache hits in the browser.
160 #[must_use]
161 pub fn is_cached(&self, repo_id: &str, filename: &str) -> bool {
162 #[cfg(not(target_arch = "wasm32"))]
163 {
164 self.cached_path(repo_id, filename).is_file()
165 }
166 #[cfg(target_arch = "wasm32")]
167 {
168 let _ = (repo_id, filename);
169 false
170 }
171 }
172
173 /// Compute the expected cache path for a repo/file pair.
174 fn cached_path(&self, repo_id: &str, filename: &str) -> PathBuf {
175 self.cache_dir.join(repo_id).join(filename)
176 }
177}
178
179// -- Native-only methods (filesystem + HuggingFace Hub) -----------------------
180
181#[cfg(not(target_arch = "wasm32"))]
182impl ModelCache {
183 /// Create a cache in the default location.
184 ///
185 /// Uses `$BLAZEN_CACHE_DIR/models/` if the `BLAZEN_CACHE_DIR` environment
186 /// variable is set, otherwise falls back to `~/.cache/blazen/models/`.
187 ///
188 /// # Errors
189 ///
190 /// Returns [`CacheError::CacheDir`] if the home directory cannot be
191 /// determined and `BLAZEN_CACHE_DIR` is not set.
192 pub fn new() -> Result<Self, CacheError> {
193 let cache_dir = if let Ok(dir) = std::env::var("BLAZEN_CACHE_DIR") {
194 PathBuf::from(dir).join("models")
195 } else {
196 dirs::cache_dir()
197 .ok_or_else(|| {
198 CacheError::CacheDir(
199 "could not determine home cache directory; \
200 set BLAZEN_CACHE_DIR to override"
201 .to_string(),
202 )
203 })?
204 .join("blazen")
205 .join("models")
206 };
207
208 Ok(Self { cache_dir })
209 }
210
211 /// Download a file from `HuggingFace` Hub if it is not already cached.
212 ///
213 /// Returns the local filesystem path to the cached file.
214 ///
215 /// The file is first downloaded via `hf-hub` into its own managed cache,
216 /// then hard-linked (or copied as fallback) into our
217 /// `{cache_dir}/{repo_id}/{filename}` layout so that callers get a stable,
218 /// predictable path.
219 ///
220 /// # Progress
221 ///
222 /// Pass an `Arc<dyn ProgressCallback>` to receive byte-level progress
223 /// updates during the download. Pass `None` to download silently.
224 ///
225 /// # Errors
226 ///
227 /// Returns [`CacheError::Download`] if the `HuggingFace` API request fails,
228 /// or [`CacheError::Io`] if filesystem operations fail.
229 pub async fn download(
230 &self,
231 repo_id: &str,
232 filename: &str,
233 progress: Option<Arc<dyn ProgressCallback>>,
234 ) -> Result<PathBuf, CacheError> {
235 let dest = self.cached_path(repo_id, filename);
236
237 // Serialize concurrent callers for the same destination path. Different
238 // files download in parallel; same-file callers wait so hf-hub's internal
239 // blob lock isn't raced.
240 let lock = DOWNLOAD_LOCKS
241 .entry(dest.clone())
242 .or_insert_with(|| Arc::new(tokio::sync::Mutex::new(())))
243 .value()
244 .clone();
245 let _guard = lock.lock().await;
246
247 // Already cached -- return immediately (re-check inside the lock so
248 // subsequent callers observe the file created by whoever went first).
249 if dest.is_file() {
250 return Ok(dest);
251 }
252
253 // Ensure the target directory exists.
254 if let Some(parent) = dest.parent() {
255 tokio::fs::create_dir_all(parent).await?;
256 }
257
258 // Build the hf-hub async API.
259 let api = hf_hub::api::tokio::ApiBuilder::new()
260 .with_progress(false) // We handle progress ourselves.
261 .build()
262 .map_err(|e| CacheError::Download(e.to_string()))?;
263
264 let repo = api.model(repo_id.to_string());
265
266 // Download through hf-hub (it manages its own cache under ~/.cache/huggingface).
267 let hf_path = if let Some(cb) = progress {
268 let adapter = HfProgressAdapter::new(Arc::clone(&cb));
269 repo.download_with_progress(filename, adapter)
270 .await
271 .map_err(|e| CacheError::Download(e.to_string()))?
272 } else {
273 let noop = NoProgress;
274 repo.download_with_progress(filename, noop)
275 .await
276 .map_err(|e| CacheError::Download(e.to_string()))?
277 };
278
279 // hf-hub returns `snapshots/main/<filename>` as a symlink into
280 // `blobs/<hash>`. Resolve to the real blob so hard-linking targets the
281 // actual file; otherwise on some filesystems we would hard-link the
282 // symlink itself, which can behave unexpectedly if the snapshot is
283 // pruned later. If canonicalization fails (e.g. broken chain), fall
284 // back to the original path and let the copy fallback handle it.
285 let hf_path_resolved = tokio::fs::canonicalize(&hf_path)
286 .await
287 .unwrap_or_else(|_| hf_path.clone());
288
289 // Link or copy the file into our own cache layout.
290 if dest != hf_path_resolved {
291 // Try hard link first (instant, no extra disk space).
292 if tokio::fs::hard_link(&hf_path_resolved, &dest)
293 .await
294 .is_err()
295 {
296 // Cross-device or unsupported -- fall back to copy.
297 tokio::fs::copy(&hf_path_resolved, &dest).await?;
298 }
299 }
300
301 // Postcondition: dest must exist after a successful download.
302 if !dest.is_file() {
303 return Err(CacheError::Io(std::io::Error::new(
304 std::io::ErrorKind::NotFound,
305 format!(
306 "download completed but cache path is missing: {}",
307 dest.display()
308 ),
309 )));
310 }
311
312 Ok(dest)
313 }
314}
315
316// -- wasm32 stubs -------------------------------------------------------------
317
318#[cfg(target_arch = "wasm32")]
319impl ModelCache {
320 /// Stub that returns [`CacheError::Unsupported`] on wasm32.
321 ///
322 /// On wasm32 there is no `dirs::cache_dir()` and no `BLAZEN_CACHE_DIR`
323 /// environment lookup that would produce a usable path. Use
324 /// [`Self::with_dir`] with an explicit virtual path instead.
325 ///
326 /// # Errors
327 ///
328 /// Always returns [`CacheError::Unsupported`].
329 pub fn new() -> Result<Self, CacheError> {
330 Err(CacheError::Unsupported(
331 "ModelCache::new() is not supported on wasm32; use ModelCache::with_dir() instead"
332 .to_string(),
333 ))
334 }
335
336 /// Stub that always returns [`CacheError::Unsupported`] on wasm32.
337 ///
338 /// The HuggingFace Hub client (`hf-hub`) and `tokio::fs` are not
339 /// compatible with the `wasm32-*` targets, so model files cannot be
340 /// downloaded from this side. Browser/Worker callers should fetch model
341 /// bytes via the JavaScript `fetch()` API and pass them in directly.
342 ///
343 /// # Errors
344 ///
345 /// Always returns [`CacheError::Unsupported`].
346 pub async fn download(
347 &self,
348 repo_id: &str,
349 filename: &str,
350 progress: Option<Arc<dyn ProgressCallback>>,
351 ) -> Result<PathBuf, CacheError> {
352 let _ = (repo_id, filename, progress);
353 Err(CacheError::Unsupported(format!(
354 "ModelCache::download() is not supported on wasm32 \
355 (cache_dir={})",
356 self.cache_dir.display()
357 )))
358 }
359}
360
361#[cfg(all(test, not(target_arch = "wasm32")))]
362#[allow(unsafe_code)] // env::set_var / env::remove_var require unsafe in edition 2024
363mod tests {
364 use super::*;
365
366 #[test]
367 fn test_default_cache_dir() {
368 // When BLAZEN_CACHE_DIR is not set, the cache should live under the
369 // platform cache directory (e.g. ~/.cache/blazen/models/ on Linux).
370 // We temporarily remove the env var to test the default path.
371 let had_var = std::env::var("BLAZEN_CACHE_DIR").ok();
372
373 // SAFETY: This test runs single-threaded and restores the variable
374 // immediately after the assertion. No other thread reads this var
375 // concurrently in the test suite (env-var tests are inherently racy
376 // but acceptable in `#[test]` which defaults to `--test-threads=1`
377 // per binary).
378 unsafe {
379 std::env::remove_var("BLAZEN_CACHE_DIR");
380 }
381
382 let cache = ModelCache::new().expect("default cache should succeed");
383 let path = cache.cache_dir();
384
385 // Should end with blazen/models
386 assert!(
387 path.ends_with("blazen/models"),
388 "expected path ending with blazen/models, got: {}",
389 path.display()
390 );
391
392 // Restore env var if it was set.
393 if let Some(val) = had_var {
394 // SAFETY: restoring the original value.
395 unsafe {
396 std::env::set_var("BLAZEN_CACHE_DIR", val);
397 }
398 }
399 }
400
401 #[test]
402 fn test_default_cache_dir_from_env() {
403 let prev = std::env::var("BLAZEN_CACHE_DIR").ok();
404
405 // SAFETY: see `test_default_cache_dir`.
406 unsafe {
407 std::env::set_var("BLAZEN_CACHE_DIR", "/tmp/blazen-test-cache");
408 }
409
410 let cache = ModelCache::new().expect("env-based cache should succeed");
411 assert_eq!(
412 cache.cache_dir(),
413 Path::new("/tmp/blazen-test-cache/models")
414 );
415
416 // Restore.
417 // SAFETY: see `test_default_cache_dir`.
418 unsafe {
419 match prev {
420 Some(val) => std::env::set_var("BLAZEN_CACHE_DIR", val),
421 None => std::env::remove_var("BLAZEN_CACHE_DIR"),
422 }
423 }
424 }
425
426 #[test]
427 fn test_with_dir() {
428 let dir = tempfile::tempdir().expect("tempdir");
429 let cache = ModelCache::with_dir(dir.path());
430 assert_eq!(cache.cache_dir(), dir.path());
431 }
432
433 #[test]
434 fn test_is_cached_false_initially() {
435 let dir = tempfile::tempdir().expect("tempdir");
436 let cache = ModelCache::with_dir(dir.path());
437 assert!(!cache.is_cached("foo/bar", "model.gguf"));
438 }
439
440 #[test]
441 fn test_is_cached_true_after_manual_placement() {
442 let dir = tempfile::tempdir().expect("tempdir");
443 let cache = ModelCache::with_dir(dir.path());
444
445 // Manually create the file to simulate a cached download.
446 let file_dir = dir.path().join("my-org/my-model");
447 std::fs::create_dir_all(&file_dir).unwrap();
448 std::fs::write(file_dir.join("config.json"), b"{}").unwrap();
449
450 assert!(cache.is_cached("my-org/my-model", "config.json"));
451 }
452
453 #[test]
454 fn test_cached_path_layout() {
455 let cache = ModelCache::with_dir("/fake/cache");
456 let path = cache.cached_path("org/model", "weights.bin");
457 assert_eq!(path, PathBuf::from("/fake/cache/org/model/weights.bin"));
458 }
459
460 /// Verifies that the per-path lock actually serializes concurrent callers
461 /// targeting the same destination. We can't easily mock hf-hub inside
462 /// `download()`, so this test exercises the serialization primitive
463 /// directly: if the lock map misbehaves (e.g. hands out independent
464 /// mutexes for the same path), more than one task will sit inside the
465 /// critical section at the same time and the counter will exceed zero
466 /// when observed by another task.
467 #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
468 async fn concurrent_downloads_serialize_same_path() {
469 let tmp = tempfile::tempdir().expect("tempdir");
470 let cache = ModelCache::with_dir(tmp.path().to_path_buf());
471 let dest = cache.cached_path("test/repo", "file.bin");
472
473 let counter = Arc::new(std::sync::atomic::AtomicUsize::new(0));
474 let mut handles = Vec::new();
475 for _ in 0..4 {
476 let dest_clone = dest.clone();
477 let counter_clone = Arc::clone(&counter);
478 handles.push(tokio::spawn(async move {
479 let lock = DOWNLOAD_LOCKS
480 .entry(dest_clone)
481 .or_insert_with(|| Arc::new(tokio::sync::Mutex::new(())))
482 .value()
483 .clone();
484 let _guard = lock.lock().await;
485 // If another task already holds the lock, it would have
486 // incremented the counter before us; the assertion below
487 // would then catch the violation.
488 let prev = counter_clone.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
489 // Simulate in-flight work to widen the race window.
490 tokio::time::sleep(std::time::Duration::from_millis(10)).await;
491 counter_clone.fetch_sub(1, std::sync::atomic::Ordering::SeqCst);
492 prev
493 }));
494 }
495
496 let results = futures_util::future::join_all(handles).await;
497 for r in results {
498 let prev = r.expect("task panicked");
499 assert_eq!(prev, 0, "another task held the lock concurrently");
500 }
501 }
502
503 /// Integration test that actually downloads from `HuggingFace` Hub.
504 ///
505 /// Ignored by default because it requires network access. Run with:
506 /// ```sh
507 /// cargo test -p blazen-model-cache -- --ignored
508 /// ```
509 #[tokio::test]
510 #[ignore = "requires network access to HuggingFace Hub"]
511 async fn test_download_and_cache() {
512 let dir = tempfile::tempdir().expect("tempdir");
513 let cache = ModelCache::with_dir(dir.path());
514
515 // Download a tiny file (~600 bytes).
516 let path = cache
517 .download("hf-internal-testing/tiny-random-gpt2", "config.json", None)
518 .await
519 .expect("download should succeed");
520
521 // File should exist and have non-zero size.
522 assert!(path.is_file(), "downloaded file should exist");
523 let meta = std::fs::metadata(&path).expect("metadata");
524 assert!(meta.len() > 0, "downloaded file should be non-empty");
525
526 // Second call should return the cached path instantly.
527 let path2 = cache
528 .download("hf-internal-testing/tiny-random-gpt2", "config.json", None)
529 .await
530 .expect("cached download should succeed");
531 assert_eq!(path, path2);
532 }
533
534 /// Integration test verifying progress callback fires.
535 #[tokio::test]
536 #[ignore = "requires network access to HuggingFace Hub"]
537 async fn test_download_with_progress() {
538 use std::sync::atomic::{AtomicU64, Ordering};
539
540 struct TestProgress {
541 calls: AtomicU64,
542 }
543
544 impl ProgressCallback for TestProgress {
545 fn on_progress(&self, _downloaded_bytes: u64, _total_bytes: Option<u64>) {
546 self.calls.fetch_add(1, Ordering::Relaxed);
547 }
548 }
549
550 let dir = tempfile::tempdir().expect("tempdir");
551 let cache = ModelCache::with_dir(dir.path());
552 let progress = Arc::new(TestProgress {
553 calls: AtomicU64::new(0),
554 });
555
556 // Clone the Arc so we retain a handle for assertions.
557 let cb: Arc<dyn ProgressCallback> = Arc::clone(&progress) as Arc<dyn ProgressCallback>;
558
559 cache
560 .download(
561 "hf-internal-testing/tiny-random-gpt2",
562 "config.json",
563 Some(cb),
564 )
565 .await
566 .expect("download should succeed");
567
568 assert!(
569 progress.calls.load(Ordering::Relaxed) > 0,
570 "progress callback should have been called at least once"
571 );
572 }
573}