1extern crate alloc;
2use super::macros::poll_ok;
3use crate::{DownloadResult, Event, ProgressEntry, RandPuller, RandPusher, Total, WorkerId};
4use alloc::{sync::Arc, vec::Vec};
5use bytes::Bytes;
6use core::{
7 num::{NonZero, NonZeroU64, NonZeroUsize},
8 sync::atomic::{AtomicUsize, Ordering},
9 time::Duration,
10};
11use fast_steal::{Executor, Handle, Task, TaskList};
12use futures::TryStreamExt;
13use tokio::task::AbortHandle;
14
15#[derive(Debug, Clone)]
16pub struct DownloadOptions {
17 pub download_chunks: Vec<ProgressEntry>,
18 pub concurrent: NonZeroUsize,
19 pub retry_gap: Duration,
20 pub push_queue_cap: usize,
21 pub min_chunk_size: NonZeroU64,
22}
23
24pub async fn download_multi<R, W>(
25 puller: R,
26 mut pusher: W,
27 options: DownloadOptions,
28) -> DownloadResult<TokioExecutor<R, W::Error>, R::Error, W::Error>
29where
30 R: RandPuller + 'static + Sync,
31 W: RandPusher + 'static,
32{
33 let (tx, event_chain) = kanal::unbounded_async();
34 let (tx_push, rx_push) =
35 kanal::bounded_async::<(WorkerId, ProgressEntry, Bytes)>(options.push_queue_cap);
36 let tx_clone = tx.clone();
37 let push_handle = tokio::spawn(async move {
38 while let Ok((id, spin, data)) = rx_push.recv().await {
39 poll_ok!(
40 pusher.push(spin.clone(), &data).await,
41 id @ tx_clone => PushError,
42 options.retry_gap
43 );
44 tx_clone.send(Event::PushProgress(id, spin)).await.unwrap();
45 }
46 poll_ok!(
47 pusher.flush().await,
48 tx_clone => FlushError,
49 options.retry_gap
50 );
51 });
52 let executor: Arc<TokioExecutor<R, W::Error>> = Arc::new(TokioExecutor {
53 tx,
54 tx_push,
55 puller,
56 retry_gap: options.retry_gap,
57 id: Arc::new(AtomicUsize::new(0)),
58 min_chunk_size: options.min_chunk_size,
59 });
60 let task_list = Arc::new(TaskList::run(&options.download_chunks[..], executor));
61 task_list
62 .clone()
63 .set_threads(options.concurrent, options.min_chunk_size);
64 DownloadResult::new(
65 event_chain,
66 push_handle,
67 None,
68 Some(Arc::downgrade(&task_list)),
69 )
70}
71
72#[derive(Clone)]
73pub struct TokioHandle(AbortHandle);
74impl Handle for TokioHandle {
75 type Output = ();
76 fn abort(&mut self) -> Self::Output {
77 self.0.abort();
78 }
79}
80#[derive(Debug)]
81pub struct TokioExecutor<R, WE>
82where
83 R: RandPuller + 'static,
84 WE: Send + 'static,
85{
86 tx: kanal::AsyncSender<Event<R::Error, WE>>,
87 tx_push: kanal::AsyncSender<(WorkerId, ProgressEntry, Bytes)>,
88 puller: R,
89 retry_gap: Duration,
90 id: Arc<AtomicUsize>,
91 min_chunk_size: NonZeroU64,
92}
93impl<R, WE> Executor for TokioExecutor<R, WE>
94where
95 R: RandPuller + 'static + Sync,
96 WE: Send + 'static,
97{
98 type Handle = TokioHandle;
99 fn execute(self: Arc<Self>, task: Arc<Task>, task_list: Arc<TaskList<Self>>) -> Self::Handle {
100 let id = self.id.fetch_add(1, Ordering::SeqCst);
101 let steal_min_chunk_size = NonZero::new(2 * self.min_chunk_size.get()).unwrap();
102 let mut puller = self.puller.clone();
103 let handle = tokio::spawn(async move {
104 'steal_task: loop {
105 let mut start = task.start();
106 if start >= task.end() {
107 if task_list.steal(&task, steal_min_chunk_size) {
108 continue;
109 } else {
110 break;
111 }
112 }
113 self.tx.send(Event::Pulling(id)).await.unwrap();
114 let download_range = start..task.end();
115 let mut stream = loop {
116 match puller.pull(&download_range).await {
117 Ok(t) => break t,
118 Err((e, retry_gap)) => {
119 self.tx.send(Event::PullError(id, e)).await.unwrap();
120 tokio::time::sleep(retry_gap.unwrap_or(self.retry_gap)).await;
121 }
122 }
123 };
124 loop {
125 match stream.try_next().await {
126 Ok(Some(mut chunk)) => {
127 let len = chunk.len() as u64;
128 task.fetch_add_start(len);
129 let range_start = start;
130 start += len;
131 let range_end = start.min(task.end());
132 if range_start >= range_end {
133 continue 'steal_task;
134 }
135 let span = range_start..range_end;
136 chunk.truncate(span.total() as usize);
137 self.tx
138 .send(Event::PullProgress(id, span.clone()))
139 .await
140 .unwrap();
141 self.tx_push.send((id, span, chunk)).await.unwrap();
142 }
143 Ok(None) => break,
144 Err((e, retry_gap)) => {
145 self.tx.send(Event::PullError(id, e)).await.unwrap();
146 tokio::time::sleep(retry_gap.unwrap_or(self.retry_gap)).await;
147 }
148 }
149 }
150 }
151 task_list.remove(&task);
152 self.tx.send(Event::Finished(id)).await.unwrap();
153 });
154 TokioHandle(handle.abort_handle())
155 }
156}
157
158#[cfg(test)]
159mod tests {
160 extern crate std;
161 use super::*;
162 use crate::{
163 MergeProgress, ProgressEntry,
164 mem::MemPusher,
165 mock::{MockPuller, build_mock_data},
166 };
167 use alloc::vec;
168 use std::dbg;
169
170 #[tokio::test]
171 async fn test_concurrent_download() {
172 let mock_data = build_mock_data(3 * 1024);
173 let puller = MockPuller::new(&mock_data);
174 let pusher = MemPusher::with_capacity(mock_data.len());
175 #[allow(clippy::single_range_in_vec_init)]
176 let download_chunks = vec![0..mock_data.len() as u64];
177 let result = download_multi(
178 puller,
179 pusher.clone(),
180 DownloadOptions {
181 concurrent: NonZero::new(32).unwrap(),
182 retry_gap: Duration::from_secs(1),
183 push_queue_cap: 1024,
184 download_chunks: download_chunks.clone(),
185 min_chunk_size: NonZero::new(1).unwrap(),
186 },
187 )
188 .await;
189
190 let mut pull_progress: Vec<ProgressEntry> = Vec::new();
191 let mut push_progress: Vec<ProgressEntry> = Vec::new();
192 let mut pull_ids = [false; 32];
193 let mut push_ids = [false; 32];
194 while let Ok(e) = result.event_chain.recv().await {
195 match e {
196 Event::PullProgress(id, p) => {
197 pull_ids[id] = true;
198 pull_progress.merge_progress(p);
199 }
200 Event::PushProgress(id, p) => {
201 push_ids[id] = true;
202 push_progress.merge_progress(p);
203 }
204 _ => {}
205 }
206 }
207 dbg!(&pull_progress);
208 dbg!(&push_progress);
209 assert_eq!(pull_progress, download_chunks);
210 assert_eq!(push_progress, download_chunks);
211 assert_eq!(pull_ids, [true; 32]);
212 assert_eq!(push_ids, [true; 32]);
213
214 result.join().await.unwrap();
215 assert_eq!(&**pusher.receive.lock().await, mock_data);
216 }
217}