use std::path::{Path, PathBuf};
#[cfg(not(target_arch = "wasm32"))]
use std::sync::Arc;
#[cfg(not(target_arch = "wasm32"))]
use std::sync::LazyLock;
#[cfg(all(target_arch = "wasm32", target_os = "wasi"))]
pub mod wasi;
#[cfg(all(target_arch = "wasm32", target_os = "wasi"))]
pub use wasi::{WasiHttpFetch, WasiModelCache};
#[cfg(not(target_arch = "wasm32"))]
static DOWNLOAD_LOCKS: LazyLock<dashmap::DashMap<PathBuf, Arc<tokio::sync::Mutex<()>>>> =
LazyLock::new(dashmap::DashMap::new);
#[derive(Debug, thiserror::Error)]
pub enum CacheError {
#[error("failed to download model: {0}")]
Download(String),
#[error("cache directory error: {0}")]
CacheDir(String),
#[error("IO error: {0}")]
Io(#[from] std::io::Error),
#[error("model cache operation not supported on this target: {0}")]
Unsupported(String),
}
pub trait ProgressCallback: Send + Sync {
fn on_progress(&self, downloaded_bytes: u64, total_bytes: Option<u64>);
}
#[cfg(not(target_arch = "wasm32"))]
#[derive(Clone)]
struct HfProgressAdapter {
callback: Arc<dyn ProgressCallback>,
downloaded: u64,
total: Option<u64>,
}
#[cfg(not(target_arch = "wasm32"))]
impl HfProgressAdapter {
fn new(callback: Arc<dyn ProgressCallback>) -> Self {
Self {
callback,
downloaded: 0,
total: None,
}
}
}
#[cfg(not(target_arch = "wasm32"))]
impl hf_hub::api::tokio::Progress for HfProgressAdapter {
async fn init(&mut self, size: usize, _filename: &str) {
self.total = Some(size as u64);
self.downloaded = 0;
self.callback.on_progress(0, self.total);
}
async fn update(&mut self, size: usize) {
self.downloaded += size as u64;
self.callback.on_progress(self.downloaded, self.total);
}
async fn finish(&mut self) {
self.callback
.on_progress(self.downloaded.max(1), self.total);
}
}
#[cfg(not(target_arch = "wasm32"))]
#[derive(Clone)]
struct NoProgress;
#[cfg(not(target_arch = "wasm32"))]
impl hf_hub::api::tokio::Progress for NoProgress {
async fn init(&mut self, _size: usize, _filename: &str) {}
async fn update(&mut self, _size: usize) {}
async fn finish(&mut self) {}
}
pub struct ModelCache {
cache_dir: PathBuf,
}
impl ModelCache {
#[must_use]
pub fn with_dir(cache_dir: impl Into<PathBuf>) -> Self {
Self {
cache_dir: cache_dir.into(),
}
}
#[must_use]
pub fn cache_dir(&self) -> &Path {
&self.cache_dir
}
#[must_use]
pub fn is_cached(&self, repo_id: &str, filename: &str) -> bool {
#[cfg(not(target_arch = "wasm32"))]
{
self.cached_path(repo_id, filename).is_file()
}
#[cfg(target_arch = "wasm32")]
{
let _ = (repo_id, filename);
false
}
}
#[cfg(not(target_arch = "wasm32"))]
fn cached_path(&self, repo_id: &str, filename: &str) -> PathBuf {
self.cache_dir.join(repo_id).join(filename)
}
}
#[cfg(not(target_arch = "wasm32"))]
impl ModelCache {
pub fn new() -> Result<Self, CacheError> {
let cache_dir = if let Ok(dir) = std::env::var("BLAZEN_CACHE_DIR") {
PathBuf::from(dir).join("models")
} else {
dirs::cache_dir()
.ok_or_else(|| {
CacheError::CacheDir(
"could not determine home cache directory; \
set BLAZEN_CACHE_DIR to override"
.to_string(),
)
})?
.join("blazen")
.join("models")
};
Ok(Self { cache_dir })
}
pub async fn download(
&self,
repo_id: &str,
filename: &str,
progress: Option<Arc<dyn ProgressCallback>>,
) -> Result<PathBuf, CacheError> {
let dest = self.cached_path(repo_id, filename);
let lock = DOWNLOAD_LOCKS
.entry(dest.clone())
.or_insert_with(|| Arc::new(tokio::sync::Mutex::new(())))
.value()
.clone();
let _guard = lock.lock().await;
if dest.is_file() {
return Ok(dest);
}
if let Some(parent) = dest.parent() {
tokio::fs::create_dir_all(parent).await?;
}
let api = hf_hub::api::tokio::ApiBuilder::new()
.with_progress(false) .build()
.map_err(|e| CacheError::Download(e.to_string()))?;
let repo = api.model(repo_id.to_string());
let hf_path = if let Some(cb) = progress {
let adapter = HfProgressAdapter::new(Arc::clone(&cb));
repo.download_with_progress(filename, adapter)
.await
.map_err(|e| CacheError::Download(e.to_string()))?
} else {
let noop = NoProgress;
repo.download_with_progress(filename, noop)
.await
.map_err(|e| CacheError::Download(e.to_string()))?
};
let hf_path_resolved = tokio::fs::canonicalize(&hf_path)
.await
.unwrap_or_else(|_| hf_path.clone());
if dest != hf_path_resolved {
if tokio::fs::hard_link(&hf_path_resolved, &dest)
.await
.is_err()
{
tokio::fs::copy(&hf_path_resolved, &dest).await?;
}
}
if !dest.is_file() {
return Err(CacheError::Io(std::io::Error::new(
std::io::ErrorKind::NotFound,
format!(
"download completed but cache path is missing: {}",
dest.display()
),
)));
}
Ok(dest)
}
}
#[cfg(all(target_arch = "wasm32", not(target_os = "wasi")))]
impl ModelCache {
pub fn new() -> Result<Self, CacheError> {
Err(CacheError::Unsupported(
"ModelCache::new() is not supported on wasm32; use ModelCache::with_dir() instead"
.to_string(),
))
}
pub async fn download(
&self,
repo_id: &str,
filename: &str,
progress: Option<Arc<dyn ProgressCallback>>,
) -> Result<PathBuf, CacheError> {
let _ = (repo_id, filename, progress);
Err(CacheError::Unsupported(format!(
"ModelCache::download() is not supported on wasm32 \
(cache_dir={})",
self.cache_dir.display()
)))
}
}
#[cfg(all(test, not(target_arch = "wasm32")))]
#[allow(unsafe_code)] mod tests {
use super::*;
#[test]
fn test_default_cache_dir() {
let had_var = std::env::var("BLAZEN_CACHE_DIR").ok();
unsafe {
std::env::remove_var("BLAZEN_CACHE_DIR");
}
let cache = ModelCache::new().expect("default cache should succeed");
let path = cache.cache_dir();
assert!(
path.ends_with("blazen/models"),
"expected path ending with blazen/models, got: {}",
path.display()
);
if let Some(val) = had_var {
unsafe {
std::env::set_var("BLAZEN_CACHE_DIR", val);
}
}
}
#[test]
fn test_default_cache_dir_from_env() {
let prev = std::env::var("BLAZEN_CACHE_DIR").ok();
unsafe {
std::env::set_var("BLAZEN_CACHE_DIR", "/tmp/blazen-test-cache");
}
let cache = ModelCache::new().expect("env-based cache should succeed");
assert_eq!(
cache.cache_dir(),
Path::new("/tmp/blazen-test-cache/models")
);
unsafe {
match prev {
Some(val) => std::env::set_var("BLAZEN_CACHE_DIR", val),
None => std::env::remove_var("BLAZEN_CACHE_DIR"),
}
}
}
#[test]
fn test_with_dir() {
let dir = tempfile::tempdir().expect("tempdir");
let cache = ModelCache::with_dir(dir.path());
assert_eq!(cache.cache_dir(), dir.path());
}
#[test]
fn test_is_cached_false_initially() {
let dir = tempfile::tempdir().expect("tempdir");
let cache = ModelCache::with_dir(dir.path());
assert!(!cache.is_cached("foo/bar", "model.gguf"));
}
#[test]
fn test_is_cached_true_after_manual_placement() {
let dir = tempfile::tempdir().expect("tempdir");
let cache = ModelCache::with_dir(dir.path());
let file_dir = dir.path().join("my-org/my-model");
std::fs::create_dir_all(&file_dir).unwrap();
std::fs::write(file_dir.join("config.json"), b"{}").unwrap();
assert!(cache.is_cached("my-org/my-model", "config.json"));
}
#[test]
fn test_cached_path_layout() {
let cache = ModelCache::with_dir("/fake/cache");
let path = cache.cached_path("org/model", "weights.bin");
assert_eq!(path, PathBuf::from("/fake/cache/org/model/weights.bin"));
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn concurrent_downloads_serialize_same_path() {
let tmp = tempfile::tempdir().expect("tempdir");
let cache = ModelCache::with_dir(tmp.path().to_path_buf());
let dest = cache.cached_path("test/repo", "file.bin");
let counter = Arc::new(std::sync::atomic::AtomicUsize::new(0));
let mut handles = Vec::new();
for _ in 0..4 {
let dest_clone = dest.clone();
let counter_clone = Arc::clone(&counter);
handles.push(tokio::spawn(async move {
let lock = DOWNLOAD_LOCKS
.entry(dest_clone)
.or_insert_with(|| Arc::new(tokio::sync::Mutex::new(())))
.value()
.clone();
let _guard = lock.lock().await;
let prev = counter_clone.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
counter_clone.fetch_sub(1, std::sync::atomic::Ordering::SeqCst);
prev
}));
}
let results = futures_util::future::join_all(handles).await;
for r in results {
let prev = r.expect("task panicked");
assert_eq!(prev, 0, "another task held the lock concurrently");
}
}
#[tokio::test]
#[ignore = "requires network access to HuggingFace Hub"]
async fn test_download_and_cache() {
let dir = tempfile::tempdir().expect("tempdir");
let cache = ModelCache::with_dir(dir.path());
let path = cache
.download("hf-internal-testing/tiny-random-gpt2", "config.json", None)
.await
.expect("download should succeed");
assert!(path.is_file(), "downloaded file should exist");
let meta = std::fs::metadata(&path).expect("metadata");
assert!(meta.len() > 0, "downloaded file should be non-empty");
let path2 = cache
.download("hf-internal-testing/tiny-random-gpt2", "config.json", None)
.await
.expect("cached download should succeed");
assert_eq!(path, path2);
}
#[tokio::test]
#[ignore = "requires network access to HuggingFace Hub"]
async fn test_download_with_progress() {
use std::sync::atomic::{AtomicU64, Ordering};
struct TestProgress {
calls: AtomicU64,
}
impl ProgressCallback for TestProgress {
fn on_progress(&self, _downloaded_bytes: u64, _total_bytes: Option<u64>) {
self.calls.fetch_add(1, Ordering::Relaxed);
}
}
let dir = tempfile::tempdir().expect("tempdir");
let cache = ModelCache::with_dir(dir.path());
let progress = Arc::new(TestProgress {
calls: AtomicU64::new(0),
});
let cb: Arc<dyn ProgressCallback> = Arc::clone(&progress) as Arc<dyn ProgressCallback>;
cache
.download(
"hf-internal-testing/tiny-random-gpt2",
"config.json",
Some(cb),
)
.await
.expect("download should succeed");
assert!(
progress.calls.load(Ordering::Relaxed) > 0,
"progress callback should have been called at least once"
);
}
}