hf_hub_simple_progress/
sync.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
use crate::{DownloadState, ProgressEvent};
use hf_hub::api::Progress;

struct CallbackStorage<C> {
    download_state: Option<DownloadState>,
    callback: C,
}

impl<C> Progress for CallbackStorage<C>
where
    C: FnMut(ProgressEvent),
{
    fn init(&mut self, size: usize, filename: &str) {
        self.download_state = Some(DownloadState::new(size, filename));
    }

    fn update(&mut self, size: usize) {
        if let Some(delta) = self.download_state.as_mut().unwrap().update(size) {
            (self.callback)(delta);
        }
    }

    fn finish(&mut self) {
        // Nothing to do
    }
}

/// Build a [hf_hub::api::Progress] that encapsulate the provided callback
pub fn callback_builder(callback: impl FnMut(ProgressEvent) + 'static) -> impl Progress {
    CallbackStorage {
        download_state: None,
        callback: Box::new(callback),
    }
}

#[cfg(test)]
mod tests {
    use crate::sync::callback_builder;
    use crate::ProgressEvent;
    use hf_hub::api::sync::ApiBuilder;
    use std::rc::Rc;
    use std::sync::atomic::AtomicBool;

    #[test]
    fn it_works() {
        let done = Rc::new(AtomicBool::new(false));
        let done_copy = done.clone();
        let api = ApiBuilder::new().build().unwrap();
        let callback = callback_builder(move |progress: ProgressEvent| {
            if progress.percentage == 1. {
                done_copy.store(true, std::sync::atomic::Ordering::Relaxed);
            }
        });
        api.model("ggerganov/whisper.cpp".to_string())
            .download_with_progress("ggml-tiny-q5_1.bin", callback)
            .unwrap();
        assert!(done.load(std::sync::atomic::Ordering::Relaxed));
    }
}