1#![deny(unsafe_code)]
2#![deny(missing_docs)]
3
4pub use indicatif::ProgressFinish;
9
10use indicatif::{MultiProgress, ProgressBar, ProgressStyle};
11use reqwest::Request;
12use tokio::io::AsyncWriteExt;
13use tokio::sync::Semaphore;
14
15type Result<T> = std::result::Result<T, Box<dyn std::error::Error + Send + Sync>>;
16
17pub struct RetrieverBuilder {
51 workers: usize,
52 show_progress_bar: bool,
53 pb_style: Option<ProgressStyle>,
54 pb_finish: ProgressFinish,
55}
56
57impl Default for RetrieverBuilder {
58 fn default() -> Self {
60 Self {
61 workers: 10,
62 show_progress_bar: false,
63 pb_style: Some(
64 ProgressStyle::with_template(
65 "[{elapsed_precise}] [{bar:40.cyan/blue}] {bytes}/{total_bytes} {msg}",
66 )
67 .expect("progress bar template should compile")
68 .progress_chars("=>-"),
69 ),
70 pb_finish: ProgressFinish::AndLeave,
71 }
72 }
73}
74
75impl RetrieverBuilder {
76 pub fn new() -> Self {
78 Self::default()
79 }
80
81 pub fn show_progress(mut self, show_progress_bar: bool) -> Self {
83 self.show_progress_bar = show_progress_bar;
84 self
85 }
86
87 pub fn progress_style(mut self, pb_style: ProgressStyle) -> Self {
89 self.pb_style = Some(pb_style);
90 self
91 }
92
93 pub fn with_finish(mut self, pb_finish: ProgressFinish) -> Self {
95 self.pb_finish = pb_finish;
96 self
97 }
98
99 pub fn workers(mut self, workers: usize) -> Self {
101 self.workers = workers;
102 self
103 }
104
105 pub fn build(self) -> Retriever {
107 Retriever {
108 client: reqwest::Client::new(),
109 job_semaphore: Semaphore::new(self.workers),
110 mp: if self.show_progress_bar {
111 Some(MultiProgress::new())
112 } else {
113 None
114 },
115 pb_style: self.pb_style,
116 pb_finish: self.pb_finish,
117 }
118 }
119}
120
121pub struct Retriever {
152 client: reqwest::Client,
153 job_semaphore: Semaphore,
154 mp: Option<MultiProgress>,
155 pb_style: Option<ProgressStyle>,
156 pb_finish: ProgressFinish,
157}
158
159impl Default for Retriever {
160 fn default() -> Self {
162 RetrieverBuilder::new().build()
163 }
164}
165
166impl Retriever {
167 pub fn with_progress_bar() -> Self {
169 RetrieverBuilder::new().show_progress(true).build()
170 }
171
172 pub async fn download_file<W>(&self, request: Request, mut writer: W) -> Result<()>
174 where
175 W: AsyncWriteExt + Unpin,
176 {
177 let _permit = self.job_semaphore.acquire().await?;
178
179 let path = String::from(request.url().path());
180 let mut resp = self.client.execute(request).await?;
181
182 let mut pb = ProgressBar::hidden();
183 if let Some(m) = &self.mp {
184 if let Some(pb_style) = &self.pb_style {
185 pb = m.add(
186 ProgressBar::no_length()
187 .with_style(pb_style.clone())
188 .with_message(path)
189 .with_finish(self.pb_finish.clone()),
190 );
191
192 if let Some(total_size) = resp.content_length() {
193 pb.set_length(total_size);
194 }
195 }
196 }
197
198 while let Some(chunk) = resp.chunk().await? {
199 writer.write_all(chunk.as_ref()).await?;
200 writer.flush().await?;
201
202 pb.inc(chunk.len() as u64);
203 }
204
205 pb.set_length(pb.position());
206 pb.finish_using_style();
207
208 drop(_permit);
209
210 Ok(())
211 }
212}
213
214#[cfg(test)]
215mod tests {
216 use super::*;
217 use mockito::Matcher;
218 use reqwest::Client;
219 use tokio::{fs::OpenOptions, io::AsyncReadExt};
220
221 #[tokio::test]
222 async fn download_single() {
223 let mut server = mockito::Server::new_async().await;
224
225 let mock = server
226 .mock("GET", Matcher::Regex(r"/\d".to_string()))
227 .with_status(200)
228 .with_body("hello")
229 .create();
230
231 let retriever = RetrieverBuilder::new()
232 .show_progress(false)
233 .workers(1)
234 .build();
235
236 let req = Client::new()
237 .get(format!("{}/1", server.url()))
238 .build()
239 .expect("failed to build request");
240
241 let file_path = "/tmp/test";
242 let file = OpenOptions::new()
243 .create(true)
244 .write(true)
245 .truncate(true)
246 .open(file_path)
247 .await
248 .expect("failed to open file for writing");
249
250 let _ = retriever
251 .download_file(req, file)
252 .await
253 .expect("failed to download");
254
255 let mut file = OpenOptions::new()
256 .read(true)
257 .open(file_path)
258 .await
259 .expect("failed to open file for reading");
260
261 let mut contents = String::new();
262 file.read_to_string(&mut contents)
263 .await
264 .expect("failed to read file");
265
266 assert_eq!(contents, "hello");
267
268 mock.assert();
269 }
270
271 #[tokio::test]
272 async fn download_multi() {
273 use std::sync::Arc;
274 use tokio::task::JoinSet;
275
276 let mut server = mockito::Server::new_async().await;
277
278 let mock = server
279 .mock("GET", Matcher::Regex(r"/\d".to_string()))
280 .with_status(200)
281 .with_body("hello")
282 .expect(10)
283 .create();
284
285 let retriever = Arc::new(
286 RetrieverBuilder::new()
287 .show_progress(true)
288 .progress_style(
289 ProgressStyle::with_template("{bytes}/{total_bytes} {msg}")
290 .expect("progress bar template should compile"),
291 )
292 .with_finish(ProgressFinish::WithMessage("done".into()))
293 .build(),
294 );
295
296 let mut set = JoinSet::new();
297
298 for i in 0..10 {
299 let ret = Arc::clone(&retriever);
300 let req = Client::new()
301 .get(format!("{}/{}", server.url(), i))
302 .build()
303 .expect("request should build");
304
305 let file = OpenOptions::new()
306 .create(true)
307 .write(true)
308 .truncate(true)
309 .open(format!("/tmp/test{}", i))
310 .await
311 .expect("file should be accessible");
312
313 set.spawn(async move { ret.download_file(req, file).await });
314 }
315
316 while let Some(download_result) = set.join_next().await {
317 assert!(!download_result.is_err());
318 }
319
320 mock.assert();
321 }
322}