1extern crate alloc;
2use super::macros::poll_ok;
3use crate::{DownloadResult, Event, ProgressEntry, RandPuller, RandPusher, Total, WorkerId};
4use alloc::{sync::Arc, vec::Vec};
5use bytes::Bytes;
6use core::{
7 num::{NonZero, NonZeroU64, NonZeroUsize},
8 sync::atomic::{AtomicUsize, Ordering},
9 time::Duration,
10};
11use fast_steal::{Executor, Handle, Task, TaskList};
12use futures::TryStreamExt;
13use tokio::task::AbortHandle;
14
15#[derive(Debug, Clone)]
16pub struct DownloadOptions {
17 pub download_chunks: Vec<ProgressEntry>,
18 pub concurrent: NonZeroUsize,
19 pub retry_gap: Duration,
20 pub push_queue_cap: usize,
21 pub min_chunk_size: NonZeroU64,
22}
23
24pub async fn download_multi<R, W>(
25 puller: R,
26 mut pusher: W,
27 options: DownloadOptions,
28) -> DownloadResult<R::Error, W::Error>
29where
30 R: RandPuller + 'static + Sync,
31 W: RandPusher + 'static,
32{
33 let (tx, event_chain) = kanal::unbounded_async();
34 let (tx_push, rx_push) =
35 kanal::bounded_async::<(WorkerId, ProgressEntry, Bytes)>(options.push_queue_cap);
36 let tx_clone = tx.clone();
37 let push_handle = tokio::spawn(async move {
38 while let Ok((id, spin, data)) = rx_push.recv().await {
39 poll_ok!(
40 {},
41 pusher.push(spin.clone(), data.clone()).await,
42 id @ tx_clone => PushError,
43 options.retry_gap
44 );
45 tx_clone.send(Event::PushProgress(id, spin)).await.unwrap();
46 }
47 poll_ok!(
48 {},
49 pusher.flush().await,
50 tx_clone => FlushError,
51 options.retry_gap
52 );
53 });
54 let executor: TokioExecutor<R, W> = TokioExecutor {
55 tx,
56 tx_push,
57 puller,
58 retry_gap: options.retry_gap,
59 id: Arc::new(AtomicUsize::new(0)),
60 min_chunk_size: options.min_chunk_size,
61 };
62 let task_list = TaskList::run(
63 options.concurrent,
64 options.min_chunk_size,
65 &options.download_chunks[..],
66 executor,
67 );
68 DownloadResult::new(
69 event_chain,
70 push_handle,
71 &task_list.handles(|iter| iter.map(|h| h.0.clone()).collect::<Arc<[_]>>()),
72 )
73}
74
75#[derive(Clone)]
76pub struct TokioHandle(AbortHandle);
77impl Handle for TokioHandle {
78 type Output = ();
79 fn abort(&mut self) -> Self::Output {
80 self.0.abort();
81 }
82}
83pub struct TokioExecutor<R, W>
84where
85 R: RandPuller + 'static,
86 W: RandPusher + 'static,
87{
88 tx: kanal::AsyncSender<Event<R::Error, W::Error>>,
89 tx_push: kanal::AsyncSender<(WorkerId, ProgressEntry, Bytes)>,
90 puller: R,
91 retry_gap: Duration,
92 id: Arc<AtomicUsize>,
93 min_chunk_size: NonZeroU64,
94}
95impl<R, W> Executor for TokioExecutor<R, W>
96where
97 R: RandPuller + 'static + Sync,
98 W: RandPusher + 'static,
99{
100 type Handle = TokioHandle;
101 fn execute(self: Arc<Self>, task: Arc<Task>, task_list: Arc<TaskList<Self>>) -> Self::Handle {
102 let id = self.id.fetch_add(1, Ordering::SeqCst);
103 let handle = tokio::spawn(async move {
104 'steal_task: loop {
105 let mut start = task.start();
106 if start >= task.end() {
107 if task_list.steal(&task, NonZero::new(2 * self.min_chunk_size.get()).unwrap())
108 {
109 continue;
110 }
111 break;
112 }
113 self.tx.send(Event::Pulling(id)).await.unwrap();
114 let download_range = start..task.end();
115 let mut puller = self.puller.clone();
116 let mut stream = puller.pull(&download_range);
117 loop {
118 match stream.try_next().await {
119 Ok(Some(mut chunk)) => {
120 let len = chunk.len() as u64;
121 task.fetch_add_start(len);
122 let range_start = start;
123 start += len;
124 let range_end = start.min(task.end());
125 if range_start >= range_end {
126 continue 'steal_task;
127 }
128 let span = range_start..range_end;
129 let len = span.total() as usize;
130 self.tx
131 .send(Event::PullProgress(id, span.clone()))
132 .await
133 .unwrap();
134 self.tx_push
135 .send((id, span, chunk.split_to(len)))
136 .await
137 .unwrap();
138 }
139 Ok(None) => break,
140 Err(e) => {
141 self.tx.send(Event::PullError(id, e)).await.unwrap();
142 tokio::time::sleep(self.retry_gap).await;
143 }
144 }
145 }
146 }
147 self.tx.send(Event::Finished(id)).await.unwrap();
148 task_list.remove(&task);
149 });
150 TokioHandle(handle.abort_handle())
151 }
152}
153
154#[cfg(test)]
155mod tests {
156 extern crate std;
157 use super::*;
158 use crate::{
159 MergeProgress, ProgressEntry,
160 core::mock::{MockRandPuller, MockRandPusher, build_mock_data},
161 };
162 use alloc::vec;
163 use std::dbg;
164
165 #[tokio::test]
166 async fn test_concurrent_download() {
167 let mock_data = build_mock_data(3 * 1024);
168 let puller = MockRandPuller::new(&mock_data);
169 let pusher = MockRandPusher::new(&mock_data);
170 #[allow(clippy::single_range_in_vec_init)]
171 let download_chunks = vec![0..mock_data.len() as u64];
172 let result = download_multi(
173 puller,
174 pusher.clone(),
175 DownloadOptions {
176 concurrent: NonZero::new(32).unwrap(),
177 retry_gap: Duration::from_secs(1),
178 push_queue_cap: 1024,
179 download_chunks: download_chunks.clone(),
180 min_chunk_size: NonZero::new(1).unwrap(),
181 },
182 )
183 .await;
184
185 let mut pull_progress: Vec<ProgressEntry> = Vec::new();
186 let mut push_progress: Vec<ProgressEntry> = Vec::new();
187 let mut pull_ids = [false; 32];
188 let mut push_ids = [false; 32];
189 while let Ok(e) = result.event_chain.recv().await {
190 match e {
191 Event::PullProgress(id, p) => {
192 pull_ids[id] = true;
193 pull_progress.merge_progress(p);
194 }
195 Event::PushProgress(id, p) => {
196 push_ids[id] = true;
197 push_progress.merge_progress(p);
198 }
199 _ => {}
200 }
201 }
202 dbg!(&pull_progress);
203 dbg!(&push_progress);
204 assert_eq!(pull_progress, download_chunks);
205 assert_eq!(push_progress, download_chunks);
206 assert_eq!(pull_ids, [true; 32]);
207 assert_eq!(push_ids, [true; 32]);
208
209 result.join().await.unwrap();
210 pusher.assert().await;
211 }
212}