fast_down/core/
multi.rs

1use super::DownloadResult;
2use crate::{ConnectErrorKind, Event, ProgressEntry, RandWriter, Total};
3use bytes::Bytes;
4use fast_steal::{SplitTask, StealTask, Task, TaskList};
5use reqwest::{header, Client, IntoUrl, StatusCode};
6use std::{
7    sync::{
8        atomic::{AtomicBool, Ordering},
9        Arc,
10    },
11    time::Duration,
12};
13use tokio::sync::{mpsc, Mutex};
14
15#[derive(Debug, Clone)]
16pub struct DownloadOptions {
17    pub threads: usize,
18    pub client: Client,
19    pub download_chunks: Vec<ProgressEntry>,
20    pub retry_gap: Duration,
21}
22
23pub async fn download(
24    url: impl IntoUrl,
25    mut writer: impl RandWriter + 'static,
26    options: DownloadOptions,
27) -> Result<DownloadResult, reqwest::Error> {
28    let url = url.into_url()?;
29    let (tx, event_chain) = mpsc::channel(1024);
30    let (tx_write, mut rx_write) = mpsc::channel::<(ProgressEntry, Bytes)>(1024);
31    let tx_clone = tx.clone();
32    let handle = tokio::spawn(async move {
33        while let Some((spin, data)) = rx_write.recv().await {
34            loop {
35                match writer.write_randomly(spin.clone(), &data).await {
36                    Ok(_) => break,
37                    Err(e) => tx_clone.send(Event::WriteError(e)).await.unwrap(),
38                }
39                tokio::time::sleep(options.retry_gap).await;
40            }
41            tx_clone.send(Event::WriteProgress(spin)).await.unwrap();
42        }
43        loop {
44            match writer.flush().await {
45                Ok(_) => break,
46                Err(e) => tx_clone.send(Event::WriteError(e)).await.unwrap(),
47            };
48            tokio::time::sleep(options.retry_gap).await;
49        }
50    });
51    let mutex = Arc::new(Mutex::new(()));
52    let task_list = Arc::new(TaskList::from(options.download_chunks));
53    let tasks = Arc::new(
54        Task::from(&*task_list)
55            .split_task(options.threads as u64)
56            .map(|t| Arc::new(t))
57            .collect::<Vec<_>>(),
58    );
59    let running = Arc::new(AtomicBool::new(true));
60    let running_clone = running.clone();
61    let client = Arc::new(options.client);
62    let url = Arc::new(url);
63    for (id, task) in tasks.iter().enumerate() {
64        let task = task.clone();
65        let tasks = tasks.clone();
66        let task_list = task_list.clone();
67        let mutex = mutex.clone();
68        let tx = tx.clone();
69        let running = running.clone();
70        let client = client.clone();
71        let url = url.clone();
72        let tx_write = tx_write.clone();
73        tokio::spawn(async move {
74            'a: loop {
75                if !running.load(Ordering::Relaxed) {
76                    tx.send(Event::Abort(id)).await.unwrap();
77                    return;
78                }
79                let mut start = task.start();
80                if start >= task.end() {
81                    let guard = mutex.lock().await;
82                    if task.steal(&tasks, 2) {
83                        continue;
84                    }
85                    drop(guard);
86                    tx.send(Event::Finished(id)).await.unwrap();
87                    return;
88                }
89                let download_range = &task_list.get_range(start..task.end());
90                for range in download_range {
91                    let header_range_value = format!("bytes={}-{}", range.start, range.end - 1);
92                    let mut response = loop {
93                        if !running.load(Ordering::Relaxed) {
94                            tx.send(Event::Abort(id)).await.unwrap();
95                            return;
96                        }
97                        tx.send(Event::Connecting(id)).await.unwrap();
98                        match client
99                            .get(url.as_str())
100                            .header(header::RANGE, &header_range_value)
101                            .send()
102                            .await
103                        {
104                            Ok(response) if response.status() == StatusCode::PARTIAL_CONTENT => {
105                                break response
106                            }
107                            Ok(response) => tx.send(Event::ConnectError(
108                                id,
109                                ConnectErrorKind::StatusCode(response.status()),
110                            )),
111                            Err(e) => {
112                                tx.send(Event::ConnectError(id, ConnectErrorKind::Reqwest(e)))
113                            }
114                        }
115                        .await
116                        .unwrap();
117                        tokio::time::sleep(options.retry_gap).await;
118                    };
119                    tx.send(Event::Downloading(id)).await.unwrap();
120                    let mut downloaded = 0;
121                    loop {
122                        let chunk = loop {
123                            if !running.load(Ordering::Relaxed) {
124                                tx.send(Event::Abort(id)).await.unwrap();
125                                return;
126                            }
127                            match response.chunk().await {
128                                Ok(chunk) => break chunk,
129                                Err(e) => tx.send(Event::DownloadError(id, e)).await.unwrap(),
130                            }
131                            tokio::time::sleep(options.retry_gap).await;
132                        };
133                        if chunk.is_none() {
134                            break;
135                        }
136                        let mut chunk = chunk.unwrap();
137                        let len = chunk.len() as u64;
138                        task.fetch_add_start(len);
139                        start += len;
140                        let range_start = range.start + downloaded;
141                        downloaded += len;
142                        let range_end = range.start + downloaded;
143                        let span = range_start..range_end.min(task_list.get(task.end()));
144                        let len = span.total();
145                        tx.send(Event::DownloadProgress(span.clone()))
146                            .await
147                            .unwrap();
148                        tx_write
149                            .send((span, chunk.split_to(len as usize)))
150                            .await
151                            .unwrap();
152                        if start >= task.end() {
153                            continue 'a;
154                        }
155                    }
156                }
157            }
158        });
159    }
160    Ok(DownloadResult::new(
161        event_chain,
162        handle,
163        Box::new(move || {
164            running_clone.store(false, Ordering::Relaxed);
165        }),
166    ))
167}
168
169#[cfg(test)]
170mod tests {
171    use super::*;
172    #[cfg(feature = "file")]
173    use crate::writer::file::rand_file_writer_mmap::RandFileWriter;
174    use crate::{MergeProgress, ProgressEntry};
175    use tempfile::NamedTempFile;
176
177    fn build_mock_data(size: usize) -> Vec<u8> {
178        (0..size).map(|i| (i % 256) as u8).collect()
179    }
180
181    pub fn reverse_progress(progress: &[ProgressEntry], total_size: u64) -> Vec<ProgressEntry> {
182        if progress.is_empty() {
183            return vec![0..total_size];
184        }
185        let mut result = Vec::with_capacity(progress.len());
186        let mut prev_end = 0;
187        for range in progress {
188            if range.start > prev_end {
189                result.push(prev_end..range.start);
190            }
191            prev_end = range.end;
192        }
193        if prev_end < total_size {
194            result.push(prev_end..total_size);
195        }
196        result
197    }
198
199    #[cfg(feature = "file")]
200    #[tokio::test]
201    async fn test_multi_thread_regular_download() {
202        use tokio::{fs::File, io::AsyncReadExt};
203
204        let mock_body = build_mock_data(3 * 1024);
205        let mock_body_clone = mock_body.clone();
206        let mut server = mockito::Server::new_async().await;
207        server
208            .mock("GET", "/mutli-2")
209            .with_status(206)
210            .with_body_from_request(move |request| {
211                if !request.has_header("Range") {
212                    return mock_body_clone.clone();
213                }
214                let range = request.header("Range")[0];
215                println!("range: {:?}", range);
216                range
217                    .to_str()
218                    .unwrap()
219                    .rsplit('=')
220                    .next()
221                    .unwrap()
222                    .split(',')
223                    .map(|p| p.trim().splitn(2, '-'))
224                    .map(|mut p| {
225                        let start = p.next().unwrap().parse::<usize>().unwrap();
226                        let end = p.next().unwrap().parse::<usize>().unwrap();
227                        start..=end
228                    })
229                    .flat_map(|p| mock_body_clone[p].to_vec())
230                    .collect()
231            })
232            .create_async()
233            .await;
234
235        let temp_file = NamedTempFile::new().unwrap();
236        let file = temp_file.reopen().unwrap().into();
237
238        let client = Client::new();
239        let download_chunks = vec![0..mock_body.len() as u64];
240        let result = download(
241            format!("{}/mutli-2", server.url()),
242            RandFileWriter::new(file, mock_body.len() as u64, 8 * 1024 * 1024)
243                .await
244                .unwrap(),
245            DownloadOptions {
246                client,
247                threads: 32,
248                download_chunks: download_chunks.clone(),
249                retry_gap: Duration::from_secs(1),
250            },
251        )
252        .await
253        .unwrap();
254
255        let mut download_progress: Vec<ProgressEntry> = Vec::new();
256        let mut write_progress: Vec<ProgressEntry> = Vec::new();
257        let mut rx = result.event_chain.lock().await;
258        while let Some(e) = rx.recv().await {
259            match e {
260                Event::DownloadProgress(p) => {
261                    download_progress.merge_progress(p);
262                }
263                Event::WriteProgress(p) => {
264                    write_progress.merge_progress(p);
265                }
266                _ => {}
267            }
268        }
269        dbg!(&download_progress);
270        dbg!(&write_progress);
271        assert_eq!(download_progress, download_chunks);
272        assert_eq!(write_progress, download_chunks);
273
274        result.join().await.unwrap();
275
276        let output = {
277            let mut data = Vec::with_capacity(mock_body.len());
278            for _ in 0..mock_body.len() {
279                data.push(0);
280            }
281            for chunk in download_chunks.clone() {
282                for i in chunk {
283                    data[i as usize] = mock_body[i as usize];
284                }
285            }
286            data
287        };
288        let mut file_content = Vec::new();
289        File::open(temp_file.path())
290            .await
291            .unwrap()
292            .read_to_end(&mut file_content)
293            .await
294            .unwrap();
295        assert_eq!(file_content, output);
296    }
297
298    #[cfg(feature = "file")]
299    #[tokio::test]
300    async fn test_multi_thread_download_chunk() {
301        use tokio::{fs::File, io::AsyncReadExt};
302
303        let mock_body = build_mock_data(3 * 1024);
304        let mock_body_clone = mock_body.clone();
305        let mut server = mockito::Server::new_async().await;
306        server
307            .mock("GET", "/multi-2")
308            .with_status(206)
309            .with_body_from_request(move |request| {
310                if !request.has_header("Range") {
311                    return mock_body_clone.clone();
312                }
313                let range = request.header("Range")[0];
314                println!("range: {:?}", range);
315                range
316                    .to_str()
317                    .unwrap()
318                    .rsplit('=')
319                    .next()
320                    .unwrap()
321                    .split(',')
322                    .map(|p| p.trim().splitn(2, '-'))
323                    .map(|mut p| {
324                        let start = p.next().unwrap().parse::<usize>().unwrap();
325                        let end = p.next().unwrap().parse::<usize>().unwrap();
326                        start..=end
327                    })
328                    .flat_map(|p| mock_body_clone[p].to_vec())
329                    .collect()
330            })
331            .create_async()
332            .await;
333
334        let temp_file = NamedTempFile::new().unwrap();
335        let file = temp_file.reopen().unwrap().into();
336
337        let client = Client::new();
338        let download_chunks = vec![10..80, 100..300, 1000..2000];
339        let result = download(
340            format!("{}/multi-2", server.url()),
341            RandFileWriter::new(file, mock_body.len() as u64, 8 * 1024 * 1024)
342                .await
343                .unwrap(),
344            DownloadOptions {
345                client,
346                threads: 32,
347                download_chunks: download_chunks.clone(),
348                retry_gap: Duration::from_secs(1),
349            },
350        )
351        .await
352        .unwrap();
353
354        let mut download_progress: Vec<ProgressEntry> = Vec::new();
355        let mut write_progress: Vec<ProgressEntry> = Vec::new();
356        let mut rx = result.event_chain.lock().await;
357        while let Some(e) = rx.recv().await {
358            match e {
359                Event::DownloadProgress(p) => {
360                    download_progress.merge_progress(p);
361                }
362                Event::WriteProgress(p) => {
363                    write_progress.merge_progress(p);
364                }
365                _ => {}
366            }
367        }
368        dbg!(&download_progress);
369        dbg!(&write_progress);
370        assert_eq!(download_progress, download_chunks);
371        assert_eq!(write_progress, download_chunks);
372
373        result.join().await.unwrap();
374
375        let output = {
376            let mut data = Vec::with_capacity(mock_body.len());
377            for _ in 0..mock_body.len() {
378                data.push(0);
379            }
380            for chunk in download_chunks.clone() {
381                for i in chunk {
382                    data[i as usize] = mock_body[i as usize];
383                }
384            }
385            data
386        };
387        let mut file_content = Vec::new();
388        File::open(temp_file.path())
389            .await
390            .unwrap()
391            .read_to_end(&mut file_content)
392            .await
393            .unwrap();
394        assert_eq!(file_content, output);
395    }
396
397    #[cfg(feature = "file")]
398    #[tokio::test]
399    async fn test_multi_thread_break_point() {
400        use tokio::{fs::File, io::AsyncReadExt};
401
402        let mock_body = build_mock_data(200 * 1024 * 1024);
403        let mock_body_clone = mock_body.clone();
404        let mut server = mockito::Server::new_async().await;
405        server
406            .mock("GET", "/mutli-3")
407            .with_status(206)
408            .with_body_from_request(move |request| {
409                if !request.has_header("Range") {
410                    return mock_body_clone.clone();
411                }
412                let range = request.header("Range")[0];
413                println!("range: {:?}", range);
414                range
415                    .to_str()
416                    .unwrap()
417                    .rsplit('=')
418                    .next()
419                    .unwrap()
420                    .split(',')
421                    .map(|p| p.trim().splitn(2, '-'))
422                    .map(|mut p| {
423                        let start = p.next().unwrap().parse::<usize>().unwrap();
424                        let end = p.next().unwrap().parse::<usize>().unwrap();
425                        start..=end
426                    })
427                    .flat_map(|p| mock_body_clone[p].to_vec())
428                    .collect()
429            })
430            .create_async()
431            .await;
432
433        let temp_file = NamedTempFile::new().unwrap();
434        let mut write_progress: Vec<ProgressEntry> = Vec::new();
435        {
436            let file = temp_file.reopen().unwrap().into();
437            let client = Client::new();
438            let result = download(
439                format!("{}/mutli-3", server.url()),
440                RandFileWriter::new(file, mock_body.len() as u64, 8 * 1024 * 1024)
441                    .await
442                    .unwrap(),
443                DownloadOptions {
444                    client,
445                    threads: 32,
446                    download_chunks: vec![0..mock_body.len() as u64],
447                    retry_gap: Duration::from_secs(1),
448                },
449            )
450            .await
451            .unwrap();
452            let result_clone = result.clone();
453            tokio::spawn(async move {
454                tokio::time::sleep(Duration::from_millis(1000)).await;
455                result_clone.cancel().await;
456            });
457            let mut download_progress: Vec<ProgressEntry> = Vec::new();
458            let mut rx = result.event_chain.lock().await;
459            while let Some(e) = rx.recv().await {
460                match e {
461                    Event::DownloadProgress(p) => {
462                        download_progress.merge_progress(p);
463                    }
464                    Event::WriteProgress(p) => {
465                        write_progress.merge_progress(p);
466                    }
467                    _ => {}
468                }
469            }
470            dbg!(&download_progress);
471            dbg!(&write_progress);
472            assert_eq!(download_progress, write_progress);
473            result.join().await.unwrap();
474            let mut file_content = Vec::new();
475            File::open(temp_file.path())
476                .await
477                .unwrap()
478                .read_to_end(&mut file_content)
479                .await
480                .unwrap();
481            let output = {
482                let mut data = vec![0; mock_body.len()];
483                for chunk in write_progress.clone() {
484                    for i in chunk {
485                        data[i as usize] = mock_body[i as usize];
486                    }
487                }
488                data
489            };
490            assert_eq!(file_content, output);
491        }
492
493        // 开始续传
494        println!("开始续传");
495        let file = temp_file.reopen().unwrap().into();
496        let client = Client::new();
497        let download_chunks = reverse_progress(&write_progress, mock_body.len() as u64);
498        let result = download(
499            format!("{}/mutli-3", server.url()),
500            RandFileWriter::new(file, mock_body.len() as u64, 8 * 1024 * 1024)
501                .await
502                .unwrap(),
503            DownloadOptions {
504                client,
505                threads: 8,
506                download_chunks: download_chunks.clone(),
507                retry_gap: Duration::from_secs(1),
508            },
509        )
510        .await
511        .unwrap();
512
513        let mut download_progress: Vec<ProgressEntry> = Vec::new();
514        let mut write_progress: Vec<ProgressEntry> = Vec::new();
515        let mut rx = result.event_chain.lock().await;
516        while let Some(e) = rx.recv().await {
517            match e {
518                Event::DownloadProgress(p) => {
519                    download_progress.merge_progress(p);
520                }
521                Event::WriteProgress(p) => {
522                    write_progress.merge_progress(p);
523                }
524                _ => {}
525            }
526        }
527        dbg!(&download_progress);
528        dbg!(&write_progress);
529        assert_eq!(download_progress, download_chunks);
530        assert_eq!(write_progress, download_chunks);
531
532        result.join().await.unwrap();
533
534        let mut file_content = Vec::new();
535        File::open(temp_file.path())
536            .await
537            .unwrap()
538            .read_to_end(&mut file_content)
539            .await
540            .unwrap();
541        assert_eq!(file_content, mock_body);
542    }
543}