1extern crate std;
2use crate::{DownloadResult, Event, ProgressEntry, Puller, PullerError, Pusher, Total, WorkerId};
3use bytes::Bytes;
4use core::{
5 sync::atomic::{AtomicUsize, Ordering},
6 time::Duration,
7};
8use crossfire::{MAsyncTx, MTx, mpmc, mpsc};
9use fast_steal::{Executor, Handle, Task, TaskQueue};
10use futures::TryStreamExt;
11use std::sync::Arc;
12use tokio::task::AbortHandle;
13
14#[derive(Debug, Clone)]
15pub struct DownloadOptions<'a, I: Iterator<Item = &'a ProgressEntry>> {
16 pub download_chunks: I,
17 pub concurrent: usize,
18 pub retry_gap: Duration,
19 pub push_queue_cap: usize,
20 pub min_chunk_size: u64,
21}
22
23pub fn download_multi<'a, R: Puller, W: Pusher, I: Iterator<Item = &'a ProgressEntry>>(
24 puller: R,
25 mut pusher: W,
26 options: DownloadOptions<'a, I>,
27) -> DownloadResult<TokioExecutor<R, W::Error>, R::Error, W::Error> {
28 let (tx, event_chain) = mpmc::unbounded_async();
29 let (tx_push, rx_push) = mpsc::bounded_async(options.push_queue_cap);
30 let tx_clone = tx.clone();
31 let rx_push = rx_push.into_blocking();
32 let push_handle = tokio::task::spawn_blocking(move || {
33 while let Ok((id, spin, mut data)) = rx_push.recv() {
34 loop {
35 match pusher.push(&spin, data) {
36 Ok(_) => break,
37 Err((err, bytes)) => {
38 data = bytes;
39 let _ = tx_clone.send(Event::PushError(id, err));
40 }
41 }
42 std::thread::sleep(options.retry_gap);
43 }
44 let _ = tx_clone.send(Event::PushProgress(id, spin));
45 }
46 loop {
47 match pusher.flush() {
48 Ok(_) => break,
49 Err(err) => {
50 let _ = tx_clone.send(Event::FlushError(err));
51 }
52 }
53 std::thread::sleep(options.retry_gap);
54 }
55 });
56 let executor: Arc<TokioExecutor<R, W::Error>> = Arc::new(TokioExecutor {
57 tx,
58 tx_push,
59 puller,
60 retry_gap: options.retry_gap,
61 id: AtomicUsize::new(0),
62 min_chunk_size: options.min_chunk_size,
63 });
64 let task_queue = TaskQueue::new(options.download_chunks);
65 task_queue.set_threads(
66 options.concurrent,
67 options.min_chunk_size,
68 Some(executor.as_ref()),
69 );
70 DownloadResult::new(
71 event_chain,
72 push_handle,
73 None,
74 Some((Arc::downgrade(&executor), task_queue)),
75 )
76}
77
78#[derive(Clone)]
79pub struct TokioHandle(AbortHandle);
80impl Handle for TokioHandle {
81 type Output = ();
82 fn abort(&mut self) -> Self::Output {
83 self.0.abort();
84 }
85}
86#[derive(Debug)]
87pub struct TokioExecutor<R, WE>
88where
89 R: Puller,
90 WE: Send + Unpin + 'static,
91{
92 tx: MTx<mpmc::List<Event<R::Error, WE>>>,
93 tx_push: MAsyncTx<mpsc::Array<(WorkerId, ProgressEntry, Bytes)>>,
94 puller: R,
95 retry_gap: Duration,
96 id: AtomicUsize,
97 min_chunk_size: u64,
98}
99impl<R, WE> Executor for TokioExecutor<R, WE>
100where
101 R: Puller,
102 WE: Send + Unpin + 'static,
103{
104 type Handle = TokioHandle;
105 fn execute(&self, task: Task, task_queue: TaskQueue<Self::Handle>) -> Self::Handle {
106 let id = self.id.fetch_add(1, Ordering::SeqCst);
107 let mut puller = self.puller.clone();
108 let min_chunk_size = self.min_chunk_size;
109 let cfg_retry_gap = self.retry_gap;
110 let tx = self.tx.clone();
111 let tx_push = self.tx_push.clone();
112 let handle = tokio::spawn(async move {
113 loop {
114 let mut start = task.start();
115 if start >= task.end() {
116 if task_queue.steal(&task, min_chunk_size) {
117 continue;
118 } else {
119 break;
120 }
121 }
122 let _ = tx.send(Event::Pulling(id));
123 let download_range = start..task.end();
124 let mut stream = loop {
125 match puller.pull(Some(&download_range)).await {
126 Ok(t) => break t,
127 Err((e, retry_gap)) => {
128 let _ = tx.send(Event::PullError(id, e));
129 tokio::time::sleep(retry_gap.unwrap_or(cfg_retry_gap)).await;
130 }
131 }
132 };
133 loop {
134 match stream.try_next().await {
135 Ok(Some(mut chunk)) => {
136 let len = chunk.len() as u64;
137 if task.fetch_add_start(len).is_err() {
138 break;
139 }
140 let range_start = start;
141 start += len;
142 let range_end = start.min(task.end());
143 if range_start >= range_end {
144 break;
145 }
146 let span = range_start..range_end;
147 chunk.truncate(span.total() as usize);
148 let _ = tx.send(Event::PullProgress(id, span.clone()));
149 let tx_push = tx_push.clone();
150 let _ = tokio::spawn(async move {
151 tx_push.send((id, span, chunk)).await.unwrap();
152 })
153 .await;
154 }
155 Ok(None) => break,
156 Err((e, retry_gap)) => {
157 let is_irrecoverable = e.is_irrecoverable();
158 let _ = tx.send(Event::PullError(id, e));
159 tokio::time::sleep(retry_gap.unwrap_or(cfg_retry_gap)).await;
160 if is_irrecoverable {
161 break;
162 }
163 }
164 }
165 }
166 }
167 task_queue.finish_work(&task);
168 let _ = tx.send(Event::Finished(id));
169 });
170 TokioHandle(handle.abort_handle())
171 }
172}
173
174#[cfg(test)]
175mod tests {
176 use vec::Vec;
177
178 use super::*;
179 use crate::{
180 Merge, ProgressEntry,
181 mem::MemPusher,
182 mock::{MockPuller, build_mock_data},
183 };
184 use std::{dbg, vec};
185
186 #[tokio::test]
187 async fn test_concurrent_download() {
188 let mock_data = build_mock_data(3 * 1024);
189 let puller = MockPuller::new(&mock_data);
190 let pusher = MemPusher::with_capacity(mock_data.len());
191 #[allow(clippy::single_range_in_vec_init)]
192 let download_chunks = vec![0..mock_data.len() as u64];
193 let result = download_multi(
194 puller,
195 pusher.clone(),
196 DownloadOptions {
197 concurrent: 32,
198 retry_gap: Duration::from_secs(1),
199 push_queue_cap: 1024,
200 download_chunks: download_chunks.iter(),
201 min_chunk_size: 1,
202 },
203 );
204
205 let mut pull_progress: Vec<ProgressEntry> = Vec::new();
206 let mut push_progress: Vec<ProgressEntry> = Vec::new();
207 let mut pull_ids = [false; 32];
208 let mut push_ids = [false; 32];
209 while let Ok(e) = result.event_chain.recv().await {
210 match e {
211 Event::PullProgress(id, p) => {
212 pull_ids[id] = true;
213 pull_progress.merge_progress(p);
214 }
215 Event::PushProgress(id, p) => {
216 push_ids[id] = true;
217 push_progress.merge_progress(p);
218 }
219 _ => {}
220 }
221 }
222 dbg!(&pull_progress);
223 dbg!(&push_progress);
224 assert_eq!(pull_progress, download_chunks);
225 assert_eq!(push_progress, download_chunks);
226 assert_eq!(pull_ids, [true; 32]);
227 assert_eq!(push_ids, [true; 32]);
228
229 result.join().await.unwrap();
230 assert_eq!(&**pusher.receive.lock(), mock_data);
231 }
232}