1use std::{
2 io::{Error, Write},
3 path::{Path, PathBuf},
4 sync::Arc,
5};
6
7use url::Url;
8
9pub enum ProgressEvent<'a> {
11 DownloadStarted {
13 url: &'a Url,
15
16 total_bytes: Option<u64>,
18 },
19
20 DownloadProgress {
22 url: &'a Url,
24
25 bytes: u64,
27
28 total_bytes: Option<u64>,
30 },
31
32 DownloadComplete {
34 url: &'a Url,
36 },
37
38 ExtractStarted {
40 path: &'a Path,
42
43 total_bytes: Option<u64>,
45 },
46
47 ExtractProgress {
49 path: &'a Path,
51
52 bytes: u64,
54
55 total_bytes: Option<u64>,
57 },
58
59 ExtractComplete {
61 path: &'a Path,
63 },
64}
65
66pub type ProgressFn = Arc<dyn Fn(ProgressEvent<'_>) + Send + Sync>;
68
69pub enum ProgressContext {
71 Download {
73 url: Url,
75 },
76
77 Extract {
79 path: PathBuf,
81 },
82}
83
84pub struct ProgressWriter<W> {
92 inner: W,
94
95 progress: Option<ProgressFn>,
97
98 context: ProgressContext,
100
101 written: u64,
103
104 total_bytes: Option<u64>,
106}
107
108impl<W> ProgressWriter<W> {
109 pub fn for_download(
111 progress: Option<ProgressFn>,
112 inner: W,
113 url: &Url,
114 total_bytes: Option<u64>,
115 ) -> Self {
116 let url = url.clone();
117
118 if let Some(progress) = &progress {
119 progress(ProgressEvent::DownloadStarted {
120 url: &url,
121 total_bytes,
122 });
123 }
124
125 Self {
126 inner,
127 progress,
128 context: ProgressContext::Download { url },
129 written: 0,
130 total_bytes,
131 }
132 }
133
134 pub fn for_extract(
136 progress: Option<ProgressFn>,
137 inner: W,
138 path: impl Into<PathBuf>,
139 total_bytes: Option<u64>,
140 ) -> Self {
141 let path = path.into();
142
143 if let Some(progress) = &progress {
144 progress(ProgressEvent::ExtractStarted {
145 path: &path,
146 total_bytes,
147 });
148 }
149
150 Self {
151 inner,
152 progress,
153 context: ProgressContext::Extract { path },
154 written: 0,
155 total_bytes,
156 }
157 }
158}
159
160impl<W> Write for ProgressWriter<W>
161where
162 W: Write,
163{
164 fn write(&mut self, buf: &[u8]) -> Result<usize, Error> {
165 let n = self.inner.write(buf)?;
166 self.written += n as u64;
167
168 if let Some(progress) = &self.progress {
169 match &self.context {
170 ProgressContext::Download { url } => {
171 progress(ProgressEvent::DownloadProgress {
172 url,
173 bytes: self.written,
174 total_bytes: self.total_bytes,
175 });
176 }
177 ProgressContext::Extract { path } => {
178 progress(ProgressEvent::ExtractProgress {
179 path,
180 bytes: self.written,
181 total_bytes: self.total_bytes,
182 });
183 }
184 }
185 }
186
187 Ok(n)
188 }
189
190 fn flush(&mut self) -> Result<(), Error> {
191 self.inner.flush()
192 }
193}
194
195impl<W> Drop for ProgressWriter<W> {
196 fn drop(&mut self) {
197 if let Some(progress) = &self.progress {
198 match &self.context {
199 ProgressContext::Download { url } => {
200 progress(ProgressEvent::DownloadComplete { url });
201 }
202 ProgressContext::Extract { path } => {
203 progress(ProgressEvent::ExtractComplete { path });
204 }
205 }
206 }
207 }
208}
209
210#[cfg(test)]
211mod tests {
212 use std::{
213 io::Write,
214 path::Path,
215 sync::{Arc, Mutex},
216 };
217
218 use super::*;
219
220 fn capture_progress() -> (ProgressFn, Arc<Mutex<Vec<String>>>) {
221 let events = Arc::new(Mutex::new(Vec::new()));
222 let events_clone = Arc::clone(&events);
223 let progress = Arc::new(move |event: ProgressEvent<'_>| {
224 let v = match event {
225 ProgressEvent::DownloadStarted { url, total_bytes } => {
226 format!("dl-start:{url}:{total_bytes:?}")
227 }
228 ProgressEvent::DownloadProgress {
229 url,
230 bytes,
231 total_bytes,
232 } => {
233 format!("dl-progress:{url}:{bytes}:{total_bytes:?}")
234 }
235 ProgressEvent::DownloadComplete { url } => {
236 format!("dl-complete:{url}")
237 }
238 ProgressEvent::ExtractStarted { path, total_bytes } => {
239 format!("ex-start:{}:{total_bytes:?}", path.display())
240 }
241 ProgressEvent::ExtractProgress {
242 path,
243 bytes,
244 total_bytes,
245 } => {
246 format!("ex-progress:{}:{bytes}:{total_bytes:?}", path.display())
247 }
248 ProgressEvent::ExtractComplete { path } => {
249 format!("ex-complete:{}", path.display())
250 }
251 };
252 events_clone.lock().unwrap().push(v);
253 });
254
255 (progress, events)
256 }
257
258 #[test]
259 fn download_emits_start_progress_complete() {
260 let (progress, events) = capture_progress();
261 let mut buf = Vec::new();
262 let url = Url::parse("http://example.com/file.pdb").unwrap();
263
264 {
265 let mut w = ProgressWriter::for_download(Some(progress), &mut buf, &url, Some(10));
266 w.write_all(b"hello").unwrap();
267 w.write_all(b"world").unwrap();
268 }
269
270 assert_eq!(buf, b"helloworld");
271 let events = events.lock().unwrap();
272 assert_eq!(events.len(), 4);
273 assert_eq!(events[0], "dl-start:http://example.com/file.pdb:Some(10)");
274 assert_eq!(
275 events[1],
276 "dl-progress:http://example.com/file.pdb:5:Some(10)"
277 );
278 assert_eq!(
279 events[2],
280 "dl-progress:http://example.com/file.pdb:10:Some(10)"
281 );
282 assert_eq!(events[3], "dl-complete:http://example.com/file.pdb");
283 }
284
285 #[test]
286 fn extract_emits_start_progress_complete() {
287 let (progress, events) = capture_progress();
288 let mut buf = Vec::new();
289
290 {
291 let mut w = ProgressWriter::for_extract(
292 Some(progress),
293 &mut buf,
294 Path::new("/tmp/vmlinux"),
295 Some(100),
296 );
297 w.write_all(b"data").unwrap();
298 }
299
300 assert_eq!(buf, b"data");
301 let events = events.lock().unwrap();
302 assert_eq!(events.len(), 3);
303 assert_eq!(events[0], "ex-start:/tmp/vmlinux:Some(100)");
304 assert_eq!(events[1], "ex-progress:/tmp/vmlinux:4:Some(100)");
305 assert_eq!(events[2], "ex-complete:/tmp/vmlinux");
306 }
307
308 #[test]
309 fn none_progress_is_passthrough() {
310 let mut buf = Vec::new();
311 let url = Url::parse("http://example.com/file.pdb").unwrap();
312
313 {
314 let mut w = ProgressWriter::for_download(None, &mut buf, &url, Some(10));
315 w.write_all(b"hello").unwrap();
316 }
317
318 assert_eq!(buf, b"hello");
319 }
320
321 #[test]
322 fn unknown_total_bytes() {
323 let (progress, events) = capture_progress();
324 let mut buf = Vec::new();
325 let url = Url::parse("http://example.com/x").unwrap();
326 {
327 let mut w = ProgressWriter::for_download(Some(progress), &mut buf, &url, None);
328 w.write_all(b"abc").unwrap();
329 }
330 let events = events.lock().unwrap();
331 assert_eq!(events[0], "dl-start:http://example.com/x:None");
332 assert_eq!(events[1], "dl-progress:http://example.com/x:3:None");
333 assert_eq!(events[2], "dl-complete:http://example.com/x");
334 }
335}