Skip to main content

isr_dl/
progress.rs

1use std::{
2    io::{Error, Write},
3    path::{Path, PathBuf},
4    sync::Arc,
5};
6
7use url::Url;
8
9/// Progress event emitted during download and extraction operations.
10pub enum ProgressEvent<'a> {
11    /// An HTTP download has started.
12    DownloadStarted {
13        /// URL being downloaded.
14        url: &'a Url,
15
16        /// Total size in bytes, if known from the `Content-Length` header.
17        total_bytes: Option<u64>,
18    },
19
20    /// Bytes have been received.
21    DownloadProgress {
22        /// URL being downloaded.
23        url: &'a Url,
24
25        /// Number of bytes received so far.
26        bytes: u64,
27
28        /// Total size in bytes, if known from the `Content-Length` header.
29        total_bytes: Option<u64>,
30    },
31
32    /// An HTTP download has completed.
33    DownloadComplete {
34        /// URL that was downloaded.
35        url: &'a Url,
36    },
37
38    /// Extraction of a file from an archive has started.
39    ExtractStarted {
40        /// Path of the file being extracted.
41        path: &'a Path,
42
43        /// Total uncompressed size in bytes, if known.
44        total_bytes: Option<u64>,
45    },
46
47    /// Bytes have been extracted from an archive.
48    ExtractProgress {
49        /// Path of the file being extracted.
50        path: &'a Path,
51
52        /// Number of bytes extracted so far.
53        bytes: u64,
54
55        /// Total uncompressed size in bytes, if known.
56        total_bytes: Option<u64>,
57    },
58
59    /// Extraction of a file from an archive has completed.
60    ExtractComplete {
61        /// Path of the file that was extracted.
62        path: &'a Path,
63    },
64}
65
66/// Shared, cloneable progress callback.
67pub type ProgressFn = Arc<dyn Fn(ProgressEvent<'_>) + Send + Sync>;
68
69/// Distinguishes download vs extraction context for [`ProgressWriter`].
70pub enum ProgressContext {
71    /// HTTP download context.
72    Download {
73        /// URL being downloaded.
74        url: Url,
75    },
76
77    /// Archive extraction context.
78    Extract {
79        /// Path of the file being extracted.
80        path: PathBuf,
81    },
82}
83
84/// A [`Write`](std::io::Write) adapter that optionally reports progress.
85///
86/// When `progress` is `None`, the writer is a transparent passthrough.
87/// When `progress` is `Some`:
88/// - Construction emits `DownloadStarted` or `ExtractStarted`.
89/// - Each `write()` emits `DownloadProgress` or `ExtractProgress`.
90/// - `Drop` emits `DownloadComplete` or `ExtractComplete`.
91pub struct ProgressWriter<W> {
92    /// Underlying writer receiving the bytes.
93    inner: W,
94
95    /// Optional progress callback; `None` disables reporting.
96    progress: Option<ProgressFn>,
97
98    /// Download vs extraction discriminator, carries the URL/path.
99    context: ProgressContext,
100
101    /// Running total of bytes written.
102    written: u64,
103
104    /// Expected total size, if known.
105    total_bytes: Option<u64>,
106}
107
108impl<W> ProgressWriter<W> {
109    /// Creates a writer for an HTTP download.
110    pub fn for_download(
111        progress: Option<ProgressFn>,
112        inner: W,
113        url: &Url,
114        total_bytes: Option<u64>,
115    ) -> Self {
116        let url = url.clone();
117
118        if let Some(progress) = &progress {
119            progress(ProgressEvent::DownloadStarted {
120                url: &url,
121                total_bytes,
122            });
123        }
124
125        Self {
126            inner,
127            progress,
128            context: ProgressContext::Download { url },
129            written: 0,
130            total_bytes,
131        }
132    }
133
134    /// Creates a writer for archive extraction.
135    pub fn for_extract(
136        progress: Option<ProgressFn>,
137        inner: W,
138        path: impl Into<PathBuf>,
139        total_bytes: Option<u64>,
140    ) -> Self {
141        let path = path.into();
142
143        if let Some(progress) = &progress {
144            progress(ProgressEvent::ExtractStarted {
145                path: &path,
146                total_bytes,
147            });
148        }
149
150        Self {
151            inner,
152            progress,
153            context: ProgressContext::Extract { path },
154            written: 0,
155            total_bytes,
156        }
157    }
158}
159
160impl<W> Write for ProgressWriter<W>
161where
162    W: Write,
163{
164    fn write(&mut self, buf: &[u8]) -> Result<usize, Error> {
165        let n = self.inner.write(buf)?;
166        self.written += n as u64;
167
168        if let Some(progress) = &self.progress {
169            match &self.context {
170                ProgressContext::Download { url } => {
171                    progress(ProgressEvent::DownloadProgress {
172                        url,
173                        bytes: self.written,
174                        total_bytes: self.total_bytes,
175                    });
176                }
177                ProgressContext::Extract { path } => {
178                    progress(ProgressEvent::ExtractProgress {
179                        path,
180                        bytes: self.written,
181                        total_bytes: self.total_bytes,
182                    });
183                }
184            }
185        }
186
187        Ok(n)
188    }
189
190    fn flush(&mut self) -> Result<(), Error> {
191        self.inner.flush()
192    }
193}
194
195impl<W> Drop for ProgressWriter<W> {
196    fn drop(&mut self) {
197        if let Some(progress) = &self.progress {
198            match &self.context {
199                ProgressContext::Download { url } => {
200                    progress(ProgressEvent::DownloadComplete { url });
201                }
202                ProgressContext::Extract { path } => {
203                    progress(ProgressEvent::ExtractComplete { path });
204                }
205            }
206        }
207    }
208}
209
210#[cfg(test)]
211mod tests {
212    use std::{
213        io::Write,
214        path::Path,
215        sync::{Arc, Mutex},
216    };
217
218    use super::*;
219
220    fn capture_progress() -> (ProgressFn, Arc<Mutex<Vec<String>>>) {
221        let events = Arc::new(Mutex::new(Vec::new()));
222        let events_clone = Arc::clone(&events);
223        let progress = Arc::new(move |event: ProgressEvent<'_>| {
224            let v = match event {
225                ProgressEvent::DownloadStarted { url, total_bytes } => {
226                    format!("dl-start:{url}:{total_bytes:?}")
227                }
228                ProgressEvent::DownloadProgress {
229                    url,
230                    bytes,
231                    total_bytes,
232                } => {
233                    format!("dl-progress:{url}:{bytes}:{total_bytes:?}")
234                }
235                ProgressEvent::DownloadComplete { url } => {
236                    format!("dl-complete:{url}")
237                }
238                ProgressEvent::ExtractStarted { path, total_bytes } => {
239                    format!("ex-start:{}:{total_bytes:?}", path.display())
240                }
241                ProgressEvent::ExtractProgress {
242                    path,
243                    bytes,
244                    total_bytes,
245                } => {
246                    format!("ex-progress:{}:{bytes}:{total_bytes:?}", path.display())
247                }
248                ProgressEvent::ExtractComplete { path } => {
249                    format!("ex-complete:{}", path.display())
250                }
251            };
252            events_clone.lock().unwrap().push(v);
253        });
254
255        (progress, events)
256    }
257
258    #[test]
259    fn download_emits_start_progress_complete() {
260        let (progress, events) = capture_progress();
261        let mut buf = Vec::new();
262        let url = Url::parse("http://example.com/file.pdb").unwrap();
263
264        {
265            let mut w = ProgressWriter::for_download(Some(progress), &mut buf, &url, Some(10));
266            w.write_all(b"hello").unwrap();
267            w.write_all(b"world").unwrap();
268        }
269
270        assert_eq!(buf, b"helloworld");
271        let events = events.lock().unwrap();
272        assert_eq!(events.len(), 4);
273        assert_eq!(events[0], "dl-start:http://example.com/file.pdb:Some(10)");
274        assert_eq!(
275            events[1],
276            "dl-progress:http://example.com/file.pdb:5:Some(10)"
277        );
278        assert_eq!(
279            events[2],
280            "dl-progress:http://example.com/file.pdb:10:Some(10)"
281        );
282        assert_eq!(events[3], "dl-complete:http://example.com/file.pdb");
283    }
284
285    #[test]
286    fn extract_emits_start_progress_complete() {
287        let (progress, events) = capture_progress();
288        let mut buf = Vec::new();
289
290        {
291            let mut w = ProgressWriter::for_extract(
292                Some(progress),
293                &mut buf,
294                Path::new("/tmp/vmlinux"),
295                Some(100),
296            );
297            w.write_all(b"data").unwrap();
298        }
299
300        assert_eq!(buf, b"data");
301        let events = events.lock().unwrap();
302        assert_eq!(events.len(), 3);
303        assert_eq!(events[0], "ex-start:/tmp/vmlinux:Some(100)");
304        assert_eq!(events[1], "ex-progress:/tmp/vmlinux:4:Some(100)");
305        assert_eq!(events[2], "ex-complete:/tmp/vmlinux");
306    }
307
308    #[test]
309    fn none_progress_is_passthrough() {
310        let mut buf = Vec::new();
311        let url = Url::parse("http://example.com/file.pdb").unwrap();
312
313        {
314            let mut w = ProgressWriter::for_download(None, &mut buf, &url, Some(10));
315            w.write_all(b"hello").unwrap();
316        }
317
318        assert_eq!(buf, b"hello");
319    }
320
321    #[test]
322    fn unknown_total_bytes() {
323        let (progress, events) = capture_progress();
324        let mut buf = Vec::new();
325        let url = Url::parse("http://example.com/x").unwrap();
326        {
327            let mut w = ProgressWriter::for_download(Some(progress), &mut buf, &url, None);
328            w.write_all(b"abc").unwrap();
329        }
330        let events = events.lock().unwrap();
331        assert_eq!(events[0], "dl-start:http://example.com/x:None");
332        assert_eq!(events[1], "dl-progress:http://example.com/x:3:None");
333        assert_eq!(events[2], "dl-complete:http://example.com/x");
334    }
335}