hf_hub_simple_progress/
sync.rs1use crate::{DownloadState, ProgressEvent};
2use hf_hub::api::Progress;
3
4struct CallbackStorage<C> {
5 download_state: Option<DownloadState>,
6 callback: C,
7}
8
9impl<C> Progress for CallbackStorage<C>
10where
11 C: FnMut(ProgressEvent),
12{
13 fn init(&mut self, size: usize, filename: &str) {
14 self.download_state = Some(DownloadState::new(size, filename));
15 }
16
17 fn update(&mut self, size: usize) {
18 if let Some(delta) = self.download_state.as_mut().unwrap().update(size) {
19 (self.callback)(delta);
20 }
21 }
22
23 fn finish(&mut self) {
24 }
26}
27
28pub fn callback_builder(callback: impl FnMut(ProgressEvent) + 'static) -> impl Progress {
30 CallbackStorage {
31 download_state: None,
32 callback: Box::new(callback),
33 }
34}
35
36#[cfg(test)]
37mod tests {
38 use crate::ProgressEvent;
39 use crate::sync::callback_builder;
40 use hf_hub::api::sync::ApiBuilder;
41 use std::rc::Rc;
42 use std::sync::atomic::AtomicBool;
43
44 #[test]
45 fn it_works() {
46 let done = Rc::new(AtomicBool::new(false));
47 let done_copy = done.clone();
48 let api = ApiBuilder::new().build().unwrap();
49 let callback = callback_builder(move |progress: ProgressEvent| {
50 if progress.percentage == 1. {
51 done_copy.store(true, std::sync::atomic::Ordering::Relaxed);
52 }
53 });
54 api.model("ggerganov/whisper.cpp".to_string())
55 .download_with_progress("ggml-tiny-q5_1.bin", callback)
56 .unwrap();
57 assert!(done.load(std::sync::atomic::Ordering::Relaxed));
58 }
59}