file_retriever/
lib.rs

1#![deny(unsafe_code)]
2#![deny(missing_docs)]
3
4//! Asynchronous download with (optional) progress bar and limited amount of workers.
5//!
6//! Retriever is based on tokio and reqwest crates dancing together in a beautiful tango.
7
8pub use indicatif::ProgressFinish;
9
10use indicatif::{MultiProgress, ProgressBar, ProgressStyle};
11use reqwest::Request;
12use tokio::io::AsyncWriteExt;
13use tokio::sync::Semaphore;
14
15type Result<T> = std::result::Result<T, Box<dyn std::error::Error + Send + Sync>>;
16
17/// Factory which is used to configure the properties of a new Retriever.
18///
19/// # Examples
20///
21/// ```
22/// use reqwest::Client;
23/// use file_retriever::RetrieverBuilder;
24/// use tokio::fs::OpenOptions;
25///
26/// #[tokio::main]
27/// async fn main() {
28///     // build a retriever
29///     let retriever = RetrieverBuilder::new()
30///         .show_progress(true)
31///         .workers(42)
32///         .build();
33///
34///     // open a file to write to
35///     let file = OpenOptions::new()
36///         .create(true)
37///         .write(true)
38///         .truncate(true)
39///         .open("index.html")
40///         .await
41///         .expect("should return file");
42///
43///     // setup a request to retrieve the file
44///     let req = Client::new().get("https://example.com").build().unwrap();
45///
46///     // download a file
47///     let _  = retriever.download_file(req, file).await;
48/// }
49/// ```
50pub struct RetrieverBuilder {
51    workers: usize,
52    show_progress_bar: bool,
53    pb_style: Option<ProgressStyle>,
54    pb_finish: ProgressFinish,
55}
56
57impl Default for RetrieverBuilder {
58    /// Creates a new Retriever builder.
59    fn default() -> Self {
60        Self {
61            workers: 10,
62            show_progress_bar: false,
63            pb_style: Some(
64                ProgressStyle::with_template(
65                    "[{elapsed_precise}] [{bar:40.cyan/blue}] {bytes}/{total_bytes} {msg}",
66                )
67                .expect("progress bar template should compile")
68                .progress_chars("=>-"),
69            ),
70            pb_finish: ProgressFinish::AndLeave,
71        }
72    }
73}
74
75impl RetrieverBuilder {
76    /// Creates a new Retriever builder.
77    pub fn new() -> Self {
78        Self::default()
79    }
80
81    /// Sets if progress bar will be shown.
82    pub fn show_progress(mut self, show_progress_bar: bool) -> Self {
83        self.show_progress_bar = show_progress_bar;
84        self
85    }
86
87    /// Sets progress bar style.
88    pub fn progress_style(mut self, pb_style: ProgressStyle) -> Self {
89        self.pb_style = Some(pb_style);
90        self
91    }
92
93    /// Sets progress bar finish behavior.
94    pub fn with_finish(mut self, pb_finish: ProgressFinish) -> Self {
95        self.pb_finish = pb_finish;
96        self
97    }
98
99    /// Sets the number of workers.
100    pub fn workers(mut self, workers: usize) -> Self {
101        self.workers = workers;
102        self
103    }
104
105    /// Creates a Retriever with the configured options.
106    pub fn build(self) -> Retriever {
107        Retriever {
108            client: reqwest::Client::new(),
109            job_semaphore: Semaphore::new(self.workers),
110            mp: if self.show_progress_bar {
111                Some(MultiProgress::new())
112            } else {
113                None
114            },
115            pb_style: self.pb_style,
116            pb_finish: self.pb_finish,
117        }
118    }
119}
120
121/// Provides an easy interface for parallel downloads with limited workers and progress bar
122///
123/// # Examples
124///
125/// ```
126/// use reqwest::Client;
127/// use file_retriever::Retriever;
128/// use tokio::fs::OpenOptions;
129///
130/// #[tokio::main]
131/// async fn main() {
132///     // create a retriever
133///     let retriever = Retriever::with_progress_bar();
134///
135///     // open a file to write to
136///     let file = OpenOptions::new()
137///         .create(true)
138///         .write(true)
139///         .truncate(true)
140///         .open("index.html")
141///         .await
142///         .expect("should return file");
143///
144///     // setup a request to retrieve the file
145///     let req = Client::new().get("https://example.com").build().unwrap();
146///
147///     // download a file
148///     let _  = retriever.download_file(req, file).await;
149/// }
150/// ```
151pub struct Retriever {
152    client: reqwest::Client,
153    job_semaphore: Semaphore,
154    mp: Option<MultiProgress>,
155    pb_style: Option<ProgressStyle>,
156    pb_finish: ProgressFinish,
157}
158
159impl Default for Retriever {
160    /// Create a default retriever with 10 workers
161    fn default() -> Self {
162        RetrieverBuilder::new().build()
163    }
164}
165
166impl Retriever {
167    /// Same as default retriever but showing progress bar
168    pub fn with_progress_bar() -> Self {
169        RetrieverBuilder::new().show_progress(true).build()
170    }
171
172    /// Makes a request using a request and writes output into writer
173    pub async fn download_file<W>(&self, request: Request, mut writer: W) -> Result<()>
174    where
175        W: AsyncWriteExt + Unpin,
176    {
177        let _permit = self.job_semaphore.acquire().await?;
178
179        let path = String::from(request.url().path());
180        let mut resp = self.client.execute(request).await?;
181
182        let mut pb = ProgressBar::hidden();
183        if let Some(m) = &self.mp {
184            if let Some(pb_style) = &self.pb_style {
185                pb = m.add(
186                    ProgressBar::no_length()
187                        .with_style(pb_style.clone())
188                        .with_message(path)
189                        .with_finish(self.pb_finish.clone()),
190                );
191
192                if let Some(total_size) = resp.content_length() {
193                    pb.set_length(total_size);
194                }
195            }
196        }
197
198        while let Some(chunk) = resp.chunk().await? {
199            writer.write_all(chunk.as_ref()).await?;
200            writer.flush().await?;
201
202            pb.inc(chunk.len() as u64);
203        }
204
205        pb.set_length(pb.position());
206        pb.finish_using_style();
207
208        drop(_permit);
209
210        Ok(())
211    }
212}
213
214#[cfg(test)]
215mod tests {
216    use super::*;
217    use mockito::Matcher;
218    use reqwest::Client;
219    use tokio::{fs::OpenOptions, io::AsyncReadExt};
220
221    #[tokio::test]
222    async fn download_single() {
223        let mut server = mockito::Server::new_async().await;
224
225        let mock = server
226            .mock("GET", Matcher::Regex(r"/\d".to_string()))
227            .with_status(200)
228            .with_body("hello")
229            .create();
230
231        let retriever = RetrieverBuilder::new()
232            .show_progress(false)
233            .workers(1)
234            .build();
235
236        let req = Client::new()
237            .get(format!("{}/1", server.url()))
238            .build()
239            .expect("failed to build request");
240
241        let file_path = "/tmp/test";
242        let file = OpenOptions::new()
243            .create(true)
244            .write(true)
245            .truncate(true)
246            .open(file_path)
247            .await
248            .expect("failed to open file for writing");
249
250        let _ = retriever
251            .download_file(req, file)
252            .await
253            .expect("failed to download");
254
255        let mut file = OpenOptions::new()
256            .read(true)
257            .open(file_path)
258            .await
259            .expect("failed to open file for reading");
260
261        let mut contents = String::new();
262        file.read_to_string(&mut contents)
263            .await
264            .expect("failed to read file");
265
266        assert_eq!(contents, "hello");
267
268        mock.assert();
269    }
270
271    #[tokio::test]
272    async fn download_multi() {
273        use std::sync::Arc;
274        use tokio::task::JoinSet;
275
276        let mut server = mockito::Server::new_async().await;
277
278        let mock = server
279            .mock("GET", Matcher::Regex(r"/\d".to_string()))
280            .with_status(200)
281            .with_body("hello")
282            .expect(10)
283            .create();
284
285        let retriever = Arc::new(
286            RetrieverBuilder::new()
287                .show_progress(true)
288                .progress_style(
289                    ProgressStyle::with_template("{bytes}/{total_bytes} {msg}")
290                        .expect("progress bar template should compile"),
291                )
292                .with_finish(ProgressFinish::WithMessage("done".into()))
293                .build(),
294        );
295
296        let mut set = JoinSet::new();
297
298        for i in 0..10 {
299            let ret = Arc::clone(&retriever);
300            let req = Client::new()
301                .get(format!("{}/{}", server.url(), i))
302                .build()
303                .expect("request should build");
304
305            let file = OpenOptions::new()
306                .create(true)
307                .write(true)
308                .truncate(true)
309                .open(format!("/tmp/test{}", i))
310                .await
311                .expect("file should be accessible");
312
313            set.spawn(async move { ret.download_file(req, file).await });
314        }
315
316        while let Some(download_result) = set.join_next().await {
317            assert!(!download_result.is_err());
318        }
319
320        mock.assert();
321    }
322}