hf_hub/api/
mod.rs

1use std::{collections::VecDeque, time::Duration};
2
3use indicatif::{style::ProgressTracker, HumanBytes, ProgressBar, ProgressStyle};
4use serde::Deserialize;
5
6/// The asynchronous version of the API
7#[cfg(feature = "tokio")]
8pub mod tokio;
9
10/// The synchronous version of the API
11#[cfg(feature = "ureq")]
12pub mod sync;
13
14const HF_ENDPOINT: &str = "HF_ENDPOINT";
15
16/// This trait is used by users of the lib
17/// to implement custom behavior during file downloads
18pub trait Progress {
19    /// At the start of the download
20    /// The size is the total size in bytes of the file.
21    fn init(&mut self, size: usize, filename: &str);
22    /// This function is called whenever `size` bytes have been
23    /// downloaded in the temporary file
24    fn update(&mut self, size: usize);
25    /// This is called at the end of the download
26    fn finish(&mut self);
27}
28
29impl Progress for () {
30    fn init(&mut self, _size: usize, _filename: &str) {}
31    fn update(&mut self, _size: usize) {}
32    fn finish(&mut self) {}
33}
34
35impl Progress for ProgressBar {
36    fn init(&mut self, size: usize, filename: &str) {
37        self.set_length(size as u64);
38        self.set_style(
39                ProgressStyle::with_template(
40                    "{msg} [{elapsed_precise}] [{wide_bar}] {bytes}/{total_bytes} {bytes_per_sec_smoothed} ({eta})",
41                ).unwrap().with_key("bytes_per_sec_smoothed", MovingAvgRate::default())
42                    ,
43            );
44        let maxlength = 30;
45        let message = if filename.len() > maxlength {
46            format!("..{}", &filename[filename.len() - maxlength..])
47        } else {
48            filename.to_string()
49        };
50        self.set_message(message);
51    }
52
53    fn update(&mut self, size: usize) {
54        self.inc(size as u64)
55    }
56
57    fn finish(&mut self) {
58        ProgressBar::finish(self);
59    }
60}
61
62/// Siblings are simplified file descriptions of remote files on the hub
63#[derive(Debug, Clone, Deserialize, PartialEq)]
64pub struct Siblings {
65    /// The path within the repo.
66    pub rfilename: String,
67}
68
69/// The description of the repo given by the hub
70#[derive(Debug, Clone, Deserialize, PartialEq)]
71pub struct RepoInfo {
72    /// See [`Siblings`]
73    pub siblings: Vec<Siblings>,
74
75    /// The commit sha of the repo.
76    pub sha: String,
77}
78
79#[derive(Clone, Default)]
80struct MovingAvgRate {
81    samples: VecDeque<(std::time::Instant, u64)>,
82}
83
84impl ProgressTracker for MovingAvgRate {
85    fn clone_box(&self) -> Box<dyn ProgressTracker> {
86        Box::new(self.clone())
87    }
88
89    fn tick(&mut self, state: &indicatif::ProgressState, now: std::time::Instant) {
90        // sample at most every 20ms
91        if self
92            .samples
93            .back()
94            .is_none_or(|(prev, _)| (now - *prev) > Duration::from_millis(20))
95        {
96            self.samples.push_back((now, state.pos()));
97        }
98
99        while let Some(first) = self.samples.front() {
100            if now - first.0 > Duration::from_secs(1) {
101                self.samples.pop_front();
102            } else {
103                break;
104            }
105        }
106    }
107
108    fn reset(&mut self, _state: &indicatif::ProgressState, _now: std::time::Instant) {
109        self.samples = Default::default();
110    }
111
112    fn write(&self, _state: &indicatif::ProgressState, w: &mut dyn std::fmt::Write) {
113        match (self.samples.front(), self.samples.back()) {
114            (Some((t0, p0)), Some((t1, p1))) if self.samples.len() > 1 => {
115                let elapsed_ms = (*t1 - *t0).as_millis();
116                let rate = ((p1 - p0) as f64 * 1000f64 / elapsed_ms as f64) as u64;
117                write!(w, "{}/s", HumanBytes(rate)).unwrap()
118            }
119            _ => write!(w, "-").unwrap(),
120        }
121    }
122}