1extern crate alloc;
2use super::macros::poll_ok;
3use crate::{DownloadResult, Event, ProgressEntry, RandPuller, RandPusher, Total, WorkerId};
4use alloc::{sync::Arc, vec::Vec};
5use bytes::Bytes;
6use core::{
7 num::{NonZeroU64, NonZeroUsize},
8 sync::atomic::{AtomicUsize, Ordering},
9 time::Duration,
10};
11use fast_steal::{Executor, Handle, Task, TaskList};
12use futures::TryStreamExt;
13use tokio::task::AbortHandle;
14
15#[derive(Debug, Clone)]
16pub struct DownloadOptions {
17 pub download_chunks: Vec<ProgressEntry>,
18 pub concurrent: NonZeroUsize,
19 pub retry_gap: Duration,
20 pub push_queue_cap: usize,
21 pub min_chunk_size: NonZeroU64,
22}
23
24pub async fn download_multi<R, W>(
25 puller: R,
26 mut pusher: W,
27 options: DownloadOptions,
28) -> DownloadResult<TokioExecutor<R, W::Error>, R::Error, W::Error>
29where
30 R: RandPuller + 'static + Sync,
31 W: RandPusher + 'static,
32{
33 let (tx, event_chain) = kanal::unbounded_async();
34 let (tx_push, rx_push) =
35 kanal::bounded_async::<(WorkerId, ProgressEntry, Bytes)>(options.push_queue_cap);
36 let tx_clone = tx.clone();
37 let push_handle = tokio::spawn(async move {
38 while let Ok((id, spin, data)) = rx_push.recv().await {
39 poll_ok!(
40 pusher.push(spin.clone(), &data).await,
41 id @ tx_clone => PushError,
42 options.retry_gap
43 );
44 let _ = tx_clone.send(Event::PushProgress(id, spin)).await;
45 }
46 poll_ok!(
47 pusher.flush().await,
48 tx_clone => FlushError,
49 options.retry_gap
50 );
51 });
52 let executor: Arc<TokioExecutor<R, W::Error>> = Arc::new(TokioExecutor {
53 tx,
54 tx_push,
55 puller,
56 retry_gap: options.retry_gap,
57 id: Arc::new(AtomicUsize::new(0)),
58 min_chunk_size: options.min_chunk_size,
59 });
60 let task_list = Arc::new(TaskList::run(&options.download_chunks[..], executor));
61 task_list
62 .clone()
63 .set_threads(options.concurrent, options.min_chunk_size);
64 DownloadResult::new(
65 event_chain,
66 push_handle,
67 None,
68 Some(Arc::downgrade(&task_list)),
69 )
70}
71
72#[derive(Clone)]
73pub struct TokioHandle(AbortHandle);
74impl Handle for TokioHandle {
75 type Output = ();
76 fn abort(&mut self) -> Self::Output {
77 self.0.abort();
78 }
79}
80#[derive(Debug)]
81pub struct TokioExecutor<R, WE>
82where
83 R: RandPuller + 'static,
84 WE: Send + 'static,
85{
86 tx: kanal::AsyncSender<Event<R::Error, WE>>,
87 tx_push: kanal::AsyncSender<(WorkerId, ProgressEntry, Bytes)>,
88 puller: R,
89 retry_gap: Duration,
90 id: Arc<AtomicUsize>,
91 min_chunk_size: NonZeroU64,
92}
93impl<R, WE> Executor for TokioExecutor<R, WE>
94where
95 R: RandPuller + 'static + Sync,
96 WE: Send + 'static,
97{
98 type Handle = TokioHandle;
99 fn execute(self: Arc<Self>, task: Arc<Task>, task_list: Arc<TaskList<Self>>) -> Self::Handle {
100 let id = self.id.fetch_add(1, Ordering::SeqCst);
101 let mut puller = self.puller.clone();
102 let handle = tokio::spawn(async move {
103 'steal_task: loop {
104 let mut start = task.start();
105 if start >= task.end() {
106 if task_list.steal(&task, self.min_chunk_size) {
107 continue;
108 } else {
109 break;
110 }
111 }
112 let _ = self.tx.send(Event::Pulling(id)).await;
113 let download_range = start..task.end();
114 let mut stream = loop {
115 match puller.pull(&download_range).await {
116 Ok(t) => break t,
117 Err((e, retry_gap)) => {
118 let _ = self.tx.send(Event::PullError(id, e)).await;
119 tokio::time::sleep(retry_gap.unwrap_or(self.retry_gap)).await;
120 }
121 }
122 };
123 loop {
124 match stream.try_next().await {
125 Ok(Some(mut chunk)) => {
126 let len = chunk.len() as u64;
127 task.fetch_add_start(len);
128 let range_start = start;
129 start += len;
130 let range_end = start.min(task.end());
131 if range_start >= range_end {
132 continue 'steal_task;
133 }
134 let span = range_start..range_end;
135 chunk.truncate(span.total() as usize);
136 let _ = self.tx.send(Event::PullProgress(id, span.clone())).await;
137 self.tx_push.send((id, span, chunk)).await.unwrap();
138 }
139 Ok(None) => break,
140 Err((e, retry_gap)) => {
141 let _ = self.tx.send(Event::PullError(id, e)).await;
142 tokio::time::sleep(retry_gap.unwrap_or(self.retry_gap)).await;
143 }
144 }
145 }
146 }
147 task_list.remove(&task);
148 let _ = self.tx.send(Event::Finished(id)).await;
149 });
150 TokioHandle(handle.abort_handle())
151 }
152}
153
154#[cfg(test)]
155mod tests {
156 extern crate std;
157 use super::*;
158 use crate::{
159 MergeProgress, ProgressEntry,
160 mem::MemPusher,
161 mock::{MockPuller, build_mock_data},
162 };
163 use alloc::vec;
164 use core::num::NonZero;
165 use std::dbg;
166
167 #[tokio::test]
168 async fn test_concurrent_download() {
169 let mock_data = build_mock_data(3 * 1024);
170 let puller = MockPuller::new(&mock_data);
171 let pusher = MemPusher::with_capacity(mock_data.len());
172 #[allow(clippy::single_range_in_vec_init)]
173 let download_chunks = vec![0..mock_data.len() as u64];
174 let result = download_multi(
175 puller,
176 pusher.clone(),
177 DownloadOptions {
178 concurrent: NonZero::new(32).unwrap(),
179 retry_gap: Duration::from_secs(1),
180 push_queue_cap: 1024,
181 download_chunks: download_chunks.clone(),
182 min_chunk_size: NonZero::new(1).unwrap(),
183 },
184 )
185 .await;
186
187 let mut pull_progress: Vec<ProgressEntry> = Vec::new();
188 let mut push_progress: Vec<ProgressEntry> = Vec::new();
189 let mut pull_ids = [false; 32];
190 let mut push_ids = [false; 32];
191 while let Ok(e) = result.event_chain.recv().await {
192 match e {
193 Event::PullProgress(id, p) => {
194 pull_ids[id] = true;
195 pull_progress.merge_progress(p);
196 }
197 Event::PushProgress(id, p) => {
198 push_ids[id] = true;
199 push_progress.merge_progress(p);
200 }
201 _ => {}
202 }
203 }
204 dbg!(&pull_progress);
205 dbg!(&push_progress);
206 assert_eq!(pull_progress, download_chunks);
207 assert_eq!(push_progress, download_chunks);
208 assert_eq!(pull_ids, [true; 32]);
209 assert_eq!(push_ids, [true; 32]);
210
211 result.join().await.unwrap();
212 assert_eq!(&**pusher.receive.lock(), mock_data);
213 }
214}