Skip to main content

hf_fetch_model/
progress.rs

1// SPDX-License-Identifier: MIT OR Apache-2.0
2
3//! Progress reporting for model downloads.
4//!
5//! [`ProgressEvent`] carries per-file and overall download status.
6//! When the `indicatif` feature is enabled, `IndicatifProgress`
7//! provides multi-progress bars out of the box.
8
9/// A progress event emitted during download.
10///
11/// Passed to the `on_progress` callback on [`crate::FetchConfig`].
12#[derive(Debug, Clone, Default)]
13pub struct ProgressEvent {
14    /// The filename currently being downloaded.
15    pub filename: String,
16    /// Bytes downloaded so far for this file.
17    pub bytes_downloaded: u64,
18    /// Total size of this file in bytes (0 if unknown).
19    pub bytes_total: u64,
20    /// Download percentage for this file (0.0–100.0).
21    pub percent: f64,
22    /// Number of files still remaining (after this one).
23    pub files_remaining: usize,
24}
25
26/// Creates a [`ProgressEvent`] for a completed file.
27#[must_use]
28pub(crate) fn completed_event(filename: &str, size: u64, files_remaining: usize) -> ProgressEvent {
29    ProgressEvent {
30        filename: filename.to_owned(),
31        bytes_downloaded: size,
32        bytes_total: size,
33        percent: 100.0,
34        files_remaining,
35    }
36}
37
38/// Creates a [`ProgressEvent`] for an in-progress file (streaming update).
39#[must_use]
40pub(crate) fn streaming_event(
41    filename: &str,
42    bytes_downloaded: u64,
43    bytes_total: u64,
44    files_remaining: usize,
45) -> ProgressEvent {
46    #[allow(clippy::cast_precision_loss, clippy::as_conversions)]
47    // CAST: u64 → f64, precision loss acceptable; values are display-only percentage scalars
48    let percent = if bytes_total > 0 {
49        (bytes_downloaded as f64 / bytes_total as f64) * 100.0
50    } else {
51        0.0
52    };
53    ProgressEvent {
54        filename: filename.to_owned(),
55        bytes_downloaded,
56        bytes_total,
57        percent,
58        files_remaining,
59    }
60}
61
62/// A watch-based receiver for [`ProgressEvent`] updates.
63///
64/// Obtained from
65/// [`FetchConfigBuilder::progress_channel()`](crate::FetchConfigBuilder::progress_channel).
66/// Call `.changed().await` to wait for the next update, then `.borrow()` to read
67/// the latest event. Only the most recent event is retained — intermediate
68/// updates that arrive between `.changed()` polls are coalesced.
69pub type ProgressReceiver = tokio::sync::watch::Receiver<ProgressEvent>;
70
71/// Multi-progress bar display using `indicatif`.
72///
73/// Available only when the `indicatif` feature is enabled.
74///
75/// # Example
76///
77/// ```rust,no_run
78/// # fn example() -> Result<(), hf_fetch_model::FetchError> {
79/// use hf_fetch_model::FetchConfig;
80/// # #[cfg(feature = "indicatif")]
81/// use hf_fetch_model::progress::IndicatifProgress;
82///
83/// # #[cfg(feature = "indicatif")]
84/// let progress = IndicatifProgress::new();
85/// let config = FetchConfig::builder()
86///     # ;
87///     # #[cfg(feature = "indicatif")]
88///     # let config = FetchConfig::builder()
89///     .on_progress(move |e| progress.handle(e))
90///     .build()?;
91/// # Ok(())
92/// # }
93/// ```
94#[cfg(feature = "indicatif")]
95pub struct IndicatifProgress {
96    // Multi-progress container for all bars.
97    multi: indicatif::MultiProgress,
98    // Overall file-count bar (always the last bar in the display).
99    overall: indicatif::ProgressBar,
100    // Per-file progress bars, keyed by filename.
101    file_bars: std::sync::Mutex<std::collections::HashMap<String, indicatif::ProgressBar>>,
102    // Filenames already counted as complete (deduplicates chunked + orchestrator events).
103    completed_files: std::sync::Mutex<std::collections::HashSet<String>>,
104    // Guards against double-finish on drop.
105    finished: std::sync::atomic::AtomicBool,
106}
107
108#[cfg(feature = "indicatif")]
109impl IndicatifProgress {
110    /// Creates a new multi-progress bar display.
111    ///
112    /// Call [`IndicatifProgress::set_total_files`] once the file count is known.
113    #[must_use]
114    pub fn new() -> Self {
115        let multi = indicatif::MultiProgress::new();
116        let overall = multi.add(indicatif::ProgressBar::new(0));
117        overall.set_style(
118            indicatif::ProgressStyle::default_bar()
119                .template("{msg} [{bar:40.cyan/blue}] {pos}/{len} files")
120                .ok()
121                .unwrap_or_else(indicatif::ProgressStyle::default_bar)
122                .progress_chars("=> "),
123        );
124        overall.set_message("Overall");
125        Self {
126            multi,
127            overall,
128            file_bars: std::sync::Mutex::new(std::collections::HashMap::new()),
129            completed_files: std::sync::Mutex::new(std::collections::HashSet::new()),
130            finished: std::sync::atomic::AtomicBool::new(false),
131        }
132    }
133
134    /// Returns a reference to the underlying [`indicatif::MultiProgress`].
135    ///
136    /// Useful for adding custom progress bars alongside the built-in ones.
137    #[must_use]
138    pub fn multi(&self) -> &indicatif::MultiProgress {
139        &self.multi
140    }
141
142    /// Sets the total number of files to download.
143    pub fn set_total_files(&self, total: u64) {
144        self.overall.set_length(total);
145    }
146
147    /// Handles a [`ProgressEvent`], updating progress bars.
148    ///
149    /// For in-progress events, creates or updates a per-file progress bar
150    /// showing bytes downloaded, throughput, and ETA. On completion, the
151    /// per-file bar is finished and the overall file counter is incremented.
152    pub fn handle(&self, event: &ProgressEvent) {
153        if event.percent >= 100.0 {
154            // Remove and finish per-file bar if it exists.
155            if let Ok(mut bars) = self.file_bars.lock() {
156                if let Some(bar) = bars.remove(&event.filename) {
157                    bar.finish_and_clear();
158                }
159            }
160            // Deduplicate: chunked downloads fire a streaming 100% event,
161            // then the orchestrator fires a completed_event for the same file.
162            let is_new = self
163                .completed_files
164                .lock()
165                .is_ok_and(|mut set| set.insert(event.filename.clone()));
166            if is_new {
167                // Derive total: completed so far + this file + remaining
168                // EXPLICIT: try_from for usize → u64 (infallible on 64-bit, safe fallback otherwise)
169                let remaining = u64::try_from(event.files_remaining).unwrap_or(u64::MAX);
170                let total = self.overall.position() + 1 + remaining;
171                self.overall.set_length(total);
172                self.overall.inc(1);
173            }
174        } else if event.bytes_total > 0 {
175            // In-progress streaming update — create or update per-file bar.
176            if let Ok(mut bars) = self.file_bars.lock() {
177                let bar = bars.entry(event.filename.clone()).or_insert_with(|| {
178                    let pb = self.multi.insert_before(
179                        &self.overall,
180                        indicatif::ProgressBar::new(event.bytes_total),
181                    );
182                    pb.set_style(
183                        indicatif::ProgressStyle::default_bar()
184                            .template(
185                                "{msg} [{bar:40.green/dim}] {bytes}/{total_bytes} {bytes_per_sec} ({eta})",
186                            )
187                            .ok()
188                            .unwrap_or_else(indicatif::ProgressStyle::default_bar)
189                            .progress_chars("=> "),
190                    );
191                    pb.set_message(event.filename.clone());
192                    pb
193                });
194                bar.set_position(event.bytes_downloaded);
195            }
196        }
197    }
198
199    /// Finishes the progress bar, ensuring the final state is rendered.
200    ///
201    /// Called automatically on drop, but can be called explicitly for
202    /// immediate visual feedback.
203    pub fn finish(&self) {
204        if !self
205            .finished
206            .swap(true, std::sync::atomic::Ordering::Relaxed)
207        {
208            self.overall.finish();
209        }
210    }
211}
212
213#[cfg(feature = "indicatif")]
214impl Drop for IndicatifProgress {
215    fn drop(&mut self) {
216        self.finish();
217    }
218}
219
220#[cfg(feature = "indicatif")]
221impl Default for IndicatifProgress {
222    fn default() -> Self {
223        Self::new()
224    }
225}