1use crate::{DownloadResult, Event, ProgressEntry, Puller, PullerError, Pusher, WorkerId};
2use bytes::Bytes;
3use core::{
4 sync::atomic::{AtomicUsize, Ordering},
5 time::Duration,
6};
7use crossfire::{MAsyncTx, MTx, mpmc, mpsc};
8use fast_steal::{Executor, Handle, Task, TaskQueue};
9use futures::TryStreamExt;
10use std::sync::Arc;
11use tokio::sync::Notify;
12
13#[derive(Debug, Clone)]
14pub struct DownloadOptions<I: Iterator<Item = ProgressEntry>> {
15 pub download_chunks: I,
16 pub concurrent: usize,
17 pub retry_gap: Duration,
18 pub pull_timeout: Duration,
19 pub push_queue_cap: usize,
20 pub min_chunk_size: u64,
21 pub max_speculative: usize,
22}
23
24pub fn download_multi<R: Puller, W: Pusher, I: Iterator<Item = ProgressEntry>>(
27 puller: R,
28 mut pusher: W,
29 options: DownloadOptions<I>,
30) -> DownloadResult<TokioExecutor<R, W::Error>, R::Error, W::Error> {
31 let (tx, event_chain) = mpmc::unbounded_async();
32 let (tx_push, rx_push) = mpsc::bounded_async(options.push_queue_cap);
33 let tx_clone = tx.clone();
34 let rx_push = rx_push.into_blocking();
35 let push_handle = tokio::task::spawn_blocking(move || {
36 while let Ok((id, spin, mut data)) = rx_push.recv() {
37 loop {
38 match pusher.push(&spin, data) {
39 Ok(()) => break,
40 Err((err, bytes)) => {
41 data = bytes;
42 let _ = tx_clone.send(Event::PushError(id, err));
43 }
44 }
45 std::thread::sleep(options.retry_gap);
46 }
47 let _ = tx_clone.send(Event::PushProgress(id, spin));
48 }
49 loop {
50 match pusher.flush() {
51 Ok(()) => break,
52 Err(err) => {
53 let _ = tx_clone.send(Event::FlushError(err));
54 }
55 }
56 std::thread::sleep(options.retry_gap);
57 }
58 });
59 let executor: Arc<TokioExecutor<R, W::Error>> = Arc::new(TokioExecutor {
60 tx,
61 tx_push,
62 puller,
63 pull_timeout: options.pull_timeout,
64 retry_gap: options.retry_gap,
65 id: AtomicUsize::new(0),
66 min_chunk_size: options.min_chunk_size,
67 max_speculative: options.max_speculative,
68 });
69 let task_queue = TaskQueue::new(options.download_chunks);
70 #[allow(clippy::unwrap_used)]
71 task_queue
72 .set_threads(
73 options.concurrent,
74 options.min_chunk_size,
75 Some(executor.as_ref()),
76 )
77 .unwrap();
78 DownloadResult::new(
79 event_chain,
80 push_handle,
81 None,
82 Some((Arc::downgrade(&executor), task_queue)),
83 )
84}
85
86#[derive(Debug, Clone)]
87pub struct TokioHandle {
88 id: usize,
89 notify: Arc<Notify>,
90}
91impl Handle for TokioHandle {
92 type Output = ();
93 type Id = usize;
94 fn abort(&mut self) -> Self::Output {
95 self.notify.notify_one();
96 }
97 fn is_self(&mut self, id: &Self::Id) -> bool {
98 self.id == *id
99 }
100}
101#[derive(Debug)]
102pub struct TokioExecutor<R, WE>
103where
104 R: Puller,
105 WE: Send + Unpin + 'static,
106{
107 tx: MTx<mpmc::List<Event<R::Error, WE>>>,
108 tx_push: MAsyncTx<mpsc::Array<(WorkerId, ProgressEntry, Bytes)>>,
109 puller: R,
110 retry_gap: Duration,
111 pull_timeout: Duration,
112 id: AtomicUsize,
113 min_chunk_size: u64,
114 max_speculative: usize,
115}
116impl<R, WE> Executor for TokioExecutor<R, WE>
117where
118 R: Puller,
119 WE: Send + Unpin + 'static,
120{
121 type Handle = TokioHandle;
122 fn execute(&self, mut task: Task, task_queue: TaskQueue<Self::Handle>) -> Self::Handle {
123 let id = self.id.fetch_add(1, Ordering::SeqCst);
124 let mut puller = self.puller.clone();
125 let min_chunk_size = self.min_chunk_size;
126 let pull_timeout = self.pull_timeout;
127 let cfg_retry_gap = self.retry_gap;
128 let max_speculative = self.max_speculative;
129 let tx = self.tx.clone();
130 let tx_push = self.tx_push.clone();
131 let notify = Arc::new(Notify::new());
132 let notify_clone = notify.clone();
133 tokio::spawn(async move {
134 'task: loop {
135 let mut start = task.start();
136 if start >= task.end() {
137 if task_queue.steal(&mut task, min_chunk_size, max_speculative) {
138 continue 'task;
139 }
140 break;
141 }
142 let _ = tx.send(Event::Pulling(id));
143 let download_range = start..task.end();
144 let mut stream = loop {
145 let t = tokio::select! {
146 () = notify.notified() => break 'task,
147 t = puller.pull(Some(&download_range)) => t
148 };
149 match t {
150 Ok(t) => break t,
151 Err((e, retry_gap)) => {
152 let _ = tx.send(Event::PullError(id, e));
153 tokio::select! {
154 () = notify.notified() => break 'task,
155 () = tokio::time::sleep(retry_gap.unwrap_or(cfg_retry_gap)) => {}
156 };
157 }
158 }
159 };
160 tokio::pin! {
161 let sleep = tokio::time::sleep(pull_timeout);
162 }
163 loop {
164 sleep
165 .as_mut()
166 .reset(tokio::time::Instant::now() + pull_timeout);
167 let t = tokio::select! {
168 () = notify.notified() => break 'task,
169 () = &mut sleep => {
170 let _ = tx.send(Event::PullTimeout(id));
171 drop(stream);
172 puller = puller.clone();
173 continue 'task;
174 },
175 t = stream.try_next() => t,
176 };
177 match t {
178 Ok(Some(mut chunk)) => {
179 if chunk.is_empty() {
180 continue;
181 }
182 let len = chunk.len() as u64;
183 let Ok(span) = task.safe_add_start(start, len) else {
184 start += len;
185 continue;
186 };
187 if span.end >= task.end() {
188 task_queue.cancel_task(&task, &id);
189 }
190 #[allow(clippy::cast_possible_truncation)]
191 let slice_span =
192 (span.start - start) as usize..(span.end - start) as usize;
193 chunk = chunk.slice(slice_span);
194 start = span.end;
195 let _ = tx.send(Event::PullProgress(id, span.clone()));
196 let _ = tx_push.send((id, span, chunk)).await;
197 if start >= task.end() {
198 continue 'task;
199 }
200 }
201 Ok(None) => continue 'task,
202 Err((e, retry_gap)) => {
203 let is_irrecoverable = e.is_irrecoverable();
204 let _ = tx.send(Event::PullError(id, e));
205 tokio::select! {
206 () = notify.notified() => break 'task,
207 () = tokio::time::sleep(retry_gap.unwrap_or(cfg_retry_gap)) => {}
208 };
209 if is_irrecoverable {
210 continue 'task;
211 }
212 }
213 }
214 }
215 }
216 let _ = tx.send(Event::Finished(id));
217 });
218 TokioHandle {
219 id,
220 notify: notify_clone,
221 }
222 }
223}
224
225#[cfg(test)]
226#[cfg(feature = "mem")]
227mod tests {
228 use vec::Vec;
229
230 use super::*;
231 use crate::{
232 Merge, ProgressEntry,
233 mem::MemPusher,
234 mock::{MockPuller, build_mock_data},
235 };
236 use std::{dbg, vec};
237
238 #[tokio::test(flavor = "multi_thread")]
239 async fn test_concurrent_download() {
240 let mock_data = build_mock_data(3 * 1024);
241 let puller = MockPuller::new(&mock_data);
242 let pusher = MemPusher::with_capacity(mock_data.len());
243 #[allow(clippy::single_range_in_vec_init)]
244 let download_chunks = vec![0..mock_data.len() as u64];
245 let result = download_multi(
246 puller,
247 pusher.clone(),
248 DownloadOptions {
249 concurrent: 32,
250 retry_gap: Duration::from_secs(1),
251 push_queue_cap: 1024,
252 download_chunks: download_chunks.iter().cloned(),
253 pull_timeout: Duration::from_secs(5),
254 min_chunk_size: 1,
255 max_speculative: 3,
256 },
257 );
258
259 let mut pull_progress: Vec<ProgressEntry> = Vec::new();
260 let mut push_progress: Vec<ProgressEntry> = Vec::new();
261 let mut pull_ids = [false; 32];
262 let mut push_ids = [false; 32];
263 while let Ok(e) = result.event_chain.recv().await {
264 match e {
265 Event::PullProgress(id, p) => {
266 pull_ids[id] = true;
267 pull_progress.merge_progress(p);
268 }
269 Event::PushProgress(id, p) => {
270 push_ids[id] = true;
271 push_progress.merge_progress(p);
272 }
273 _ => {}
274 }
275 }
276 dbg!(&pull_progress);
277 dbg!(&push_progress);
278 assert_eq!(pull_progress, download_chunks);
279 assert_eq!(push_progress, download_chunks);
280 assert!(pull_ids.iter().any(|x| *x));
281 assert!(push_ids.iter().any(|x| *x));
282
283 #[allow(clippy::unwrap_used)]
284 result.join().await.unwrap();
285 assert_eq!(&**pusher.receive.lock(), mock_data);
286 }
287}