1use super::DownloadResult;
2use crate::{ConnectErrorKind, Event, ProgressEntry, RandWriter, Total};
3use bytes::Bytes;
4use fast_steal::{SplitTask, StealTask, Task, TaskList};
5use reqwest::{header, Client, IntoUrl, StatusCode};
6use std::{
7 sync::{
8 atomic::{AtomicBool, Ordering},
9 Arc,
10 },
11 time::Duration,
12};
13use tokio::sync::{mpsc, Mutex};
14
15#[derive(Debug, Clone)]
16pub struct DownloadOptions {
17 pub threads: usize,
18 pub client: Client,
19 pub download_chunks: Vec<ProgressEntry>,
20 pub retry_gap: Duration,
21}
22
23pub async fn download(
24 url: impl IntoUrl,
25 mut writer: impl RandWriter + 'static,
26 options: DownloadOptions,
27) -> Result<DownloadResult, reqwest::Error> {
28 let url = url.into_url()?;
29 let (tx, event_chain) = mpsc::channel(1024);
30 let (tx_write, mut rx_write) = mpsc::channel::<(ProgressEntry, Bytes)>(1024);
31 let tx_clone = tx.clone();
32 let handle = tokio::spawn(async move {
33 while let Some((spin, data)) = rx_write.recv().await {
34 loop {
35 match writer.write_randomly(spin.clone(), &data).await {
36 Ok(_) => break,
37 Err(e) => tx_clone.send(Event::WriteError(e)).await.unwrap(),
38 }
39 tokio::time::sleep(options.retry_gap).await;
40 }
41 tx_clone.send(Event::WriteProgress(spin)).await.unwrap();
42 }
43 loop {
44 match writer.flush().await {
45 Ok(_) => break,
46 Err(e) => tx_clone.send(Event::WriteError(e)).await.unwrap(),
47 };
48 tokio::time::sleep(options.retry_gap).await;
49 }
50 });
51 let mutex = Arc::new(Mutex::new(()));
52 let task_list = Arc::new(TaskList::from(options.download_chunks));
53 let tasks = Arc::new(
54 Task::from(&*task_list)
55 .split_task(options.threads as u64)
56 .map(|t| Arc::new(t))
57 .collect::<Vec<_>>(),
58 );
59 let running = Arc::new(AtomicBool::new(true));
60 let running_clone = running.clone();
61 let client = Arc::new(options.client);
62 let url = Arc::new(url);
63 for (id, task) in tasks.iter().enumerate() {
64 let task = task.clone();
65 let tasks = tasks.clone();
66 let task_list = task_list.clone();
67 let mutex = mutex.clone();
68 let tx = tx.clone();
69 let running = running.clone();
70 let client = client.clone();
71 let url = url.clone();
72 let tx_write = tx_write.clone();
73 tokio::spawn(async move {
74 'a: loop {
75 if !running.load(Ordering::Relaxed) {
76 tx.send(Event::Abort(id)).await.unwrap();
77 return;
78 }
79 let mut start = task.start();
80 if start >= task.end() {
81 let guard = mutex.lock().await;
82 if task.steal(&tasks, 2) {
83 continue;
84 }
85 drop(guard);
86 tx.send(Event::Finished(id)).await.unwrap();
87 return;
88 }
89 let download_range = &task_list.get_range(start..task.end());
90 for range in download_range {
91 let header_range_value = format!("bytes={}-{}", range.start, range.end - 1);
92 let mut response = loop {
93 if !running.load(Ordering::Relaxed) {
94 tx.send(Event::Abort(id)).await.unwrap();
95 return;
96 }
97 tx.send(Event::Connecting(id)).await.unwrap();
98 match client
99 .get(url.as_str())
100 .header(header::RANGE, &header_range_value)
101 .send()
102 .await
103 {
104 Ok(response) if response.status() == StatusCode::PARTIAL_CONTENT => {
105 break response
106 }
107 Ok(response) => tx.send(Event::ConnectError(
108 id,
109 ConnectErrorKind::StatusCode(response.status()),
110 )),
111 Err(e) => {
112 tx.send(Event::ConnectError(id, ConnectErrorKind::Reqwest(e)))
113 }
114 }
115 .await
116 .unwrap();
117 tokio::time::sleep(options.retry_gap).await;
118 };
119 tx.send(Event::Downloading(id)).await.unwrap();
120 let mut downloaded = 0;
121 loop {
122 let chunk = loop {
123 if !running.load(Ordering::Relaxed) {
124 tx.send(Event::Abort(id)).await.unwrap();
125 return;
126 }
127 match response.chunk().await {
128 Ok(chunk) => break chunk,
129 Err(e) => tx.send(Event::DownloadError(id, e)).await.unwrap(),
130 }
131 tokio::time::sleep(options.retry_gap).await;
132 };
133 if chunk.is_none() {
134 break;
135 }
136 let mut chunk = chunk.unwrap();
137 let len = chunk.len() as u64;
138 task.fetch_add_start(len);
139 start += len;
140 let range_start = range.start + downloaded;
141 downloaded += len;
142 let range_end = range.start + downloaded;
143 let span = range_start..range_end.min(task_list.get(task.end()));
144 let len = span.total();
145 tx.send(Event::DownloadProgress(span.clone()))
146 .await
147 .unwrap();
148 tx_write
149 .send((span, chunk.split_to(len as usize)))
150 .await
151 .unwrap();
152 if start >= task.end() {
153 continue 'a;
154 }
155 }
156 }
157 }
158 });
159 }
160 Ok(DownloadResult::new(
161 event_chain,
162 handle,
163 Box::new(move || {
164 running_clone.store(false, Ordering::Relaxed);
165 }),
166 ))
167}
168
169#[cfg(test)]
170mod tests {
171 use super::*;
172 #[cfg(feature = "file")]
173 use crate::writer::file::rand_file_writer_mmap::RandFileWriter;
174 use crate::{MergeProgress, ProgressEntry};
175 use tempfile::NamedTempFile;
176
177 fn build_mock_data(size: usize) -> Vec<u8> {
178 (0..size).map(|i| (i % 256) as u8).collect()
179 }
180
181 pub fn reverse_progress(progress: &[ProgressEntry], total_size: u64) -> Vec<ProgressEntry> {
182 if progress.is_empty() {
183 return vec![0..total_size];
184 }
185 let mut result = Vec::with_capacity(progress.len());
186 let mut prev_end = 0;
187 for range in progress {
188 if range.start > prev_end {
189 result.push(prev_end..range.start);
190 }
191 prev_end = range.end;
192 }
193 if prev_end < total_size {
194 result.push(prev_end..total_size);
195 }
196 result
197 }
198
199 #[cfg(feature = "file")]
200 #[tokio::test]
201 async fn test_multi_thread_regular_download() {
202 use tokio::{fs::File, io::AsyncReadExt};
203
204 let mock_body = build_mock_data(3 * 1024);
205 let mock_body_clone = mock_body.clone();
206 let mut server = mockito::Server::new_async().await;
207 server
208 .mock("GET", "/mutli-2")
209 .with_status(206)
210 .with_body_from_request(move |request| {
211 if !request.has_header("Range") {
212 return mock_body_clone.clone();
213 }
214 let range = request.header("Range")[0];
215 println!("range: {:?}", range);
216 range
217 .to_str()
218 .unwrap()
219 .rsplit('=')
220 .next()
221 .unwrap()
222 .split(',')
223 .map(|p| p.trim().splitn(2, '-'))
224 .map(|mut p| {
225 let start = p.next().unwrap().parse::<usize>().unwrap();
226 let end = p.next().unwrap().parse::<usize>().unwrap();
227 start..=end
228 })
229 .flat_map(|p| mock_body_clone[p].to_vec())
230 .collect()
231 })
232 .create_async()
233 .await;
234
235 let temp_file = NamedTempFile::new().unwrap();
236 let file = temp_file.reopen().unwrap().into();
237
238 let client = Client::new();
239 let download_chunks = vec![0..mock_body.len() as u64];
240 let result = download(
241 format!("{}/mutli-2", server.url()),
242 RandFileWriter::new(file, mock_body.len() as u64, 8 * 1024 * 1024)
243 .await
244 .unwrap(),
245 DownloadOptions {
246 client,
247 threads: 32,
248 download_chunks: download_chunks.clone(),
249 retry_gap: Duration::from_secs(1),
250 },
251 )
252 .await
253 .unwrap();
254
255 let mut download_progress: Vec<ProgressEntry> = Vec::new();
256 let mut write_progress: Vec<ProgressEntry> = Vec::new();
257 let mut rx = result.event_chain.lock().await;
258 while let Some(e) = rx.recv().await {
259 match e {
260 Event::DownloadProgress(p) => {
261 download_progress.merge_progress(p);
262 }
263 Event::WriteProgress(p) => {
264 write_progress.merge_progress(p);
265 }
266 _ => {}
267 }
268 }
269 dbg!(&download_progress);
270 dbg!(&write_progress);
271 assert_eq!(download_progress, download_chunks);
272 assert_eq!(write_progress, download_chunks);
273
274 result.join().await.unwrap();
275
276 let output = {
277 let mut data = Vec::with_capacity(mock_body.len());
278 for _ in 0..mock_body.len() {
279 data.push(0);
280 }
281 for chunk in download_chunks.clone() {
282 for i in chunk {
283 data[i as usize] = mock_body[i as usize];
284 }
285 }
286 data
287 };
288 let mut file_content = Vec::new();
289 File::open(temp_file.path())
290 .await
291 .unwrap()
292 .read_to_end(&mut file_content)
293 .await
294 .unwrap();
295 assert_eq!(file_content, output);
296 }
297
298 #[cfg(feature = "file")]
299 #[tokio::test]
300 async fn test_multi_thread_download_chunk() {
301 use tokio::{fs::File, io::AsyncReadExt};
302
303 let mock_body = build_mock_data(3 * 1024);
304 let mock_body_clone = mock_body.clone();
305 let mut server = mockito::Server::new_async().await;
306 server
307 .mock("GET", "/multi-2")
308 .with_status(206)
309 .with_body_from_request(move |request| {
310 if !request.has_header("Range") {
311 return mock_body_clone.clone();
312 }
313 let range = request.header("Range")[0];
314 println!("range: {:?}", range);
315 range
316 .to_str()
317 .unwrap()
318 .rsplit('=')
319 .next()
320 .unwrap()
321 .split(',')
322 .map(|p| p.trim().splitn(2, '-'))
323 .map(|mut p| {
324 let start = p.next().unwrap().parse::<usize>().unwrap();
325 let end = p.next().unwrap().parse::<usize>().unwrap();
326 start..=end
327 })
328 .flat_map(|p| mock_body_clone[p].to_vec())
329 .collect()
330 })
331 .create_async()
332 .await;
333
334 let temp_file = NamedTempFile::new().unwrap();
335 let file = temp_file.reopen().unwrap().into();
336
337 let client = Client::new();
338 let download_chunks = vec![10..80, 100..300, 1000..2000];
339 let result = download(
340 format!("{}/multi-2", server.url()),
341 RandFileWriter::new(file, mock_body.len() as u64, 8 * 1024 * 1024)
342 .await
343 .unwrap(),
344 DownloadOptions {
345 client,
346 threads: 32,
347 download_chunks: download_chunks.clone(),
348 retry_gap: Duration::from_secs(1),
349 },
350 )
351 .await
352 .unwrap();
353
354 let mut download_progress: Vec<ProgressEntry> = Vec::new();
355 let mut write_progress: Vec<ProgressEntry> = Vec::new();
356 let mut rx = result.event_chain.lock().await;
357 while let Some(e) = rx.recv().await {
358 match e {
359 Event::DownloadProgress(p) => {
360 download_progress.merge_progress(p);
361 }
362 Event::WriteProgress(p) => {
363 write_progress.merge_progress(p);
364 }
365 _ => {}
366 }
367 }
368 dbg!(&download_progress);
369 dbg!(&write_progress);
370 assert_eq!(download_progress, download_chunks);
371 assert_eq!(write_progress, download_chunks);
372
373 result.join().await.unwrap();
374
375 let output = {
376 let mut data = Vec::with_capacity(mock_body.len());
377 for _ in 0..mock_body.len() {
378 data.push(0);
379 }
380 for chunk in download_chunks.clone() {
381 for i in chunk {
382 data[i as usize] = mock_body[i as usize];
383 }
384 }
385 data
386 };
387 let mut file_content = Vec::new();
388 File::open(temp_file.path())
389 .await
390 .unwrap()
391 .read_to_end(&mut file_content)
392 .await
393 .unwrap();
394 assert_eq!(file_content, output);
395 }
396
397 #[cfg(feature = "file")]
398 #[tokio::test]
399 async fn test_multi_thread_break_point() {
400 use tokio::{fs::File, io::AsyncReadExt};
401
402 let mock_body = build_mock_data(200 * 1024 * 1024);
403 let mock_body_clone = mock_body.clone();
404 let mut server = mockito::Server::new_async().await;
405 server
406 .mock("GET", "/mutli-3")
407 .with_status(206)
408 .with_body_from_request(move |request| {
409 if !request.has_header("Range") {
410 return mock_body_clone.clone();
411 }
412 let range = request.header("Range")[0];
413 println!("range: {:?}", range);
414 range
415 .to_str()
416 .unwrap()
417 .rsplit('=')
418 .next()
419 .unwrap()
420 .split(',')
421 .map(|p| p.trim().splitn(2, '-'))
422 .map(|mut p| {
423 let start = p.next().unwrap().parse::<usize>().unwrap();
424 let end = p.next().unwrap().parse::<usize>().unwrap();
425 start..=end
426 })
427 .flat_map(|p| mock_body_clone[p].to_vec())
428 .collect()
429 })
430 .create_async()
431 .await;
432
433 let temp_file = NamedTempFile::new().unwrap();
434 let mut write_progress: Vec<ProgressEntry> = Vec::new();
435 {
436 let file = temp_file.reopen().unwrap().into();
437 let client = Client::new();
438 let result = download(
439 format!("{}/mutli-3", server.url()),
440 RandFileWriter::new(file, mock_body.len() as u64, 8 * 1024 * 1024)
441 .await
442 .unwrap(),
443 DownloadOptions {
444 client,
445 threads: 32,
446 download_chunks: vec![0..mock_body.len() as u64],
447 retry_gap: Duration::from_secs(1),
448 },
449 )
450 .await
451 .unwrap();
452 let result_clone = result.clone();
453 tokio::spawn(async move {
454 tokio::time::sleep(Duration::from_millis(1000)).await;
455 result_clone.cancel().await;
456 });
457 let mut download_progress: Vec<ProgressEntry> = Vec::new();
458 let mut rx = result.event_chain.lock().await;
459 while let Some(e) = rx.recv().await {
460 match e {
461 Event::DownloadProgress(p) => {
462 download_progress.merge_progress(p);
463 }
464 Event::WriteProgress(p) => {
465 write_progress.merge_progress(p);
466 }
467 _ => {}
468 }
469 }
470 dbg!(&download_progress);
471 dbg!(&write_progress);
472 assert_eq!(download_progress, write_progress);
473 result.join().await.unwrap();
474 let mut file_content = Vec::new();
475 File::open(temp_file.path())
476 .await
477 .unwrap()
478 .read_to_end(&mut file_content)
479 .await
480 .unwrap();
481 let output = {
482 let mut data = vec![0; mock_body.len()];
483 for chunk in write_progress.clone() {
484 for i in chunk {
485 data[i as usize] = mock_body[i as usize];
486 }
487 }
488 data
489 };
490 assert_eq!(file_content, output);
491 }
492
493 println!("开始续传");
495 let file = temp_file.reopen().unwrap().into();
496 let client = Client::new();
497 let download_chunks = reverse_progress(&write_progress, mock_body.len() as u64);
498 let result = download(
499 format!("{}/mutli-3", server.url()),
500 RandFileWriter::new(file, mock_body.len() as u64, 8 * 1024 * 1024)
501 .await
502 .unwrap(),
503 DownloadOptions {
504 client,
505 threads: 8,
506 download_chunks: download_chunks.clone(),
507 retry_gap: Duration::from_secs(1),
508 },
509 )
510 .await
511 .unwrap();
512
513 let mut download_progress: Vec<ProgressEntry> = Vec::new();
514 let mut write_progress: Vec<ProgressEntry> = Vec::new();
515 let mut rx = result.event_chain.lock().await;
516 while let Some(e) = rx.recv().await {
517 match e {
518 Event::DownloadProgress(p) => {
519 download_progress.merge_progress(p);
520 }
521 Event::WriteProgress(p) => {
522 write_progress.merge_progress(p);
523 }
524 _ => {}
525 }
526 }
527 dbg!(&download_progress);
528 dbg!(&write_progress);
529 assert_eq!(download_progress, download_chunks);
530 assert_eq!(write_progress, download_chunks);
531
532 result.join().await.unwrap();
533
534 let mut file_content = Vec::new();
535 File::open(temp_file.path())
536 .await
537 .unwrap()
538 .read_to_end(&mut file_content)
539 .await
540 .unwrap();
541 assert_eq!(file_content, mock_body);
542 }
543}