use std::path::{Path, PathBuf};
use crate::AsrError;
#[derive(Debug, Clone)]
pub struct DownloadProgress {
pub file: String,
pub file_index: usize,
pub file_count: usize,
pub bytes_done: u64,
pub bytes_total: Option<u64>,
}
pub fn download_files_with_progress<F>(
repo_id: &str,
files: &[&str],
dest_dir: &Path,
mut on_progress: F,
) -> Result<(), AsrError>
where
F: FnMut(DownloadProgress),
{
use hf_hub::api::sync::Api;
std::fs::create_dir_all(dest_dir).map_err(|e| {
AsrError::Backend(format!(
"creating model directory {}: {e}",
dest_dir.display()
))
})?;
let file_count = files.len();
let api = Api::new().map_err(|e| AsrError::Backend(format!("hf-hub init failed: {e}")))?;
let repo = api.model(repo_id.to_string());
for (idx, name) in files.iter().enumerate() {
let file_index = idx + 1;
let adapter = CallbackProgress {
file: (*name).to_string(),
file_index,
file_count,
bytes_done: 0,
bytes_total: None,
on_progress: &mut on_progress,
};
tracing::debug!(
repo_id,
file = name,
file_index,
file_count,
"fetching from HuggingFace with progress"
);
let src = repo
.download_with_progress(name, adapter)
.map_err(|e| AsrError::Backend(format!("hf-hub download of {name} failed: {e}")))?;
let dest = dest_dir.join(name);
std::fs::copy(&src, &dest).map_err(|e| {
AsrError::Backend(format!(
"copying {} to {}: {e}",
src.display(),
dest.display()
))
})?;
}
Ok(())
}
struct CallbackProgress<'a, F: FnMut(DownloadProgress)> {
file: String,
file_index: usize,
file_count: usize,
bytes_done: u64,
bytes_total: Option<u64>,
on_progress: &'a mut F,
}
impl<F: FnMut(DownloadProgress)> CallbackProgress<'_, F> {
fn emit(&mut self) {
(self.on_progress)(DownloadProgress {
file: self.file.clone(),
file_index: self.file_index,
file_count: self.file_count,
bytes_done: self.bytes_done,
bytes_total: self.bytes_total,
});
}
}
impl<F: FnMut(DownloadProgress)> hf_hub::api::Progress for CallbackProgress<'_, F> {
fn init(&mut self, size: usize, _filename: &str) {
self.bytes_total = Some(size as u64);
self.bytes_done = 0;
self.emit();
}
fn update(&mut self, size: usize) {
self.bytes_done = self.bytes_done.saturating_add(size as u64);
self.emit();
}
fn finish(&mut self) {
if let Some(total) = self.bytes_total {
if self.bytes_done < total {
self.bytes_done = total;
self.emit();
}
}
}
}
pub fn is_repo_cached(repo_id: &str, files: &[&str]) -> bool {
let Some(cache_root) = hf_hub_cache_root() else {
return false;
};
snapshot_satisfies_files(&cache_root, repo_id, files)
}
fn hf_hub_cache_root() -> Option<PathBuf> {
if let Some(home) = std::env::var_os("HF_HOME") {
let mut p = PathBuf::from(home);
p.push("hub");
return Some(p);
}
let home = std::env::var_os("HOME").or_else(|| std::env::var_os("USERPROFILE"))?;
let mut p = PathBuf::from(home);
p.extend([".cache", "huggingface", "hub"]);
Some(p)
}
fn snapshot_satisfies_files(cache_root: &Path, repo_id: &str, files: &[&str]) -> bool {
let repo_dir = cache_root.join(format!("models--{}", repo_id.replace('/', "--")));
let snapshots = repo_dir.join("snapshots");
let Ok(entries) = std::fs::read_dir(&snapshots) else {
return false;
};
for entry in entries.flatten() {
let snap = entry.path();
if !snap.is_dir() {
continue;
}
if files.iter().all(|f| snap.join(f).exists()) {
return true;
}
}
false
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn callback_progress_reports_init_update_and_finish() {
use hf_hub::api::Progress;
let events = std::cell::RefCell::new(Vec::<DownloadProgress>::new());
let mut on_progress = |p: DownloadProgress| events.borrow_mut().push(p);
let mut adapter = CallbackProgress {
file: "encoder.onnx".to_string(),
file_index: 1,
file_count: 4,
bytes_done: 0,
bytes_total: None,
on_progress: &mut on_progress,
};
adapter.init(1000, "encoder.onnx");
adapter.update(400);
adapter.update(600);
adapter.finish();
let got = events.into_inner();
assert_eq!(got.len(), 3);
assert_eq!(got[0].bytes_done, 0);
assert_eq!(got[0].bytes_total, Some(1000));
assert_eq!(got[0].file_index, 1);
assert_eq!(got[0].file_count, 4);
assert_eq!(got[0].file, "encoder.onnx");
assert_eq!(got[1].bytes_done, 400);
assert_eq!(got[2].bytes_done, 1000);
}
#[test]
fn callback_progress_finish_synthesizes_completion() {
use hf_hub::api::Progress;
let events = std::cell::RefCell::new(Vec::<DownloadProgress>::new());
let mut on_progress = |p: DownloadProgress| events.borrow_mut().push(p);
let mut adapter = CallbackProgress {
file: "tokens.txt".to_string(),
file_index: 4,
file_count: 4,
bytes_done: 0,
bytes_total: None,
on_progress: &mut on_progress,
};
adapter.init(500, "tokens.txt");
adapter.update(400); adapter.finish();
let got = events.into_inner();
assert_eq!(got.len(), 3);
assert_eq!(got.last().unwrap().bytes_done, 500);
assert_eq!(got.last().unwrap().bytes_total, Some(500));
}
#[test]
fn snapshot_probe_matches_hf_hub_layout() {
let root = unique_temp_dir("snapshot-probe");
let _cleanup = scopeguard_remove(root.clone());
std::fs::create_dir_all(&root).unwrap();
let repo_id = "csukuangfj/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20";
let files = ["encoder.onnx", "decoder.onnx", "tokens.txt"];
assert!(!snapshot_satisfies_files(&root, repo_id, &files));
let repo_dir = root.join(format!("models--{}", repo_id.replace('/', "--")));
let snap_a = repo_dir.join("snapshots").join("abc123");
std::fs::create_dir_all(&snap_a).unwrap();
assert!(!snapshot_satisfies_files(&root, repo_id, &files));
std::fs::write(snap_a.join("encoder.onnx"), b"").unwrap();
std::fs::write(snap_a.join("decoder.onnx"), b"").unwrap();
assert!(!snapshot_satisfies_files(&root, repo_id, &files));
std::fs::write(snap_a.join("tokens.txt"), b"").unwrap();
assert!(snapshot_satisfies_files(&root, repo_id, &files));
let snap_b = repo_dir.join("snapshots").join("def456");
std::fs::create_dir_all(&snap_b).unwrap();
std::fs::write(snap_b.join("encoder.onnx"), b"").unwrap();
assert!(snapshot_satisfies_files(&root, repo_id, &files));
}
#[test]
fn snapshot_probe_flattens_slashes_in_repo_id() {
let root = unique_temp_dir("slash-flatten");
let _cleanup = scopeguard_remove(root.clone());
let repo_id = "owner/repo";
let snap = root
.join("models--owner--repo")
.join("snapshots")
.join("c0");
std::fs::create_dir_all(&snap).unwrap();
std::fs::write(snap.join("model.onnx"), b"").unwrap();
assert!(snapshot_satisfies_files(&root, repo_id, &["model.onnx"]));
}
fn unique_temp_dir(label: &str) -> std::path::PathBuf {
use std::time::{SystemTime, UNIX_EPOCH};
let nanos = SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.as_nanos())
.unwrap_or(0);
let pid = std::process::id();
std::env::temp_dir().join(format!("wavekat-asr-{label}-{pid}-{nanos}"))
}
fn scopeguard_remove(path: std::path::PathBuf) -> impl Drop {
struct Guard(std::path::PathBuf);
impl Drop for Guard {
fn drop(&mut self) {
let _ = std::fs::remove_dir_all(&self.0);
}
}
Guard(path)
}
}