1extern crate alloc;
2use super::macros::{check_running, poll_ok};
3use crate::{DownloadResult, Event, ProgressEntry, RandReader, RandWriter, Total, WorkerId};
4use alloc::{sync::Arc, vec::Vec};
5use bytes::Bytes;
6use core::{num::NonZeroUsize, sync::atomic::AtomicBool, time::Duration};
7use fast_steal::{SplitTask, StealTask, Task, TaskList};
8use futures::TryStreamExt;
9use tokio::sync::Mutex;
10
11#[derive(Debug, Clone)]
12pub struct DownloadOptions {
13 pub download_chunks: Vec<ProgressEntry>,
14 pub concurrent: NonZeroUsize,
15 pub retry_gap: Duration,
16 pub write_queue_cap: usize,
17}
18
19pub async fn download_multi<R, W>(
20 reader: R,
21 mut writer: W,
22 options: DownloadOptions,
23) -> DownloadResult<R::Error, W::Error>
24where
25 R: RandReader + 'static,
26 W: RandWriter + 'static,
27{
28 let (tx, event_chain) = kanal::unbounded_async();
29 let (tx_write, rx_write) =
30 kanal::bounded_async::<(WorkerId, ProgressEntry, Bytes)>(options.write_queue_cap);
31 let tx_clone = tx.clone();
32 let handle = tokio::spawn(async move {
33 while let Ok((id, spin, data)) = rx_write.recv().await {
34 poll_ok!(
35 {},
36 writer.write(spin.clone(), data.clone()).await,
37 id @ tx_clone => WriteError,
38 options.retry_gap
39 );
40 tx_clone.send(Event::WriteProgress(id, spin)).await.unwrap();
41 }
42 poll_ok!(
43 {},
44 writer.flush().await,
45 tx_clone => FlushError,
46 options.retry_gap
47 );
48 });
49 let mutex = Arc::new(Mutex::new(()));
50 let task_list = Arc::new(TaskList::from(&options.download_chunks[..]));
51 let tasks = Arc::from_iter(
52 Task::from(&*task_list)
53 .split_task(options.concurrent.get() as u64)
54 .map(Arc::new),
55 );
56 let running = Arc::new(AtomicBool::new(true));
57 for (id, task) in tasks.iter().enumerate() {
58 let task = task.clone();
59 let tasks = tasks.clone();
60 let task_list = task_list.clone();
61 let mutex = mutex.clone();
62 let tx = tx.clone();
63 let running = running.clone();
64 let mut reader = reader.clone();
65 let tx_write = tx_write.clone();
66 tokio::spawn(async move {
67 'steal_task: loop {
68 check_running!(id, running, tx);
69 let mut start = task.start();
70 if start >= task.end() {
71 let guard = mutex.lock().await;
72 if task.steal(&tasks, 16 * 1024) {
73 continue;
74 }
75 drop(guard);
76 tx.send(Event::Finished(id)).await.unwrap();
77 return;
78 }
79 let download_range = &task_list.get_range(start..task.end());
80 for range in download_range {
81 check_running!(id, running, tx);
82 tx.send(Event::Reading(id)).await.unwrap();
83 let mut stream = reader.read(range);
84 let mut downloaded = 0;
85 loop {
86 check_running!(id, running, tx);
87 match stream.try_next().await {
88 Ok(Some(mut chunk)) => {
89 let len = chunk.len() as u64;
90 task.fetch_add_start(len);
91 start += len;
92 let range_start = range.start + downloaded;
93 downloaded += len;
94 let range_end = range.start + downloaded;
95 let span = range_start..range_end.min(task_list.get(task.end()));
96 let len = span.total() as usize;
97 tx.send(Event::ReadProgress(id, span.clone()))
98 .await
99 .unwrap();
100 tx_write
101 .send((id, span, chunk.split_to(len)))
102 .await
103 .unwrap();
104 if start >= task.end() {
105 continue 'steal_task;
106 }
107 }
108 Ok(None) => break,
109 Err(e) => tx.send(Event::ReadError(id, e)).await.unwrap(),
110 }
111 }
112 }
113 }
114 });
115 }
116 DownloadResult::new(event_chain, handle, running)
117}
118
119#[cfg(test)]
120mod tests {
121 extern crate std;
122 use super::*;
123 use crate::{
124 MergeProgress, ProgressEntry,
125 core::mock::{MockRandReader, MockRandWriter, build_mock_data},
126 };
127 use alloc::vec;
128 use std::dbg;
129
130 #[tokio::test]
131 async fn test_concurrent_download() {
132 let mock_data = build_mock_data(3 * 1024);
133 let reader = MockRandReader::new(&mock_data);
134 let writer = MockRandWriter::new(&mock_data);
135 #[allow(clippy::single_range_in_vec_init)]
136 let download_chunks = vec![0..mock_data.len() as u64];
137 let result = download_multi(
138 reader,
139 writer.clone(),
140 DownloadOptions {
141 concurrent: NonZeroUsize::new(32).unwrap(),
142 retry_gap: Duration::from_secs(1),
143 write_queue_cap: 1024,
144 download_chunks: download_chunks.clone(),
145 },
146 )
147 .await;
148
149 let mut download_progress: Vec<ProgressEntry> = Vec::new();
150 let mut write_progress: Vec<ProgressEntry> = Vec::new();
151 while let Ok(e) = result.event_chain.recv().await {
152 match e {
153 Event::ReadProgress(_, p) => {
154 download_progress.merge_progress(p);
155 }
156 Event::WriteProgress(_, p) => {
157 write_progress.merge_progress(p);
158 }
159 _ => {}
160 }
161 }
162 dbg!(&download_progress);
163 dbg!(&write_progress);
164 assert_eq!(download_progress, download_chunks);
165 assert_eq!(write_progress, download_chunks);
166
167 result.join().await.unwrap();
168 writer.assert().await;
169 }
170}