1use crate::{DownloadResult, Event, ProgressEntry, Puller, PullerError, Pusher, WorkerId};
2use bytes::Bytes;
3use core::{
4 sync::atomic::{AtomicUsize, Ordering},
5 time::Duration,
6};
7use crossfire::{MAsyncTx, MTx, mpmc, mpsc};
8use fast_steal::{Executor, Handle, Task, TaskQueue};
9use futures::TryStreamExt;
10use std::sync::Arc;
11use tokio::sync::Notify;
12
13#[derive(Debug, Clone)]
14pub struct DownloadOptions<I: Iterator<Item = ProgressEntry>> {
15 pub download_chunks: I,
16 pub concurrent: usize,
17 pub retry_gap: Duration,
18 pub pull_timeout: Duration,
19 pub push_queue_cap: usize,
20 pub min_chunk_size: u64,
21 pub max_speculative: usize,
22}
23
24pub fn download_multi<R: Puller, W: Pusher, I: Iterator<Item = ProgressEntry>>(
27 puller: R,
28 mut pusher: W,
29 options: DownloadOptions<I>,
30) -> DownloadResult<TokioExecutor<R, W::Error>, R::Error, W::Error> {
31 let (tx, event_chain) = mpmc::unbounded_async();
32 let (tx_push, rx_push) =
33 mpsc::bounded_async::<(WorkerId, ProgressEntry, Bytes)>(options.push_queue_cap);
34 let tx_clone = tx.clone();
35 let rx_push = rx_push.into_blocking();
36 let push_handle = tokio::task::spawn_blocking(move || {
37 while let Ok((id, spin, mut data)) = rx_push.recv() {
38 loop {
39 let _ = tx_clone.send(Event::Pushing(id, spin.clone()));
40 match pusher.push(&spin, data) {
41 Ok(()) => break,
42 Err((err, bytes)) => {
43 data = bytes;
44 let _ = tx_clone.send(Event::PushError(id, spin.clone(), err));
45 }
46 }
47 std::thread::sleep(options.retry_gap);
48 }
49 let _ = tx_clone.send(Event::PushProgress(id, spin));
50 }
51 loop {
52 let _ = tx_clone.send(Event::Flushing);
53 match pusher.flush() {
54 Ok(()) => break,
55 Err(err) => {
56 let _ = tx_clone.send(Event::FlushError(err));
57 }
58 }
59 std::thread::sleep(options.retry_gap);
60 }
61 });
62 let executor: Arc<TokioExecutor<R, W::Error>> = Arc::new(TokioExecutor {
63 tx,
64 tx_push,
65 puller,
66 pull_timeout: options.pull_timeout,
67 retry_gap: options.retry_gap,
68 id: AtomicUsize::new(0),
69 min_chunk_size: options.min_chunk_size,
70 max_speculative: options.max_speculative,
71 });
72 let task_queue = TaskQueue::new(options.download_chunks);
73 #[allow(clippy::unwrap_used)]
74 task_queue
75 .set_threads(
76 options.concurrent,
77 options.min_chunk_size,
78 Some(executor.as_ref()),
79 )
80 .unwrap();
81 DownloadResult::new(
82 event_chain,
83 push_handle,
84 None,
85 Some((Arc::downgrade(&executor), task_queue)),
86 )
87}
88
89#[derive(Debug, Clone)]
90pub struct TokioHandle {
91 id: usize,
92 notify: Arc<Notify>,
93}
94impl Handle for TokioHandle {
95 type Output = ();
96 type Id = usize;
97 fn abort(&mut self) -> Self::Output {
98 self.notify.notify_one();
99 }
100 fn is_self(&mut self, id: &Self::Id) -> bool {
101 self.id == *id
102 }
103}
104#[derive(Debug)]
105pub struct TokioExecutor<R, WE>
106where
107 R: Puller,
108 WE: Send + Unpin + 'static,
109{
110 tx: MTx<mpmc::List<Event<R::Error, WE>>>,
111 tx_push: MAsyncTx<mpsc::Array<(WorkerId, ProgressEntry, Bytes)>>,
112 puller: R,
113 retry_gap: Duration,
114 pull_timeout: Duration,
115 id: AtomicUsize,
116 min_chunk_size: u64,
117 max_speculative: usize,
118}
119impl<R, WE> Executor for TokioExecutor<R, WE>
120where
121 R: Puller,
122 WE: Send + Unpin + 'static,
123{
124 type Handle = TokioHandle;
125 #[allow(clippy::too_many_lines)]
126 fn execute(&self, mut task: Task, task_queue: TaskQueue<Self::Handle>) -> Self::Handle {
127 let id = self.id.fetch_add(1, Ordering::SeqCst);
128 let mut puller = self.puller.clone();
129 let min_chunk_size = self.min_chunk_size;
130 let pull_timeout = self.pull_timeout;
131 let cfg_retry_gap = self.retry_gap;
132 let max_speculative = self.max_speculative;
133 let tx = self.tx.clone();
134 let tx_push = self.tx_push.clone();
135 let notify = Arc::new(Notify::new());
136 let notify_clone = notify.clone();
137 tokio::spawn(async move {
138 'task: loop {
139 let mut start = task.start();
140 if start >= task.end() {
141 if task_queue.steal(&id, &mut task, min_chunk_size, max_speculative) {
142 tokio::select! {
143 biased;
144 () = notify.notified() => {}
145 () = async {} => {}
146 }
147 continue 'task;
148 }
149 break;
150 }
151 let _ = tx.send(Event::Pulling(id));
152 let download_range = start..task.end();
153 let mut stream = loop {
154 let t = tokio::select! {
155 () = notify.notified() => break 'task,
156 t = puller.pull(Some(&download_range)) => t
157 };
158 match t {
159 Ok(t) => break t,
160 Err((e, retry_gap)) => {
161 let _ = tx.send(Event::PullError(id, e));
162 tokio::select! {
163 () = notify.notified() => break 'task,
164 () = tokio::time::sleep(retry_gap.unwrap_or(cfg_retry_gap)) => {}
165 };
166 }
167 }
168 };
169 tokio::pin! {
170 let sleep = tokio::time::sleep(pull_timeout);
171 }
172 loop {
173 sleep
174 .as_mut()
175 .reset(tokio::time::Instant::now() + pull_timeout);
176 let t = tokio::select! {
177 () = notify.notified() => break 'task,
178 () = &mut sleep => {
179 let _ = tx.send(Event::PullTimeout(id));
180 drop(stream);
181 puller = puller.clone();
182 continue 'task;
183 },
184 t = stream.try_next() => t,
185 };
186 match t {
187 Ok(Some(mut chunk)) => {
188 if chunk.is_empty() {
189 continue;
190 }
191 let len = chunk.len() as u64;
192 let Ok(span) = task.safe_add_start(start, len) else {
193 start += len;
194 continue;
195 };
196 if span.end >= task.end() {
197 task_queue.cancel_task(&task, &id);
198 }
199 #[allow(clippy::cast_possible_truncation)]
200 let slice_span =
201 (span.start - start) as usize..(span.end - start) as usize;
202 chunk = chunk.slice(slice_span);
203 start = span.end;
204 let _ = tx.send(Event::PullProgress(id, span.clone()));
205 let _ = tx_push.send((id, span, chunk)).await;
206 if start >= task.end() {
207 continue 'task;
208 }
209 }
210 Ok(None) => continue 'task,
211 Err((e, retry_gap)) => {
212 let is_irrecoverable = e.is_irrecoverable();
213 let _ = tx.send(Event::PullError(id, e));
214 tokio::select! {
215 () = notify.notified() => break 'task,
216 () = tokio::time::sleep(retry_gap.unwrap_or(cfg_retry_gap)) => {}
217 };
218 if is_irrecoverable {
219 continue 'task;
220 }
221 }
222 }
223 }
224 }
225 let _ = tx.send(Event::Finished(id));
226 });
227 TokioHandle {
228 id,
229 notify: notify_clone,
230 }
231 }
232}
233
234#[cfg(test)]
235#[cfg(feature = "mem")]
236mod tests {
237 use vec::Vec;
238
239 use super::*;
240 use crate::{
241 Merge, ProgressEntry,
242 mem::MemPusher,
243 mock::{MockPuller, build_mock_data},
244 };
245 use std::{dbg, vec};
246
247 #[tokio::test(flavor = "multi_thread")]
248 async fn test_concurrent_download() {
249 let mock_data = build_mock_data(3 * 1024);
250 let puller = MockPuller::new(&mock_data);
251 let pusher = MemPusher::with_capacity(mock_data.len());
252 #[allow(clippy::single_range_in_vec_init)]
253 let download_chunks = vec![0..mock_data.len() as u64];
254 let result = download_multi(
255 puller,
256 pusher.clone(),
257 DownloadOptions {
258 concurrent: 32,
259 retry_gap: Duration::from_secs(1),
260 push_queue_cap: 1024,
261 download_chunks: download_chunks.iter().cloned(),
262 pull_timeout: Duration::from_secs(5),
263 min_chunk_size: 1,
264 max_speculative: 3,
265 },
266 );
267
268 let mut pull_progress: Vec<ProgressEntry> = Vec::new();
269 let mut push_progress: Vec<ProgressEntry> = Vec::new();
270 let mut pull_ids = [false; 32];
271 let mut push_ids = [false; 32];
272 while let Ok(e) = result.event_chain.recv().await {
273 match e {
274 Event::PullProgress(id, p) => {
275 pull_ids[id] = true;
276 pull_progress.merge_progress(p);
277 }
278 Event::PushProgress(id, p) => {
279 push_ids[id] = true;
280 push_progress.merge_progress(p);
281 }
282 _ => {}
283 }
284 }
285 dbg!(&pull_progress);
286 dbg!(&push_progress);
287 assert_eq!(pull_progress, download_chunks);
288 assert_eq!(push_progress, download_chunks);
289 assert!(pull_ids.iter().any(|x| *x));
290 assert!(push_ids.iter().any(|x| *x));
291
292 #[allow(clippy::unwrap_used)]
293 result.join().await.unwrap();
294 assert_eq!(&**pusher.receive.lock(), mock_data);
295 }
296}