1use super::DownloadResult;
2use crate::{ConnectErrorKind, Event, SeqWriter};
3use reqwest::{Client, IntoUrl};
4use std::{
5 sync::{
6 atomic::{AtomicBool, Ordering},
7 Arc,
8 },
9 time::Duration,
10};
11use tokio::sync::mpsc;
12
13#[derive(Debug, Clone)]
14pub struct DownloadOptions {
15 pub client: Client,
16 pub retry_gap: Duration,
17}
18
19pub async fn download(
20 url: impl IntoUrl,
21 mut writer: impl SeqWriter + 'static,
22 options: DownloadOptions,
23) -> Result<DownloadResult, reqwest::Error> {
24 let url = url.into_url()?;
25 let (tx, event_chain) = mpsc::channel(1024);
26 let (tx_write, mut rx_write) = mpsc::channel(1024);
27 let tx_clone = tx.clone();
28 let handle = tokio::spawn(async move {
29 while let Some((spin, data)) = rx_write.recv().await {
30 loop {
31 match writer.write_sequentially(&data).await {
32 Ok(_) => break,
33 Err(e) => tx_clone.send(Event::WriteError(e)).await.unwrap(),
34 }
35 tokio::time::sleep(options.retry_gap).await;
36 }
37 tx_clone.send(Event::WriteProgress(spin)).await.unwrap();
38 }
39 loop {
40 match writer.flush().await {
41 Ok(_) => break,
42 Err(e) => tx_clone.send(Event::WriteError(e)).await.unwrap(),
43 };
44 tokio::time::sleep(options.retry_gap).await;
45 }
46 });
47 let running = Arc::new(AtomicBool::new(true));
48 let running_clone = running.clone();
49 tokio::spawn(async move {
50 let mut downloaded: u64 = 0;
51 let mut response = loop {
52 if !running.load(Ordering::Relaxed) {
53 tx.send(Event::Abort(0)).await.unwrap();
54 return;
55 }
56 tx.send(Event::Connecting(0)).await.unwrap();
57 match options.client.get(url.clone()).send().await {
58 Ok(response) => break response,
59 Err(e) => tx
60 .send(Event::ConnectError(0, ConnectErrorKind::Reqwest(e)))
61 .await
62 .unwrap(),
63 }
64 tokio::time::sleep(options.retry_gap).await;
65 };
66 tx.send(Event::Downloading(0)).await.unwrap();
67 loop {
68 let chunk = loop {
69 if !running.load(Ordering::Relaxed) {
70 tx.send(Event::Abort(0)).await.unwrap();
71 return;
72 }
73 match response.chunk().await {
74 Ok(chunk) => break chunk,
75 Err(e) => tx.send(Event::DownloadError(0, e)).await.unwrap(),
76 }
77 tokio::time::sleep(options.retry_gap).await;
78 };
79 if chunk.is_none() {
80 break;
81 }
82 let chunk = chunk.unwrap();
83 let len = chunk.len() as u64;
84 let span = downloaded..(downloaded + len);
85 tx.send(Event::DownloadProgress(span.clone()))
86 .await
87 .unwrap();
88 tx_write.send((span, chunk)).await.unwrap();
89 downloaded += len as u64;
90 }
91 tx.send(Event::Finished(0)).await.unwrap();
92 });
93 Ok(DownloadResult::new(
94 event_chain,
95 handle,
96 Box::new(move || {
97 running_clone.store(false, Ordering::Relaxed);
98 }),
99 ))
100}
101
102#[cfg(test)]
103#[cfg(feature = "file")]
104mod tests {
105 use super::*;
106 use crate::writer::file::SeqFileWriter;
107 use crate::Total;
108 use tempfile::NamedTempFile;
109 use tokio::fs::File;
110 use tokio::io::AsyncReadExt;
111
112 fn build_mock_data(size: usize) -> Vec<u8> {
113 (0..size).map(|i| (i % 256) as u8).collect()
114 }
115
116 #[tokio::test]
117 async fn test_downloads_small_file_correctly() {
118 let mock_body = b"test data";
120 let mut server = mockito::Server::new_async().await;
121 let mock = server
122 .mock("GET", "/small")
123 .with_status(200)
124 .with_body(mock_body)
125 .create_async()
126 .await;
127
128 let temp_file = NamedTempFile::new().unwrap();
129 let file = temp_file.reopen().unwrap().into();
130
131 let client = Client::new();
132 let result = download(
133 format!("{}/small", server.url()),
134 SeqFileWriter::new(file, 8 * 1024 * 1024),
135 DownloadOptions {
136 client,
137 retry_gap: Duration::from_secs(1),
138 },
139 )
140 .await
141 .unwrap();
142
143 let mut progress_events = Vec::new();
144 let mut rx = result.event_chain.lock().await;
145 while let Some(event) = rx.recv().await {
146 progress_events.push(event);
147 }
148 dbg!(&progress_events);
149 result.join().await.unwrap();
150
151 let mut file_content = Vec::new();
152 File::open(temp_file.path())
153 .await
154 .unwrap()
155 .read_to_end(&mut file_content)
156 .await
157 .unwrap();
158 assert_eq!(file_content, mock_body);
159
160 assert_eq!(
161 progress_events
162 .iter()
163 .map(|m| if let Event::DownloadProgress(p) = m {
164 p.total()
165 } else {
166 0
167 })
168 .sum::<u64>(),
169 mock_body.len() as u64
170 );
171 assert_eq!(
172 progress_events
173 .iter()
174 .map(|m| if let Event::WriteProgress(p) = m {
175 p.total()
176 } else {
177 0
178 })
179 .sum::<u64>(),
180 mock_body.len() as u64
181 );
182 mock.assert_async().await;
183 }
184
185 #[tokio::test]
186 async fn test_downloads_empty_file_correctly() {
187 let mock_body = b"";
189 let mut server = mockito::Server::new_async().await;
190 let mock = server
191 .mock("GET", "/empty")
192 .with_status(200)
193 .with_body(mock_body)
194 .create_async()
195 .await;
196
197 let temp_file = NamedTempFile::new().unwrap();
198 let file = temp_file.reopen().unwrap().into();
199
200 let client = Client::new();
201 let result = download(
202 format!("{}/empty", server.url()),
203 SeqFileWriter::new(file, 8 * 1024 * 1024),
204 DownloadOptions {
205 client,
206 retry_gap: Duration::from_secs(1),
207 },
208 )
209 .await
210 .unwrap();
211
212 let mut progress_events = Vec::new();
213 let mut rx = result.event_chain.lock().await;
214 while let Some(event) = rx.recv().await {
215 progress_events.push(event);
216 }
217 dbg!(&progress_events);
218 result.join().await.unwrap();
219
220 let mut file_content = Vec::new();
222 File::open(temp_file.path())
223 .await
224 .unwrap()
225 .read_to_end(&mut file_content)
226 .await
227 .unwrap();
228 assert!(file_content.is_empty());
229
230 assert_eq!(
232 progress_events
233 .iter()
234 .map(|m| if let Event::DownloadProgress(p) = m {
235 p.total()
236 } else {
237 0
238 })
239 .sum::<u64>(),
240 mock_body.len() as u64
241 );
242 assert_eq!(
243 progress_events
244 .iter()
245 .map(|m| if let Event::WriteProgress(p) = m {
246 p.total()
247 } else {
248 0
249 })
250 .sum::<u64>(),
251 mock_body.len() as u64
252 );
253 mock.assert_async().await;
254 }
255
256 #[tokio::test]
257 async fn test_downloads_large_file_correctly() {
258 let mock_body = build_mock_data(5000);
259 let mut server = mockito::Server::new_async().await;
260 let mock = server
261 .mock("GET", "/large")
262 .with_status(200)
263 .with_body(&mock_body)
264 .create_async()
265 .await;
266
267 let temp_file = NamedTempFile::new().unwrap();
268 let file = temp_file.reopen().unwrap().into();
269
270 let client = Client::new();
271 let result = download(
272 format!("{}/large", server.url()),
273 SeqFileWriter::new(file, 8 * 1024 * 1024),
274 DownloadOptions {
275 client,
276 retry_gap: Duration::from_secs(1),
277 },
278 )
279 .await
280 .unwrap();
281
282 let mut progress_events = Vec::new();
283 let mut rx = result.event_chain.lock().await;
284 while let Some(event) = rx.recv().await {
285 progress_events.push(event);
286 }
287 dbg!(&progress_events);
288 result.join().await.unwrap();
289
290 let mut file_content = Vec::new();
292 File::open(temp_file.path())
293 .await
294 .unwrap()
295 .read_to_end(&mut file_content)
296 .await
297 .unwrap();
298 assert_eq!(file_content, mock_body);
299
300 assert_eq!(
302 progress_events
303 .iter()
304 .map(|m| if let Event::DownloadProgress(p) = m {
305 p.total()
306 } else {
307 0
308 })
309 .sum::<u64>(),
310 mock_body.len() as u64
311 );
312 assert_eq!(
313 progress_events
314 .iter()
315 .map(|m| if let Event::WriteProgress(p) = m {
316 p.total()
317 } else {
318 0
319 })
320 .sum::<u64>(),
321 mock_body.len() as u64
322 );
323 mock.assert_async().await;
324 }
325
326 #[tokio::test]
327 async fn test_downloads_exact_buffer_size_file() {
328 let mock_body = build_mock_data(4096);
329 let mut server = mockito::Server::new_async().await;
330 let mock = server
331 .mock("GET", "/exact_buffer_size_file")
332 .with_status(200)
333 .with_body(&mock_body)
334 .create_async()
335 .await;
336
337 let temp_file = NamedTempFile::new().unwrap();
338 let file = temp_file.reopen().unwrap().into();
339
340 let client = Client::new();
341 let result = download(
342 format!("{}/exact_buffer_size_file", server.url()),
343 SeqFileWriter::new(file, 8 * 1024 * 1024),
344 DownloadOptions {
345 client,
346 retry_gap: Duration::from_secs(1),
347 },
348 )
349 .await
350 .unwrap();
351
352 let mut progress_events = Vec::new();
353 let mut rx = result.event_chain.lock().await;
354 while let Some(event) = rx.recv().await {
355 progress_events.push(event);
356 }
357 dbg!(&progress_events);
358 result.join().await.unwrap();
359
360 let mut file_content = Vec::new();
362 File::open(temp_file.path())
363 .await
364 .unwrap()
365 .read_to_end(&mut file_content)
366 .await
367 .unwrap();
368 assert_eq!(file_content, mock_body);
369
370 assert_eq!(
372 progress_events
373 .iter()
374 .map(|m| if let Event::DownloadProgress(p) = m {
375 p.total()
376 } else {
377 0
378 })
379 .sum::<u64>(),
380 mock_body.len() as u64
381 );
382 assert_eq!(
383 progress_events
384 .iter()
385 .map(|m| if let Event::WriteProgress(p) = m {
386 p.total()
387 } else {
388 0
389 })
390 .sum::<u64>(),
391 mock_body.len() as u64
392 );
393 mock.assert_async().await;
394 }
395}