use std::path::{Path, PathBuf};
use anyhow::{Context, Result};
use sha2::{Digest, Sha256};
pub struct HfCache {
cache_dir: PathBuf,
}
impl Default for HfCache {
fn default() -> Self {
Self::new()
}
}
impl HfCache {
pub fn new() -> Self {
let cache_dir = if let Ok(hf_home) = std::env::var("HF_HOME") {
PathBuf::from(hf_home).join("hub")
} else {
dirs::home_dir()
.expect("Could not determine home directory")
.join(".cache")
.join("huggingface")
.join("hub")
};
Self { cache_dir }
}
pub fn new_with_dir(cache_dir: PathBuf) -> Self {
Self { cache_dir }
}
pub fn repo_dir(&self, model_id: &str) -> PathBuf {
let folder = format!("models--{}", model_id.replace('/', "--"));
self.cache_dir.join(folder)
}
pub fn refs_dir(&self, model_id: &str) -> PathBuf {
self.repo_dir(model_id).join("refs")
}
pub fn blobs_dir(&self, model_id: &str) -> PathBuf {
self.repo_dir(model_id).join("blobs")
}
pub fn snapshots_dir(&self, model_id: &str) -> PathBuf {
self.repo_dir(model_id).join("snapshots")
}
pub fn snapshot_dir(&self, model_id: &str, commit_sha: &str) -> PathBuf {
self.snapshots_dir(model_id).join(commit_sha)
}
pub fn write_ref(&self, model_id: &str, ref_name: &str, commit_sha: &str) -> Result<()> {
let refs = self.refs_dir(model_id);
std::fs::create_dir_all(&refs)
.with_context(|| format!("Failed to create refs dir: {}", refs.display()))?;
let ref_path = refs.join(ref_name);
std::fs::write(&ref_path, commit_sha)
.with_context(|| format!("Failed to write ref: {}", ref_path.display()))?;
Ok(())
}
pub fn read_ref(&self, model_id: &str, ref_name: &str) -> Option<String> {
let ref_path = self.refs_dir(model_id).join(ref_name);
std::fs::read_to_string(ref_path)
.ok()
.map(|s| s.trim().to_string())
}
pub fn store_blob(&self, model_id: &str, data: &[u8]) -> Result<PathBuf> {
let hash = {
let mut hasher = Sha256::new();
hasher.update(data);
hex::encode(hasher.finalize())
};
let blobs = self.blobs_dir(model_id);
std::fs::create_dir_all(&blobs)
.with_context(|| format!("Failed to create blobs dir: {}", blobs.display()))?;
let blob_path = blobs.join(&hash);
if !blob_path.exists() {
std::fs::write(&blob_path, data)
.with_context(|| format!("Failed to write blob: {}", blob_path.display()))?;
}
Ok(blob_path)
}
pub fn link_snapshot(
&self,
model_id: &str,
commit_sha: &str,
filename: &str,
blob_path: &Path,
) -> Result<PathBuf> {
let snap_dir = self.snapshot_dir(model_id, commit_sha);
let snap_file = snap_dir.join(filename);
if let Some(parent) = snap_file.parent() {
std::fs::create_dir_all(parent)
.with_context(|| format!("Failed to create snapshot dir: {}", parent.display()))?;
}
if !snap_file.exists() {
#[cfg(unix)]
{
std::os::unix::fs::symlink(blob_path, &snap_file).with_context(|| {
format!(
"Failed to symlink {} -> {}",
snap_file.display(),
blob_path.display()
)
})?;
}
#[cfg(not(unix))]
{
std::fs::copy(blob_path, &snap_file).with_context(|| {
format!(
"Failed to copy {} -> {}",
blob_path.display(),
snap_file.display()
)
})?;
}
}
Ok(snap_file)
}
pub fn is_cached(&self, model_id: &str, commit_sha: &str) -> bool {
self.snapshot_dir(model_id, commit_sha).is_dir()
}
pub fn cached_snapshot_path(&self, model_id: &str, ref_name: &str) -> Option<PathBuf> {
let sha = self.read_ref(model_id, ref_name)?;
let snap = self.snapshot_dir(model_id, &sha);
if snap.is_dir() {
Some(snap)
} else {
None
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::TempDir;
fn test_cache() -> (TempDir, HfCache) {
let tmp = TempDir::new().unwrap();
let cache = HfCache::new_with_dir(tmp.path().to_path_buf());
(tmp, cache)
}
#[test]
fn repo_dir_replaces_slashes() {
let (_tmp, cache) = test_cache();
let dir = cache.repo_dir("meta-llama/Llama-2-7b");
assert!(dir.ends_with("models--meta-llama--Llama-2-7b"));
}
#[test]
fn subdirectories_under_repo() {
let (_tmp, cache) = test_cache();
let model = "org/model";
assert!(cache.refs_dir(model).ends_with("models--org--model/refs"));
assert!(cache.blobs_dir(model).ends_with("models--org--model/blobs"));
assert!(cache
.snapshots_dir(model)
.ends_with("models--org--model/snapshots"));
assert!(cache
.snapshot_dir(model, "abc123")
.ends_with("models--org--model/snapshots/abc123"));
}
#[test]
fn write_and_read_ref() {
let (_tmp, cache) = test_cache();
let model = "org/model";
cache.write_ref(model, "main", "deadbeef").unwrap();
assert_eq!(cache.read_ref(model, "main"), Some("deadbeef".to_string()));
}
#[test]
fn read_ref_missing_returns_none() {
let (_tmp, cache) = test_cache();
assert_eq!(cache.read_ref("org/model", "main"), None);
}
#[test]
fn store_blob_creates_file_with_correct_hash() {
let (_tmp, cache) = test_cache();
let data = b"hello world";
let blob = cache.store_blob("org/model", data).unwrap();
assert!(blob.exists());
assert_eq!(std::fs::read(&blob).unwrap(), data);
let expected_hash = hex::encode(Sha256::digest(data));
assert_eq!(blob.file_name().unwrap().to_str().unwrap(), expected_hash);
}
#[test]
fn store_blob_skips_existing() {
let (_tmp, cache) = test_cache();
let data = b"same content";
let p1 = cache.store_blob("org/model", data).unwrap();
let p2 = cache.store_blob("org/model", data).unwrap();
assert_eq!(p1, p2);
}
#[test]
fn link_snapshot_creates_symlink() {
let (_tmp, cache) = test_cache();
let model = "org/model";
let blob = cache.store_blob(model, b"payload").unwrap();
let snap = cache
.link_snapshot(model, "abc123", "config.json", &blob)
.unwrap();
assert!(snap.exists());
assert_eq!(std::fs::read(&snap).unwrap(), b"payload");
}
#[test]
fn link_snapshot_nested_filename() {
let (_tmp, cache) = test_cache();
let model = "org/model";
let blob = cache.store_blob(model, b"nested").unwrap();
let snap = cache
.link_snapshot(model, "abc123", "subdir/deep/file.bin", &blob)
.unwrap();
assert!(snap.exists());
assert!(snap.parent().unwrap().ends_with("subdir/deep"));
}
#[test]
fn link_snapshot_skips_existing() {
let (_tmp, cache) = test_cache();
let model = "org/model";
let blob = cache.store_blob(model, b"data").unwrap();
let p1 = cache.link_snapshot(model, "abc", "f.bin", &blob).unwrap();
let p2 = cache.link_snapshot(model, "abc", "f.bin", &blob).unwrap();
assert_eq!(p1, p2);
}
#[test]
fn is_cached_returns_false_then_true() {
let (_tmp, cache) = test_cache();
let model = "org/model";
assert!(!cache.is_cached(model, "abc123"));
let blob = cache.store_blob(model, b"x").unwrap();
cache
.link_snapshot(model, "abc123", "file.txt", &blob)
.unwrap();
assert!(cache.is_cached(model, "abc123"));
}
#[test]
fn cached_snapshot_path_roundtrip() {
let (_tmp, cache) = test_cache();
let model = "org/model";
assert!(cache.cached_snapshot_path(model, "main").is_none());
let blob = cache.store_blob(model, b"file content").unwrap();
cache
.link_snapshot(model, "sha999", "readme.md", &blob)
.unwrap();
cache.write_ref(model, "main", "sha999").unwrap();
let snap = cache.cached_snapshot_path(model, "main").unwrap();
assert!(snap.is_dir());
assert!(snap.ends_with("snapshots/sha999"));
}
}