1use crate::{DownloadResult, Event, ProgressEntry, Puller, PullerError, Pusher, WorkerId};
2use bytes::Bytes;
3use core::{
4 sync::atomic::{AtomicUsize, Ordering},
5 time::Duration,
6};
7use crossfire::{MAsyncTx, MTx, mpmc, mpsc};
8use fast_steal::{Executor, Handle, Task, TaskQueue};
9use futures::TryStreamExt;
10use std::sync::Arc;
11use tokio::task::AbortHandle;
12
13#[derive(Debug, Clone)]
14pub struct DownloadOptions<I: Iterator<Item = ProgressEntry>> {
15 pub download_chunks: I,
16 pub concurrent: usize,
17 pub retry_gap: Duration,
18 pub pull_timeout: Duration,
19 pub push_queue_cap: usize,
20 pub min_chunk_size: u64,
21 pub max_speculative: usize,
22}
23
24pub fn download_multi<R: Puller, W: Pusher, I: Iterator<Item = ProgressEntry>>(
25 puller: R,
26 mut pusher: W,
27 options: DownloadOptions<I>,
28) -> DownloadResult<TokioExecutor<R, W::Error>, R::Error, W::Error> {
29 let (tx, event_chain) = mpmc::unbounded_async();
30 let (tx_push, rx_push) = mpsc::bounded_async(options.push_queue_cap);
31 let tx_clone = tx.clone();
32 let rx_push = rx_push.into_blocking();
33 let push_handle = tokio::task::spawn_blocking(move || {
34 while let Ok((id, spin, mut data)) = rx_push.recv() {
35 loop {
36 match pusher.push(&spin, data) {
37 Ok(_) => break,
38 Err((err, bytes)) => {
39 data = bytes;
40 let _ = tx_clone.send(Event::PushError(id, err));
41 }
42 }
43 std::thread::sleep(options.retry_gap);
44 }
45 let _ = tx_clone.send(Event::PushProgress(id, spin));
46 }
47 loop {
48 match pusher.flush() {
49 Ok(_) => break,
50 Err(err) => {
51 let _ = tx_clone.send(Event::FlushError(err));
52 }
53 }
54 std::thread::sleep(options.retry_gap);
55 }
56 });
57 let executor: Arc<TokioExecutor<R, W::Error>> = Arc::new(TokioExecutor {
58 tx,
59 tx_push,
60 puller,
61 pull_timeout: options.pull_timeout,
62 retry_gap: options.retry_gap,
63 id: AtomicUsize::new(0),
64 min_chunk_size: options.min_chunk_size,
65 max_speculative: options.max_speculative,
66 });
67 let task_queue = TaskQueue::new(options.download_chunks);
68 task_queue.set_threads(
69 options.concurrent,
70 options.min_chunk_size,
71 Some(executor.as_ref()),
72 );
73 DownloadResult::new(
74 event_chain,
75 push_handle,
76 None,
77 Some((Arc::downgrade(&executor), task_queue)),
78 )
79}
80
81#[derive(Clone)]
82pub struct TokioHandle {
83 id: usize,
84 abort_handle: AbortHandle,
85}
86impl Handle for TokioHandle {
87 type Output = ();
88 type Id = usize;
89 fn abort(&mut self) -> Self::Output {
90 self.abort_handle.abort();
91 }
92 fn is_self(&mut self, id: &Self::Id) -> bool {
93 self.id == *id
94 }
95}
96#[derive(Debug)]
97pub struct TokioExecutor<R, WE>
98where
99 R: Puller,
100 WE: Send + Unpin + 'static,
101{
102 tx: MTx<mpmc::List<Event<R::Error, WE>>>,
103 tx_push: MAsyncTx<mpsc::Array<(WorkerId, ProgressEntry, Bytes)>>,
104 puller: R,
105 retry_gap: Duration,
106 pull_timeout: Duration,
107 id: AtomicUsize,
108 min_chunk_size: u64,
109 max_speculative: usize,
110}
111impl<R, WE> Executor for TokioExecutor<R, WE>
112where
113 R: Puller,
114 WE: Send + Unpin + 'static,
115{
116 type Handle = TokioHandle;
117 fn execute(&self, mut task: Task, task_queue: TaskQueue<Self::Handle>) -> Self::Handle {
118 let id = self.id.fetch_add(1, Ordering::SeqCst);
119 let mut puller = self.puller.clone();
120 let min_chunk_size = self.min_chunk_size;
121 let pull_timeout = self.pull_timeout;
122 let cfg_retry_gap = self.retry_gap;
123 let max_speculative = self.max_speculative;
124 let tx = self.tx.clone();
125 let tx_push = self.tx_push.clone();
126 let handle = tokio::spawn(async move {
127 'task: loop {
128 let mut start = task.start();
129 if start >= task.end() {
130 if task_queue.steal(&mut task, min_chunk_size, max_speculative) {
131 continue 'task;
132 } else {
133 break;
134 }
135 }
136 let _ = tx.send(Event::Pulling(id));
137 let download_range = start..task.end();
138 let mut stream = loop {
139 match puller.pull(Some(&download_range)).await {
140 Ok(t) => break t,
141 Err((e, retry_gap)) => {
142 let _ = tx.send(Event::PullError(id, e));
143 tokio::time::sleep(retry_gap.unwrap_or(cfg_retry_gap)).await;
144 }
145 }
146 };
147 loop {
148 match tokio::time::timeout(pull_timeout, stream.try_next()).await {
149 Ok(Ok(Some(mut chunk))) => {
150 if chunk.is_empty() {
151 continue;
152 }
153 let len = chunk.len() as u64;
154 let Ok(span) = task.safe_add_start(start, len) else {
155 start += len;
156 continue;
157 };
158 if span.end >= task.end() {
159 task_queue.cancel_task(&task, &id);
160 }
161 chunk = chunk
162 .slice((span.start - start) as usize..(span.end - start) as usize);
163 start = span.end;
164 let _ = tx.send(Event::PullProgress(id, span.clone()));
165 let tx_push = tx_push.clone();
166 let _ = tokio::spawn(async move {
167 tx_push.send((id, span, chunk)).await.unwrap();
168 })
169 .await;
170 if start >= task.end() {
171 continue 'task;
172 }
173 }
174 Ok(Ok(None)) => continue 'task,
175 Ok(Err((e, retry_gap))) => {
176 let is_irrecoverable = e.is_irrecoverable();
177 let _ = tx.send(Event::PullError(id, e));
178 tokio::time::sleep(retry_gap.unwrap_or(cfg_retry_gap)).await;
179 if is_irrecoverable {
180 continue 'task;
181 }
182 }
183 Err(_) => {
184 let _ = tx.send(Event::PullTimeout(id));
185 drop(stream);
186 puller = puller.clone();
187 continue 'task;
188 }
189 }
190 }
191 }
192 let _ = tx.send(Event::Finished(id));
193 });
194 TokioHandle {
195 id,
196 abort_handle: handle.abort_handle(),
197 }
198 }
199}
200
201#[cfg(test)]
202mod tests {
203 use vec::Vec;
204
205 use super::*;
206 use crate::{
207 Merge, ProgressEntry,
208 mem::MemPusher,
209 mock::{MockPuller, build_mock_data},
210 };
211 use std::{dbg, vec};
212
213 #[tokio::test]
214 async fn test_concurrent_download() {
215 let mock_data = build_mock_data(3 * 1024);
216 let puller = MockPuller::new(&mock_data);
217 let pusher = MemPusher::with_capacity(mock_data.len());
218 #[allow(clippy::single_range_in_vec_init)]
219 let download_chunks = vec![0..mock_data.len() as u64];
220 let result = download_multi(
221 puller,
222 pusher.clone(),
223 DownloadOptions {
224 concurrent: 32,
225 retry_gap: Duration::from_secs(1),
226 push_queue_cap: 1024,
227 download_chunks: download_chunks.iter().cloned(),
228 pull_timeout: Duration::from_secs(5),
229 min_chunk_size: 1,
230 max_speculative: 3,
231 },
232 );
233
234 let mut pull_progress: Vec<ProgressEntry> = Vec::new();
235 let mut push_progress: Vec<ProgressEntry> = Vec::new();
236 let mut pull_ids = [false; 32];
237 let mut push_ids = [false; 32];
238 while let Ok(e) = result.event_chain.recv().await {
239 match e {
240 Event::PullProgress(id, p) => {
241 pull_ids[id] = true;
242 pull_progress.merge_progress(p);
243 }
244 Event::PushProgress(id, p) => {
245 push_ids[id] = true;
246 push_progress.merge_progress(p);
247 }
248 _ => {}
249 }
250 }
251 dbg!(&pull_progress);
252 dbg!(&push_progress);
253 assert_eq!(pull_progress, download_chunks);
254 assert_eq!(push_progress, download_chunks);
255 assert_eq!(pull_ids, [true; 32]);
256 assert_eq!(push_ids, [true; 32]);
257
258 result.join().await.unwrap();
259 assert_eq!(&**pusher.receive.lock(), mock_data);
260 }
261}