1use crate::{DownloadResult, Event, ProgressEntry, Puller, PullerError, Pusher, WorkerId};
2use bytes::Bytes;
3use core::{
4 sync::atomic::{AtomicUsize, Ordering},
5 time::Duration,
6};
7use crossfire::{MAsyncTx, MTx, mpmc, mpsc};
8use fast_steal::{Executor, Handle, Task, TaskQueue};
9use futures::TryStreamExt;
10use std::sync::Arc;
11use tokio::sync::Notify;
12
13#[derive(Debug, Clone)]
14pub struct DownloadOptions<I: Iterator<Item = ProgressEntry>> {
15 pub download_chunks: I,
16 pub concurrent: usize,
17 pub retry_gap: Duration,
18 pub pull_timeout: Duration,
19 pub push_queue_cap: usize,
20 pub min_chunk_size: u64,
21 pub max_speculative: usize,
22}
23
24pub fn download_multi<R: Puller, W: Pusher, I: Iterator<Item = ProgressEntry>>(
27 puller: R,
28 mut pusher: W,
29 options: DownloadOptions<I>,
30) -> DownloadResult<TokioExecutor<R, W::Error>, R::Error, W::Error> {
31 let (tx, event_chain) = mpmc::unbounded_async();
32 let (tx_push, rx_push) = mpsc::bounded_async(options.push_queue_cap);
33 let tx_clone = tx.clone();
34 let rx_push = rx_push.into_blocking();
35 let push_handle = tokio::task::spawn_blocking(move || {
36 while let Ok((id, spin, mut data)) = rx_push.recv() {
37 loop {
38 match pusher.push(&spin, data) {
39 Ok(()) => break,
40 Err((err, bytes)) => {
41 data = bytes;
42 let _ = tx_clone.send(Event::PushError(id, err));
43 }
44 }
45 std::thread::sleep(options.retry_gap);
46 }
47 let _ = tx_clone.send(Event::PushProgress(id, spin));
48 }
49 loop {
50 match pusher.flush() {
51 Ok(()) => break,
52 Err(err) => {
53 let _ = tx_clone.send(Event::FlushError(err));
54 }
55 }
56 std::thread::sleep(options.retry_gap);
57 }
58 });
59 let executor: Arc<TokioExecutor<R, W::Error>> = Arc::new(TokioExecutor {
60 tx,
61 tx_push,
62 puller,
63 pull_timeout: options.pull_timeout,
64 retry_gap: options.retry_gap,
65 id: AtomicUsize::new(0),
66 min_chunk_size: options.min_chunk_size,
67 max_speculative: options.max_speculative,
68 });
69 let task_queue = TaskQueue::new(options.download_chunks);
70 #[allow(clippy::unwrap_used)]
71 task_queue
72 .set_threads(
73 options.concurrent,
74 options.min_chunk_size,
75 Some(executor.as_ref()),
76 )
77 .unwrap();
78 DownloadResult::new(
79 event_chain,
80 push_handle,
81 None,
82 Some((Arc::downgrade(&executor), task_queue)),
83 )
84}
85
86#[derive(Debug, Clone)]
87pub struct TokioHandle {
88 id: usize,
89 notify: Arc<Notify>,
90}
91impl Handle for TokioHandle {
92 type Output = ();
93 type Id = usize;
94 fn abort(&mut self) -> Self::Output {
95 self.notify.notify_one();
96 }
97 fn is_self(&mut self, id: &Self::Id) -> bool {
98 self.id == *id
99 }
100}
101#[derive(Debug)]
102pub struct TokioExecutor<R, WE>
103where
104 R: Puller,
105 WE: Send + Unpin + 'static,
106{
107 tx: MTx<mpmc::List<Event<R::Error, WE>>>,
108 tx_push: MAsyncTx<mpsc::Array<(WorkerId, ProgressEntry, Bytes)>>,
109 puller: R,
110 retry_gap: Duration,
111 pull_timeout: Duration,
112 id: AtomicUsize,
113 min_chunk_size: u64,
114 max_speculative: usize,
115}
116impl<R, WE> Executor for TokioExecutor<R, WE>
117where
118 R: Puller,
119 WE: Send + Unpin + 'static,
120{
121 type Handle = TokioHandle;
122 fn execute(&self, mut task: Task, task_queue: TaskQueue<Self::Handle>) -> Self::Handle {
123 let id = self.id.fetch_add(1, Ordering::SeqCst);
124 let mut puller = self.puller.clone();
125 let min_chunk_size = self.min_chunk_size;
126 let pull_timeout = self.pull_timeout;
127 let cfg_retry_gap = self.retry_gap;
128 let max_speculative = self.max_speculative;
129 let tx = self.tx.clone();
130 let tx_push = self.tx_push.clone();
131 let notify = Arc::new(Notify::new());
132 let notify_clone = notify.clone();
133 tokio::spawn(async move {
134 'task: loop {
135 let mut start = task.start();
136 if start >= task.end() {
137 if task_queue.steal(&mut task, min_chunk_size, max_speculative) {
138 continue 'task;
139 }
140 break;
141 }
142 let _ = tx.send(Event::Pulling(id));
143 let download_range = start..task.end();
144 let mut stream = loop {
145 let t = tokio::select! {
146 () = notify.notified() => break 'task,
147 t = puller.pull(Some(&download_range)) => t
148 };
149 match t {
150 Ok(t) => break t,
151 Err((e, retry_gap)) => {
152 let _ = tx.send(Event::PullError(id, e));
153 tokio::select! {
154 () = notify.notified() => break 'task,
155 () = tokio::time::sleep(retry_gap.unwrap_or(cfg_retry_gap)) => {}
156 };
157 }
158 }
159 };
160 loop {
161 let t = tokio::select! {
162 () = notify.notified() => break 'task,
163 t = tokio::time::timeout(pull_timeout, stream.try_next()) => t
164 };
165 match t {
166 Ok(Ok(Some(mut chunk))) => {
167 if chunk.is_empty() {
168 continue;
169 }
170 let len = chunk.len() as u64;
171 let Ok(span) = task.safe_add_start(start, len) else {
172 start += len;
173 continue;
174 };
175 if span.end >= task.end() {
176 task_queue.cancel_task(&task, &id);
177 }
178 #[allow(clippy::cast_possible_truncation)]
179 let slice_span =
180 (span.start - start) as usize..(span.end - start) as usize;
181 chunk = chunk.slice(slice_span);
182 start = span.end;
183 let _ = tx.send(Event::PullProgress(id, span.clone()));
184 let _ = tx_push.send((id, span, chunk)).await;
185 if start >= task.end() {
186 continue 'task;
187 }
188 }
189 Ok(Ok(None)) => continue 'task,
190 Ok(Err((e, retry_gap))) => {
191 let is_irrecoverable = e.is_irrecoverable();
192 let _ = tx.send(Event::PullError(id, e));
193 tokio::select! {
194 () = notify.notified() => break 'task,
195 () = tokio::time::sleep(retry_gap.unwrap_or(cfg_retry_gap)) => {}
196 };
197 if is_irrecoverable {
198 continue 'task;
199 }
200 }
201 Err(_) => {
202 let _ = tx.send(Event::PullTimeout(id));
203 drop(stream);
204 puller = puller.clone();
205 continue 'task;
206 }
207 }
208 }
209 }
210 let _ = tx.send(Event::Finished(id));
211 });
212 TokioHandle {
213 id,
214 notify: notify_clone,
215 }
216 }
217}
218
219#[cfg(test)]
220#[cfg(feature = "mem")]
221mod tests {
222 use vec::Vec;
223
224 use super::*;
225 use crate::{
226 Merge, ProgressEntry,
227 mem::MemPusher,
228 mock::{MockPuller, build_mock_data},
229 };
230 use std::{dbg, vec};
231
232 #[tokio::test(flavor = "multi_thread")]
233 async fn test_concurrent_download() {
234 let mock_data = build_mock_data(3 * 1024);
235 let puller = MockPuller::new(&mock_data);
236 let pusher = MemPusher::with_capacity(mock_data.len());
237 #[allow(clippy::single_range_in_vec_init)]
238 let download_chunks = vec![0..mock_data.len() as u64];
239 let result = download_multi(
240 puller,
241 pusher.clone(),
242 DownloadOptions {
243 concurrent: 32,
244 retry_gap: Duration::from_secs(1),
245 push_queue_cap: 1024,
246 download_chunks: download_chunks.iter().cloned(),
247 pull_timeout: Duration::from_secs(5),
248 min_chunk_size: 1,
249 max_speculative: 3,
250 },
251 );
252
253 let mut pull_progress: Vec<ProgressEntry> = Vec::new();
254 let mut push_progress: Vec<ProgressEntry> = Vec::new();
255 let mut pull_ids = [false; 32];
256 let mut push_ids = [false; 32];
257 while let Ok(e) = result.event_chain.recv().await {
258 match e {
259 Event::PullProgress(id, p) => {
260 pull_ids[id] = true;
261 pull_progress.merge_progress(p);
262 }
263 Event::PushProgress(id, p) => {
264 push_ids[id] = true;
265 push_progress.merge_progress(p);
266 }
267 _ => {}
268 }
269 }
270 dbg!(&pull_progress);
271 dbg!(&push_progress);
272 assert_eq!(pull_progress, download_chunks);
273 assert_eq!(push_progress, download_chunks);
274 assert!(pull_ids.iter().any(|x| *x));
275 assert!(push_ids.iter().any(|x| *x));
276
277 #[allow(clippy::unwrap_used)]
278 result.join().await.unwrap();
279 assert_eq!(&**pusher.receive.lock(), mock_data);
280 }
281}