use std::path::Path;
use crate::AsrError;
#[derive(Debug, Clone)]
pub struct DownloadProgress {
pub file: String,
pub file_index: usize,
pub file_count: usize,
pub bytes_done: u64,
pub bytes_total: Option<u64>,
}
pub fn download_files_with_progress<F>(
repo_id: &str,
files: &[&str],
dest_dir: &Path,
mut on_progress: F,
) -> Result<(), AsrError>
where
F: FnMut(DownloadProgress),
{
use hf_hub::api::sync::Api;
std::fs::create_dir_all(dest_dir).map_err(|e| {
AsrError::Backend(format!(
"creating model directory {}: {e}",
dest_dir.display()
))
})?;
let file_count = files.len();
let api = Api::new().map_err(|e| AsrError::Backend(format!("hf-hub init failed: {e}")))?;
let repo = api.model(repo_id.to_string());
for (idx, name) in files.iter().enumerate() {
let file_index = idx + 1;
let adapter = CallbackProgress {
file: (*name).to_string(),
file_index,
file_count,
bytes_done: 0,
bytes_total: None,
on_progress: &mut on_progress,
};
tracing::debug!(
repo_id,
file = name,
file_index,
file_count,
"fetching from HuggingFace with progress"
);
let src = repo
.download_with_progress(name, adapter)
.map_err(|e| AsrError::Backend(format!("hf-hub download of {name} failed: {e}")))?;
let dest = dest_dir.join(name);
std::fs::copy(&src, &dest).map_err(|e| {
AsrError::Backend(format!(
"copying {} to {}: {e}",
src.display(),
dest.display()
))
})?;
}
Ok(())
}
struct CallbackProgress<'a, F: FnMut(DownloadProgress)> {
file: String,
file_index: usize,
file_count: usize,
bytes_done: u64,
bytes_total: Option<u64>,
on_progress: &'a mut F,
}
impl<F: FnMut(DownloadProgress)> CallbackProgress<'_, F> {
fn emit(&mut self) {
(self.on_progress)(DownloadProgress {
file: self.file.clone(),
file_index: self.file_index,
file_count: self.file_count,
bytes_done: self.bytes_done,
bytes_total: self.bytes_total,
});
}
}
impl<F: FnMut(DownloadProgress)> hf_hub::api::Progress for CallbackProgress<'_, F> {
fn init(&mut self, size: usize, _filename: &str) {
self.bytes_total = Some(size as u64);
self.bytes_done = 0;
self.emit();
}
fn update(&mut self, size: usize) {
self.bytes_done = self.bytes_done.saturating_add(size as u64);
self.emit();
}
fn finish(&mut self) {
if let Some(total) = self.bytes_total {
if self.bytes_done < total {
self.bytes_done = total;
self.emit();
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn callback_progress_reports_init_update_and_finish() {
use hf_hub::api::Progress;
let events = std::cell::RefCell::new(Vec::<DownloadProgress>::new());
let mut on_progress = |p: DownloadProgress| events.borrow_mut().push(p);
let mut adapter = CallbackProgress {
file: "encoder.onnx".to_string(),
file_index: 1,
file_count: 4,
bytes_done: 0,
bytes_total: None,
on_progress: &mut on_progress,
};
adapter.init(1000, "encoder.onnx");
adapter.update(400);
adapter.update(600);
adapter.finish();
let got = events.into_inner();
assert_eq!(got.len(), 3);
assert_eq!(got[0].bytes_done, 0);
assert_eq!(got[0].bytes_total, Some(1000));
assert_eq!(got[0].file_index, 1);
assert_eq!(got[0].file_count, 4);
assert_eq!(got[0].file, "encoder.onnx");
assert_eq!(got[1].bytes_done, 400);
assert_eq!(got[2].bytes_done, 1000);
}
#[test]
fn callback_progress_finish_synthesizes_completion() {
use hf_hub::api::Progress;
let events = std::cell::RefCell::new(Vec::<DownloadProgress>::new());
let mut on_progress = |p: DownloadProgress| events.borrow_mut().push(p);
let mut adapter = CallbackProgress {
file: "tokens.txt".to_string(),
file_index: 4,
file_count: 4,
bytes_done: 0,
bytes_total: None,
on_progress: &mut on_progress,
};
adapter.init(500, "tokens.txt");
adapter.update(400); adapter.finish();
let got = events.into_inner();
assert_eq!(got.len(), 3);
assert_eq!(got.last().unwrap().bytes_done, 500);
assert_eq!(got.last().unwrap().bytes_total, Some(500));
}
}