use std::collections::HashMap;
use std::sync::{Arc, Mutex, OnceLock};
use serde::{Deserialize, Serialize};
use tokio::sync::{Mutex as AsyncMutex, OwnedMutexGuard};
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(tag = "event", rename_all = "snake_case")]
pub enum DownloadEvent {
Started {
model: String,
total_files: u32,
total_mb: u64,
},
FileStarted {
filename: String,
index: u32,
total_files: u32,
size_mb: u64,
},
FileProgress {
filename: String,
downloaded_mb: u64,
total_mb: u64,
},
FileCompleted { filename: String },
Completed { model: String },
Failed { error: String },
}
pub trait DownloadProgress: Send + Sync {
fn on_event(&self, event: &DownloadEvent);
}
#[derive(Clone, Default)]
pub struct ProgressSink(Option<Arc<dyn DownloadProgress>>);
impl ProgressSink {
pub fn none() -> Self {
ProgressSink(None)
}
pub fn new(sink: Arc<dyn DownloadProgress>) -> Self {
ProgressSink(Some(sink))
}
pub fn emit(&self, event: DownloadEvent) {
if let Some(s) = &self.0 {
s.on_event(&event);
}
}
pub fn is_active(&self) -> bool {
self.0.is_some()
}
}
fn model_locks() -> &'static Mutex<HashMap<String, Arc<AsyncMutex<()>>>> {
static LOCKS: OnceLock<Mutex<HashMap<String, Arc<AsyncMutex<()>>>>> = OnceLock::new();
LOCKS.get_or_init(|| Mutex::new(HashMap::new()))
}
pub async fn acquire_model_lock(model_id: &str) -> OwnedMutexGuard<()> {
let lock = {
let mut map = model_locks().lock().unwrap();
map.entry(model_id.to_string())
.or_insert_with(|| Arc::new(AsyncMutex::new(())))
.clone()
};
lock.lock_owned().await
}
pub fn check_disk_space(path: &std::path::Path, needed_mb: u64) -> Result<(), String> {
if needed_mb == 0 {
return Ok(());
}
let Some(available_mb) = available_disk_mb(path) else {
return Ok(());
};
let required = needed_mb.saturating_add(1024);
if available_mb < required {
return Err(format!(
"not enough disk space: need ~{} MB (+1 GB free), but only {} MB available at {}",
needed_mb,
available_mb,
path.display()
));
}
Ok(())
}
fn available_disk_mb(path: &std::path::Path) -> Option<u64> {
let mut probe = path;
loop {
if probe.exists() {
break;
}
probe = probe.parent()?;
}
#[cfg(unix)]
{
let out = std::process::Command::new("df")
.arg("-Pk")
.arg(probe)
.output()
.ok()?;
let text = String::from_utf8(out.stdout).ok()?;
let avail_kb: u64 = text
.lines()
.nth(1)?
.split_whitespace()
.nth(3)?
.parse()
.ok()?;
Some(avail_kb / 1024)
}
#[cfg(not(unix))]
{
let _ = probe;
None
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Mutex;
#[derive(Default)]
struct Recorder {
events: Mutex<Vec<DownloadEvent>>,
}
impl DownloadProgress for Recorder {
fn on_event(&self, event: &DownloadEvent) {
self.events.lock().unwrap().push(event.clone());
}
}
#[test]
fn none_sink_is_inert() {
let s = ProgressSink::none();
assert!(!s.is_active());
s.emit(DownloadEvent::Completed { model: "x".into() }); }
#[test]
fn sink_records_events_in_order() {
let rec = Arc::new(Recorder::default());
let sink = ProgressSink::new(rec.clone());
assert!(sink.is_active());
sink.emit(DownloadEvent::Started {
model: "Qwen3-4B".into(),
total_files: 2,
total_mb: 2500,
});
sink.emit(DownloadEvent::FileCompleted {
filename: "model.gguf".into(),
});
sink.emit(DownloadEvent::Completed {
model: "Qwen3-4B".into(),
});
let evs = rec.events.lock().unwrap();
assert_eq!(evs.len(), 3);
assert!(matches!(evs[0], DownloadEvent::Started { total_files: 2, .. }));
assert!(matches!(evs[2], DownloadEvent::Completed { .. }));
}
#[test]
fn event_serializes_with_tag() {
let json = serde_json::to_string(&DownloadEvent::FileStarted {
filename: "model.gguf".into(),
index: 1,
total_files: 2,
size_mb: 2400,
})
.unwrap();
assert!(json.contains("\"event\":\"file_started\""));
assert!(json.contains("\"size_mb\":2400"));
}
#[tokio::test]
async fn same_model_lock_is_exclusive_distinct_ids_are_not() {
use std::time::Duration;
let (a, b) = ("lock-test-a", "lock-test-b");
let a1 = acquire_model_lock(a).await;
let _b = tokio::time::timeout(Duration::from_millis(200), acquire_model_lock(b))
.await
.expect("distinct id must not block");
let contended =
tokio::time::timeout(Duration::from_millis(50), acquire_model_lock(a)).await;
assert!(contended.is_err(), "same-id lock should be contended");
drop(a1);
tokio::time::timeout(Duration::from_millis(500), acquire_model_lock(a))
.await
.expect("acquires after release");
}
#[test]
fn zero_needed_skips_disk_check() {
assert!(check_disk_space(std::path::Path::new("/nonexistent/x"), 0).is_ok());
}
#[test]
fn absurd_size_is_rejected_when_probe_succeeds() {
let tmp = std::env::temp_dir();
if available_disk_mb(&tmp).is_some() {
let res = check_disk_space(&tmp, u64::MAX / (1024 * 1024) - 2048);
assert!(res.is_err(), "expected disk-space rejection, got {res:?}");
}
}
}