use std::collections::HashMap;
use std::sync::{Arc, Mutex, OnceLock};
use serde::{Deserialize, Serialize};
use tokio::sync::{Mutex as AsyncMutex, OwnedMutexGuard};
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(tag = "event", rename_all = "snake_case")]
pub enum DownloadEvent {
Started {
model: String,
total_files: u32,
total_mb: u64,
},
FileStarted {
filename: String,
index: u32,
total_files: u32,
size_mb: u64,
},
FileProgress {
filename: String,
downloaded_mb: u64,
total_mb: u64,
},
FileCompleted { filename: String },
Completed { model: String },
Failed { error: String },
}
pub trait DownloadProgress: Send + Sync {
fn on_event(&self, event: &DownloadEvent);
}
#[derive(Clone, Default)]
pub struct ProgressSink(Option<Arc<dyn DownloadProgress>>);
impl ProgressSink {
pub fn none() -> Self {
ProgressSink(None)
}
pub fn new(sink: Arc<dyn DownloadProgress>) -> Self {
ProgressSink(Some(sink))
}
pub fn emit(&self, event: DownloadEvent) {
if let Some(s) = &self.0 {
s.on_event(&event);
}
}
pub fn is_active(&self) -> bool {
self.0.is_some()
}
}
fn model_locks() -> &'static Mutex<HashMap<String, Arc<AsyncMutex<()>>>> {
static LOCKS: OnceLock<Mutex<HashMap<String, Arc<AsyncMutex<()>>>>> = OnceLock::new();
LOCKS.get_or_init(|| Mutex::new(HashMap::new()))
}
pub async fn acquire_model_lock(model_id: &str) -> OwnedMutexGuard<()> {
let lock = {
let mut map = model_locks().lock().unwrap();
map.entry(model_id.to_string())
.or_insert_with(|| Arc::new(AsyncMutex::new(())))
.clone()
};
lock.lock_owned().await
}
pub fn check_disk_space(path: &std::path::Path, needed_mb: u64) -> Result<(), String> {
if needed_mb == 0 {
return Ok(());
}
let Some(available_mb) = available_disk_mb(path) else {
return Ok(());
};
let required = needed_mb.saturating_add(1024);
if available_mb < required {
return Err(format!(
"not enough disk space: need ~{} MB (+1 GB free), but only {} MB available at {}",
needed_mb,
available_mb,
path.display()
));
}
Ok(())
}
fn available_disk_mb(path: &std::path::Path) -> Option<u64> {
let mut probe = path;
loop {
if probe.exists() {
break;
}
probe = probe.parent()?;
}
#[cfg(unix)]
{
let out = std::process::Command::new("df")
.arg("-Pk")
.arg(probe)
.output()
.ok()?;
let text = String::from_utf8(out.stdout).ok()?;
let avail_kb: u64 = text
.lines()
.nth(1)?
.split_whitespace()
.nth(3)?
.parse()
.ok()?;
Some(avail_kb / 1024)
}
#[cfg(not(unix))]
{
let _ = probe;
None
}
}
use std::path::Path;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CacheIntegrity {
Verified,
Unverifiable,
Corrupt,
}
pub fn cache_file_usable(path: &Path) -> bool {
matches!(std::fs::metadata(path), Ok(m) if m.is_file() && m.len() > 0)
}
pub fn verify_cache_file(pointer_path: &Path) -> CacheIntegrity {
if !cache_file_usable(pointer_path) {
return CacheIntegrity::Corrupt;
}
let Some(etag) = blob_etag(pointer_path) else {
return CacheIntegrity::Unverifiable;
};
if etag.len() != 64 || !etag.bytes().all(|b| b.is_ascii_hexdigit()) {
return CacheIntegrity::Unverifiable;
}
match sha256_hex(pointer_path) {
Ok(actual) if actual.eq_ignore_ascii_case(&etag) => CacheIntegrity::Verified,
Ok(_) | Err(_) => CacheIntegrity::Corrupt,
}
}
pub fn purge_corrupt_cache_files(model_dir: &Path) -> usize {
let mut removed = 0usize;
purge_corrupt_recurse(model_dir, &mut removed);
removed
}
fn purge_corrupt_recurse(dir: &Path, removed: &mut usize) {
let Ok(entries) = std::fs::read_dir(dir) else {
return;
};
for entry in entries.filter_map(Result::ok) {
let path = entry.path();
if entry.file_type().map(|t| t.is_dir()).unwrap_or(false) {
purge_corrupt_recurse(&path, removed);
continue;
}
if verify_cache_file(&path) == CacheIntegrity::Corrupt {
let is_link = std::fs::symlink_metadata(&path)
.map(|m| m.file_type().is_symlink())
.unwrap_or(false);
if is_link {
if let Ok(real) = std::fs::canonicalize(&path) {
let _ = std::fs::remove_file(&real);
}
}
if std::fs::remove_file(&path).is_ok() {
*removed += 1;
}
}
}
}
fn blob_etag(pointer_path: &Path) -> Option<String> {
let real = std::fs::canonicalize(pointer_path).ok()?;
let into_blobs = real
.parent()
.and_then(|p| p.file_name())
.and_then(|n| n.to_str())
== Some("blobs");
if !into_blobs {
return None;
}
real.file_name()?.to_str().map(str::to_string)
}
fn sha256_hex(path: &Path) -> std::io::Result<String> {
use sha2::{Digest, Sha256};
use std::io::Read;
let mut file = std::fs::File::open(path)?;
let mut hasher = Sha256::new();
let mut buf = [0u8; 64 * 1024];
loop {
let n = file.read(&mut buf)?;
if n == 0 {
break;
}
hasher.update(&buf[..n]);
}
Ok(hex::encode(hasher.finalize()))
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Mutex;
#[derive(Default)]
struct Recorder {
events: Mutex<Vec<DownloadEvent>>,
}
impl DownloadProgress for Recorder {
fn on_event(&self, event: &DownloadEvent) {
self.events.lock().unwrap().push(event.clone());
}
}
#[test]
fn none_sink_is_inert() {
let s = ProgressSink::none();
assert!(!s.is_active());
s.emit(DownloadEvent::Completed { model: "x".into() }); }
#[test]
fn sink_records_events_in_order() {
let rec = Arc::new(Recorder::default());
let sink = ProgressSink::new(rec.clone());
assert!(sink.is_active());
sink.emit(DownloadEvent::Started {
model: "Qwen3-4B".into(),
total_files: 2,
total_mb: 2500,
});
sink.emit(DownloadEvent::FileCompleted {
filename: "model.gguf".into(),
});
sink.emit(DownloadEvent::Completed {
model: "Qwen3-4B".into(),
});
let evs = rec.events.lock().unwrap();
assert_eq!(evs.len(), 3);
assert!(matches!(
evs[0],
DownloadEvent::Started { total_files: 2, .. }
));
assert!(matches!(evs[2], DownloadEvent::Completed { .. }));
}
#[test]
fn event_serializes_with_tag() {
let json = serde_json::to_string(&DownloadEvent::FileStarted {
filename: "model.gguf".into(),
index: 1,
total_files: 2,
size_mb: 2400,
})
.unwrap();
assert!(json.contains("\"event\":\"file_started\""));
assert!(json.contains("\"size_mb\":2400"));
}
#[tokio::test]
async fn same_model_lock_is_exclusive_distinct_ids_are_not() {
use std::time::Duration;
let (a, b) = ("lock-test-a", "lock-test-b");
let a1 = acquire_model_lock(a).await;
let _b = tokio::time::timeout(Duration::from_millis(200), acquire_model_lock(b))
.await
.expect("distinct id must not block");
let contended =
tokio::time::timeout(Duration::from_millis(50), acquire_model_lock(a)).await;
assert!(contended.is_err(), "same-id lock should be contended");
drop(a1);
tokio::time::timeout(Duration::from_millis(500), acquire_model_lock(a))
.await
.expect("acquires after release");
}
#[test]
fn zero_needed_skips_disk_check() {
assert!(check_disk_space(std::path::Path::new("/nonexistent/x"), 0).is_ok());
}
#[test]
fn absurd_size_is_rejected_when_probe_succeeds() {
let tmp = std::env::temp_dir();
if available_disk_mb(&tmp).is_some() {
let res = check_disk_space(&tmp, u64::MAX / (1024 * 1024) - 2048);
assert!(res.is_err(), "expected disk-space rejection, got {res:?}");
}
}
use sha2::{Digest, Sha256};
use tempfile::TempDir;
fn sha256_of(bytes: &[u8]) -> String {
let mut h = Sha256::new();
h.update(bytes);
hex::encode(h.finalize())
}
#[test]
fn cache_file_usable_rejects_missing_and_empty() {
let tmp = TempDir::new().unwrap();
let missing = tmp.path().join("nope");
assert!(!cache_file_usable(&missing), "missing file is not usable");
let empty = tmp.path().join("empty");
std::fs::write(&empty, b"").unwrap();
assert!(!cache_file_usable(&empty), "zero-length file is not usable");
let good = tmp.path().join("good");
std::fs::write(&good, b"weights").unwrap();
assert!(cache_file_usable(&good), "non-empty file is usable");
}
#[cfg(unix)]
#[test]
fn cache_file_usable_rejects_dangling_symlink() {
let tmp = TempDir::new().unwrap();
let link = tmp.path().join("ptr");
std::os::unix::fs::symlink(tmp.path().join("does-not-exist"), &link).unwrap();
assert!(
!cache_file_usable(&link),
"dangling symlink (pruned blob) must be unusable"
);
let blob = tmp.path().join("blob");
std::fs::write(&blob, b"data").unwrap();
let live = tmp.path().join("live");
std::os::unix::fs::symlink(&blob, &live).unwrap();
assert!(cache_file_usable(&live), "resolving symlink is usable");
}
#[cfg(unix)]
fn hf_pointer(root: &Path, etag: &str, content: &[u8], filename: &str) -> std::path::PathBuf {
let blobs = root.join("blobs");
std::fs::create_dir_all(&blobs).unwrap();
let blob = blobs.join(etag);
std::fs::write(&blob, content).unwrap();
let snap = root.join("snapshots").join("deadbeef");
std::fs::create_dir_all(&snap).unwrap();
let ptr = snap.join(filename);
std::os::unix::fs::symlink(
Path::new("..").join("..").join("blobs").join(etag),
&ptr,
)
.unwrap();
ptr
}
#[cfg(unix)]
#[test]
fn verify_cache_file_confirms_matching_sha256_blob() {
let tmp = TempDir::new().unwrap();
let content = b"the real weights";
let etag = sha256_of(content);
let ptr = hf_pointer(tmp.path(), &etag, content, "model.safetensors");
assert_eq!(verify_cache_file(&ptr), CacheIntegrity::Verified);
}
#[cfg(unix)]
#[test]
fn verify_cache_file_flags_corrupt_blob() {
let tmp = TempDir::new().unwrap();
let claimed = sha256_of(b"the real weights");
let ptr = hf_pointer(tmp.path(), &claimed, b"trunc", "model.safetensors");
assert_eq!(
verify_cache_file(&ptr),
CacheIntegrity::Corrupt,
"content not matching its etag hash is corrupt"
);
}
#[cfg(unix)]
#[test]
fn verify_cache_file_flags_dangling_and_empty() {
let tmp = TempDir::new().unwrap();
let dangling = tmp.path().join("d");
std::os::unix::fs::symlink(tmp.path().join("gone"), &dangling).unwrap();
assert_eq!(verify_cache_file(&dangling), CacheIntegrity::Corrupt);
let empty = tmp.path().join("e");
std::fs::write(&empty, b"").unwrap();
assert_eq!(verify_cache_file(&empty), CacheIntegrity::Corrupt);
}
#[cfg(unix)]
#[test]
fn verify_cache_file_non_lfs_etag_is_unverifiable() {
let tmp = TempDir::new().unwrap();
let git_sha1 = "a".repeat(40);
let ptr = hf_pointer(tmp.path(), &git_sha1, b"{}", "config.json");
assert_eq!(verify_cache_file(&ptr), CacheIntegrity::Unverifiable);
}
#[test]
fn verify_cache_file_plain_file_is_unverifiable() {
let tmp = TempDir::new().unwrap();
let f = tmp.path().join("model.gguf");
std::fs::write(&f, b"weights").unwrap();
assert_eq!(verify_cache_file(&f), CacheIntegrity::Unverifiable);
}
#[cfg(unix)]
#[test]
fn purge_corrupt_only_removes_provably_bad_files() {
let tmp = TempDir::new().unwrap();
let model = tmp.path().join("model");
std::fs::create_dir_all(&model).unwrap();
let good_bytes = b"the real weights";
let good_etag = sha256_of(good_bytes);
let good = hf_pointer(tmp.path(), &good_etag, good_bytes, "good.safetensors");
std::os::unix::fs::symlink(&good, model.join("good.safetensors")).unwrap();
let bad_etag = sha256_of(b"claimed");
let bad = hf_pointer(tmp.path(), &bad_etag, b"actually-truncated", "bad.safetensors");
let bad_blob = tmp.path().join("blobs").join(&bad_etag);
std::os::unix::fs::symlink(&bad, model.join("bad.safetensors")).unwrap();
std::fs::write(model.join("config.json"), b"{}").unwrap();
std::os::unix::fs::symlink(model.join("gone"), model.join("dangling.safetensors")).unwrap();
let removed = purge_corrupt_cache_files(&model);
assert_eq!(removed, 2, "corrupt weight + dangling symlink removed");
assert!(model.join("good.safetensors").exists(), "good weight kept");
assert!(model.join("config.json").exists(), "config kept");
assert!(!model.join("bad.safetensors").exists(), "corrupt weight gone");
assert!(
!bad_blob.exists(),
"the corrupt content-addressed blob must be removed, not just the pointer — \
otherwise the re-pull re-links it instead of re-downloading"
);
assert!(
std::fs::canonicalize(&good).is_ok(),
"the good blob must survive (shared with other models)"
);
assert!(
std::fs::symlink_metadata(model.join("dangling.safetensors")).is_err(),
"dangling symlink gone"
);
}
}