1extern crate alloc;
2use super::macros::{check_running, poll_ok};
3use crate::{DownloadResult, Event, ProgressEntry, RandReader, RandWriter, Total, WorkerId};
4use alloc::sync::Arc;
5use bytes::Bytes;
6use core::{num::NonZeroUsize, sync::atomic::AtomicBool, time::Duration};
7use fast_steal::{SplitTask, StealTask, Task, TaskList};
8use futures::{TryStreamExt, lock::Mutex};
9
10#[derive(Debug, Clone)]
11pub struct DownloadOptions {
12 pub download_chunks: Vec<ProgressEntry>,
13 pub concurrent: NonZeroUsize,
14 pub retry_gap: Duration,
15 pub write_queue_cap: usize,
16}
17
18pub async fn download_multi<R, W>(
19 reader: R,
20 mut writer: W,
21 options: DownloadOptions,
22) -> DownloadResult<R::Error, W::Error>
23where
24 R: RandReader + 'static,
25 W: RandWriter + 'static,
26{
27 let (tx, event_chain) = kanal::unbounded_async();
28 let (tx_write, rx_write) =
29 kanal::bounded_async::<(WorkerId, ProgressEntry, Bytes)>(options.write_queue_cap);
30 let tx_clone = tx.clone();
31 let handle = tokio::spawn(async move {
32 while let Ok((id, spin, data)) = rx_write.recv().await {
33 poll_ok!(
34 {},
35 writer.write(spin.clone(), data.clone()).await,
36 id @ tx_clone => WriteError,
37 options.retry_gap
38 );
39 tx_clone.send(Event::WriteProgress(id, spin)).await.unwrap();
40 }
41 poll_ok!(
42 {},
43 writer.flush().await,
44 tx_clone => FlushError,
45 options.retry_gap
46 );
47 });
48 let mutex = Arc::new(Mutex::new(()));
49 let task_list = Arc::new(TaskList::from(&options.download_chunks[..]));
50 let tasks = Arc::from_iter(
51 Task::from(&*task_list)
52 .split_task(options.concurrent.get() as u64)
53 .map(Arc::new),
54 );
55 let running = Arc::new(AtomicBool::new(true));
56 for (id, task) in tasks.iter().enumerate() {
57 let task = task.clone();
58 let tasks = tasks.clone();
59 let task_list = task_list.clone();
60 let mutex = mutex.clone();
61 let tx = tx.clone();
62 let running = running.clone();
63 let mut reader = reader.clone();
64 let tx_write = tx_write.clone();
65 tokio::spawn(async move {
66 'steal_task: loop {
67 check_running!(id, running, tx);
68 let mut start = task.start();
69 if start >= task.end() {
70 let guard = mutex.lock().await;
71 if task.steal(&tasks, 16 * 1024) {
72 continue;
73 }
74 drop(guard);
75 tx.send(Event::Finished(id)).await.unwrap();
76 return;
77 }
78 let download_range = &task_list.get_range(start..task.end());
79 for range in download_range {
80 check_running!(id, running, tx);
81 tx.send(Event::Reading(id)).await.unwrap();
82 let mut stream = reader.read(range);
83 let mut downloaded = 0;
84 loop {
85 check_running!(id, running, tx);
86 match stream.try_next().await {
87 Ok(Some(mut chunk)) => {
88 let len = chunk.len() as u64;
89 task.fetch_add_start(len);
90 start += len;
91 let range_start = range.start + downloaded;
92 downloaded += len;
93 let range_end = range.start + downloaded;
94 let span = range_start..range_end.min(task_list.get(task.end()));
95 let len = span.total() as usize;
96 tx.send(Event::ReadProgress(id, span.clone()))
97 .await
98 .unwrap();
99 tx_write
100 .send((id, span, chunk.split_to(len)))
101 .await
102 .unwrap();
103 if start >= task.end() {
104 continue 'steal_task;
105 }
106 }
107 Ok(None) => break,
108 Err(e) => tx.send(Event::ReadError(id, e)).await.unwrap(),
109 }
110 }
111 }
112 }
113 });
114 }
115 DownloadResult::new(event_chain, handle, running)
116}
117
118#[cfg(test)]
119mod tests {
120 use super::*;
121 use crate::{
122 MergeProgress, ProgressEntry,
123 core::mock::{MockRandReader, MockRandWriter, build_mock_data},
124 };
125
126 #[tokio::test]
127 async fn test_concurrent_download() {
128 let mock_data = build_mock_data(3 * 1024);
129 let reader = MockRandReader::new(&mock_data);
130 let writer = MockRandWriter::new(&mock_data);
131 #[allow(clippy::single_range_in_vec_init)]
132 let download_chunks = vec![0..mock_data.len() as u64];
133 let result = download_multi(
134 reader,
135 writer.clone(),
136 DownloadOptions {
137 concurrent: NonZeroUsize::new(32).unwrap(),
138 retry_gap: Duration::from_secs(1),
139 write_queue_cap: 1024,
140 download_chunks: download_chunks.clone(),
141 },
142 )
143 .await;
144
145 let mut download_progress: Vec<ProgressEntry> = Vec::new();
146 let mut write_progress: Vec<ProgressEntry> = Vec::new();
147 while let Ok(e) = result.event_chain.recv().await {
148 match e {
149 Event::ReadProgress(_, p) => {
150 download_progress.merge_progress(p);
151 }
152 Event::WriteProgress(_, p) => {
153 write_progress.merge_progress(p);
154 }
155 _ => {}
156 }
157 }
158 dbg!(&download_progress);
159 dbg!(&write_progress);
160 assert_eq!(download_progress, download_chunks);
161 assert_eq!(write_progress, download_chunks);
162
163 result.join().await.unwrap();
164 writer.assert().await;
165 }
166}