hf_hub_simple_progress/
sync.rs

1use crate::{DownloadState, ProgressEvent};
2use hf_hub::api::Progress;
3
4struct CallbackStorage<C> {
5    download_state: Option<DownloadState>,
6    callback: C,
7}
8
9impl<C> Progress for CallbackStorage<C>
10where
11    C: FnMut(ProgressEvent),
12{
13    fn init(&mut self, size: usize, filename: &str) {
14        self.download_state = Some(DownloadState::new(size, filename));
15    }
16
17    fn update(&mut self, size: usize) {
18        if let Some(delta) = self.download_state.as_mut().unwrap().update(size) {
19            (self.callback)(delta);
20        }
21    }
22
23    fn finish(&mut self) {
24        // Nothing to do
25    }
26}
27
28/// Build a [hf_hub::api::Progress] that encapsulate the provided callback
29pub fn callback_builder(callback: impl FnMut(ProgressEvent) + 'static) -> impl Progress {
30    CallbackStorage {
31        download_state: None,
32        callback: Box::new(callback),
33    }
34}
35
36#[cfg(test)]
37mod tests {
38    use crate::ProgressEvent;
39    use crate::sync::callback_builder;
40    use hf_hub::api::sync::ApiBuilder;
41    use std::rc::Rc;
42    use std::sync::atomic::AtomicBool;
43
44    #[test]
45    fn it_works() {
46        let done = Rc::new(AtomicBool::new(false));
47        let done_copy = done.clone();
48        let api = ApiBuilder::new().build().unwrap();
49        let callback = callback_builder(move |progress: ProgressEvent| {
50            if progress.percentage == 1. {
51                done_copy.store(true, std::sync::atomic::Ordering::Relaxed);
52            }
53        });
54        api.model("ggerganov/whisper.cpp".to_string())
55            .download_with_progress("ggml-tiny-q5_1.bin", callback)
56            .unwrap();
57        assert!(done.load(std::sync::atomic::Ordering::Relaxed));
58    }
59}