1extern crate std;
2use crate::{DownloadResult, Event, ProgressEntry, Puller, PullerError, Pusher, WorkerId};
3use bytes::Bytes;
4use core::{
5 sync::atomic::{AtomicUsize, Ordering},
6 time::Duration,
7};
8use crossfire::{MAsyncTx, MTx, mpmc, mpsc};
9use fast_steal::{Executor, Handle, Task, TaskQueue};
10use futures::TryStreamExt;
11use std::sync::Arc;
12use tokio::task::AbortHandle;
13
14#[derive(Debug, Clone)]
15pub struct DownloadOptions<'a, I: Iterator<Item = &'a ProgressEntry>> {
16 pub download_chunks: I,
17 pub concurrent: usize,
18 pub retry_gap: Duration,
19 pub pull_timeout: Duration,
20 pub push_queue_cap: usize,
21 pub min_chunk_size: u64,
22}
23
24pub fn download_multi<'a, R: Puller, W: Pusher, I: Iterator<Item = &'a ProgressEntry>>(
25 puller: R,
26 mut pusher: W,
27 options: DownloadOptions<'a, I>,
28) -> DownloadResult<TokioExecutor<R, W::Error>, R::Error, W::Error> {
29 let (tx, event_chain) = mpmc::unbounded_async();
30 let (tx_push, rx_push) = mpsc::bounded_async(options.push_queue_cap);
31 let tx_clone = tx.clone();
32 let rx_push = rx_push.into_blocking();
33 let push_handle = tokio::task::spawn_blocking(move || {
34 while let Ok((id, spin, mut data)) = rx_push.recv() {
35 loop {
36 match pusher.push(&spin, data) {
37 Ok(_) => break,
38 Err((err, bytes)) => {
39 data = bytes;
40 let _ = tx_clone.send(Event::PushError(id, err));
41 }
42 }
43 std::thread::sleep(options.retry_gap);
44 }
45 let _ = tx_clone.send(Event::PushProgress(id, spin));
46 }
47 loop {
48 match pusher.flush() {
49 Ok(_) => break,
50 Err(err) => {
51 let _ = tx_clone.send(Event::FlushError(err));
52 }
53 }
54 std::thread::sleep(options.retry_gap);
55 }
56 });
57 let executor: Arc<TokioExecutor<R, W::Error>> = Arc::new(TokioExecutor {
58 tx,
59 tx_push,
60 puller,
61 pull_timeout: options.pull_timeout,
62 retry_gap: options.retry_gap,
63 id: AtomicUsize::new(0),
64 min_chunk_size: options.min_chunk_size,
65 });
66 let task_queue = TaskQueue::new(options.download_chunks);
67 task_queue.set_threads(
68 options.concurrent,
69 options.min_chunk_size,
70 Some(executor.as_ref()),
71 );
72 DownloadResult::new(
73 event_chain,
74 push_handle,
75 None,
76 Some((Arc::downgrade(&executor), task_queue)),
77 )
78}
79
80#[derive(Clone)]
81pub struct TokioHandle(AbortHandle);
82impl Handle for TokioHandle {
83 type Output = ();
84 fn abort(&mut self) -> Self::Output {
85 self.0.abort();
86 }
87}
88#[derive(Debug)]
89pub struct TokioExecutor<R, WE>
90where
91 R: Puller,
92 WE: Send + Unpin + 'static,
93{
94 tx: MTx<mpmc::List<Event<R::Error, WE>>>,
95 tx_push: MAsyncTx<mpsc::Array<(WorkerId, ProgressEntry, Bytes)>>,
96 puller: R,
97 retry_gap: Duration,
98 pull_timeout: Duration,
99 id: AtomicUsize,
100 min_chunk_size: u64,
101}
102impl<R, WE> Executor for TokioExecutor<R, WE>
103where
104 R: Puller,
105 WE: Send + Unpin + 'static,
106{
107 type Handle = TokioHandle;
108 fn execute(&self, mut task: Task, task_queue: TaskQueue<Self::Handle>) -> Self::Handle {
109 let id = self.id.fetch_add(1, Ordering::SeqCst);
110 let mut puller = self.puller.clone();
111 let min_chunk_size = self.min_chunk_size;
112 let pull_timeout = self.pull_timeout;
113 let cfg_retry_gap = self.retry_gap;
114 let tx = self.tx.clone();
115 let tx_push = self.tx_push.clone();
116 let handle = tokio::spawn(async move {
117 loop {
118 let mut start = task.start();
119 if start >= task.end() {
120 if task_queue.steal(&mut task, min_chunk_size, true) {
121 continue;
122 } else {
123 break;
124 }
125 }
126 let _ = tx.send(Event::Pulling(id));
127 let download_range = start..task.end();
128 let mut stream = loop {
129 match puller.pull(Some(&download_range)).await {
130 Ok(t) => break t,
131 Err((e, retry_gap)) => {
132 let _ = tx.send(Event::PullError(id, e));
133 tokio::time::sleep(retry_gap.unwrap_or(cfg_retry_gap)).await;
134 }
135 }
136 };
137 loop {
138 match tokio::time::timeout(pull_timeout, stream.try_next()).await {
139 Ok(Ok(Some(mut chunk))) => {
140 if chunk.is_empty() {
141 continue;
142 }
143 let len = chunk.len() as u64;
144 let Ok(span) = task.safe_add_start(start, len) else {
145 break;
146 };
147 if span.end >= task.end() {
148 task_queue.cancel_tasks(&task);
149 }
150 chunk = chunk
151 .slice((span.start - start) as usize..(span.end - start) as usize);
152 start = span.end;
153 let _ = tx.send(Event::PullProgress(id, span.clone()));
154 let tx_push = tx_push.clone();
155 let _ = tokio::spawn(async move {
156 tx_push.send((id, span, chunk)).await.unwrap();
157 })
158 .await;
159 if start >= task.end() {
160 break;
161 }
162 }
163 Ok(Ok(None)) => break,
164 Ok(Err((e, retry_gap))) => {
165 let is_irrecoverable = e.is_irrecoverable();
166 let _ = tx.send(Event::PullError(id, e));
167 tokio::time::sleep(retry_gap.unwrap_or(cfg_retry_gap)).await;
168 if is_irrecoverable {
169 break;
170 }
171 }
172 Err(_) => {
173 let _ = tx.send(Event::PullTimeout(id));
174 drop(stream);
175 puller = puller.clone();
176 break;
177 }
178 }
179 }
180 }
181 let _ = tx.send(Event::Finished(id));
182 });
183 TokioHandle(handle.abort_handle())
184 }
185}
186
187#[cfg(test)]
188mod tests {
189 use vec::Vec;
190
191 use super::*;
192 use crate::{
193 Merge, ProgressEntry,
194 mem::MemPusher,
195 mock::{MockPuller, build_mock_data},
196 };
197 use std::{dbg, vec};
198
199 #[tokio::test]
200 async fn test_concurrent_download() {
201 let mock_data = build_mock_data(3 * 1024);
202 let puller = MockPuller::new(&mock_data);
203 let pusher = MemPusher::with_capacity(mock_data.len());
204 #[allow(clippy::single_range_in_vec_init)]
205 let download_chunks = vec![0..mock_data.len() as u64];
206 let result = download_multi(
207 puller,
208 pusher.clone(),
209 DownloadOptions {
210 concurrent: 32,
211 retry_gap: Duration::from_secs(1),
212 push_queue_cap: 1024,
213 download_chunks: download_chunks.iter(),
214 pull_timeout: Duration::from_secs(5),
215 min_chunk_size: 1,
216 },
217 );
218
219 let mut pull_progress: Vec<ProgressEntry> = Vec::new();
220 let mut push_progress: Vec<ProgressEntry> = Vec::new();
221 let mut pull_ids = [false; 32];
222 let mut push_ids = [false; 32];
223 while let Ok(e) = result.event_chain.recv().await {
224 match e {
225 Event::PullProgress(id, p) => {
226 pull_ids[id] = true;
227 pull_progress.merge_progress(p);
228 }
229 Event::PushProgress(id, p) => {
230 push_ids[id] = true;
231 push_progress.merge_progress(p);
232 }
233 _ => {}
234 }
235 }
236 dbg!(&pull_progress);
237 dbg!(&push_progress);
238 assert_eq!(pull_progress, download_chunks);
239 assert_eq!(push_progress, download_chunks);
240 assert_eq!(pull_ids, [true; 32]);
241 assert_eq!(push_ids, [true; 32]);
242
243 result.join().await.unwrap();
244 assert_eq!(&**pusher.receive.lock(), mock_data);
245 }
246}