1extern crate std;
2use crate::{DownloadResult, Event, ProgressEntry, Puller, Pusher, Total, WorkerId};
3use bytes::Bytes;
4use core::{
5 sync::atomic::{AtomicUsize, Ordering},
6 time::Duration,
7};
8use crossfire::{MAsyncTx, MTx, mpmc, mpsc};
9use fast_steal::{Executor, Handle, Task, TaskQueue};
10use futures::TryStreamExt;
11use std::sync::Arc;
12use tokio::task::AbortHandle;
13
14#[derive(Debug, Clone)]
15pub struct DownloadOptions<'a, I: Iterator<Item = &'a ProgressEntry>> {
16 pub download_chunks: I,
17 pub concurrent: usize,
18 pub retry_gap: Duration,
19 pub push_queue_cap: usize,
20 pub min_chunk_size: u64,
21}
22
23pub fn download_multi<'a, R: Puller, W: Pusher, I: Iterator<Item = &'a ProgressEntry>>(
24 puller: R,
25 mut pusher: W,
26 options: DownloadOptions<'a, I>,
27) -> DownloadResult<TokioExecutor<R, W::Error>, R::Error, W::Error> {
28 let (tx, event_chain) = mpmc::unbounded_async();
29 let (tx_push, rx_push) = mpsc::bounded_async(options.push_queue_cap);
30 let tx_clone = tx.clone();
31 let rx_push = rx_push.into_blocking();
32 let push_handle = tokio::task::spawn_blocking(move || {
33 while let Ok((id, spin, mut data)) = rx_push.recv() {
34 loop {
35 match pusher.push(&spin, data) {
36 Ok(_) => break,
37 Err((err, bytes)) => {
38 data = bytes;
39 let _ = tx_clone.send(Event::PushError(id, err));
40 }
41 }
42 std::thread::sleep(options.retry_gap);
43 }
44 let _ = tx_clone.send(Event::PushProgress(id, spin));
45 }
46 loop {
47 match pusher.flush() {
48 Ok(_) => break,
49 Err(err) => {
50 let _ = tx_clone.send(Event::FlushError(err));
51 }
52 }
53 std::thread::sleep(options.retry_gap);
54 }
55 });
56 let executor: Arc<TokioExecutor<R, W::Error>> = Arc::new(TokioExecutor {
57 tx,
58 tx_push,
59 puller,
60 retry_gap: options.retry_gap,
61 id: AtomicUsize::new(0),
62 min_chunk_size: options.min_chunk_size,
63 });
64 let task_queue = TaskQueue::new(options.download_chunks);
65 task_queue.set_threads(
66 options.concurrent,
67 options.min_chunk_size,
68 Some(executor.as_ref()),
69 );
70 DownloadResult::new(
71 event_chain,
72 push_handle,
73 None,
74 Some((Arc::downgrade(&executor), task_queue)),
75 )
76}
77
78#[derive(Clone)]
79pub struct TokioHandle(AbortHandle);
80impl Handle for TokioHandle {
81 type Output = ();
82 fn abort(&mut self) -> Self::Output {
83 self.0.abort();
84 }
85}
86#[derive(Debug)]
87pub struct TokioExecutor<R, WE>
88where
89 R: Puller,
90 WE: Send + Unpin + 'static,
91{
92 tx: MTx<mpmc::List<Event<R::Error, WE>>>,
93 tx_push: MAsyncTx<mpsc::Array<(WorkerId, ProgressEntry, Bytes)>>,
94 puller: R,
95 retry_gap: Duration,
96 id: AtomicUsize,
97 min_chunk_size: u64,
98}
99impl<R, WE> Executor for TokioExecutor<R, WE>
100where
101 R: Puller,
102 WE: Send + Unpin + 'static,
103{
104 type Handle = TokioHandle;
105 fn execute(&self, task: Task, task_queue: TaskQueue<Self::Handle>) -> Self::Handle {
106 let id = self.id.fetch_add(1, Ordering::SeqCst);
107 let mut puller = self.puller.clone();
108 let min_chunk_size = self.min_chunk_size;
109 let cfg_retry_gap = self.retry_gap;
110 let tx = self.tx.clone();
111 let tx_push = self.tx_push.clone();
112 let handle = tokio::spawn(async move {
113 loop {
114 let mut start = task.start();
115 if start >= task.end() {
116 if task_queue.steal(&task, min_chunk_size) {
117 continue;
118 } else {
119 break;
120 }
121 }
122 let _ = tx.send(Event::Pulling(id));
123 let download_range = start..task.end();
124 let mut stream = loop {
125 match puller.pull(Some(&download_range)).await {
126 Ok(t) => break t,
127 Err((e, retry_gap)) => {
128 let _ = tx.send(Event::PullError(id, e));
129 tokio::time::sleep(retry_gap.unwrap_or(cfg_retry_gap)).await;
130 }
131 }
132 };
133 loop {
134 match stream.try_next().await {
135 Ok(Some(mut chunk)) => {
136 let len = chunk.len() as u64;
137 if task.fetch_add_start(len).is_err() {
138 break;
139 }
140 let range_start = start;
141 start += len;
142 let range_end = start.min(task.end());
143 if range_start >= range_end {
144 break;
145 }
146 let span = range_start..range_end;
147 chunk.truncate(span.total() as usize);
148 let _ = tx.send(Event::PullProgress(id, span.clone()));
149 let tx_push = tx_push.clone();
150 let _ = tokio::spawn(async move {
151 tx_push.send((id, span, chunk)).await.unwrap();
152 })
153 .await;
154 }
155 Ok(None) => break,
156 Err((e, retry_gap)) => {
157 let _ = tx.send(Event::PullError(id, e));
158 tokio::time::sleep(retry_gap.unwrap_or(cfg_retry_gap)).await;
159 }
160 }
161 }
162 }
163 task_queue.finish_work(&task);
164 let _ = tx.send(Event::Finished(id));
165 });
166 TokioHandle(handle.abort_handle())
167 }
168}
169
170#[cfg(test)]
171mod tests {
172 use vec::Vec;
173
174 use super::*;
175 use crate::{
176 Merge, ProgressEntry,
177 mem::MemPusher,
178 mock::{MockPuller, build_mock_data},
179 };
180 use std::{dbg, vec};
181
182 #[tokio::test]
183 async fn test_concurrent_download() {
184 let mock_data = build_mock_data(3 * 1024);
185 let puller = MockPuller::new(&mock_data);
186 let pusher = MemPusher::with_capacity(mock_data.len());
187 #[allow(clippy::single_range_in_vec_init)]
188 let download_chunks = vec![0..mock_data.len() as u64];
189 let result = download_multi(
190 puller,
191 pusher.clone(),
192 DownloadOptions {
193 concurrent: 32,
194 retry_gap: Duration::from_secs(1),
195 push_queue_cap: 1024,
196 download_chunks: download_chunks.iter(),
197 min_chunk_size: 1,
198 },
199 );
200
201 let mut pull_progress: Vec<ProgressEntry> = Vec::new();
202 let mut push_progress: Vec<ProgressEntry> = Vec::new();
203 let mut pull_ids = [false; 32];
204 let mut push_ids = [false; 32];
205 while let Ok(e) = result.event_chain.recv().await {
206 match e {
207 Event::PullProgress(id, p) => {
208 pull_ids[id] = true;
209 pull_progress.merge_progress(p);
210 }
211 Event::PushProgress(id, p) => {
212 push_ids[id] = true;
213 push_progress.merge_progress(p);
214 }
215 _ => {}
216 }
217 }
218 dbg!(&pull_progress);
219 dbg!(&push_progress);
220 assert_eq!(pull_progress, download_chunks);
221 assert_eq!(push_progress, download_chunks);
222 assert_eq!(pull_ids, [true; 32]);
223 assert_eq!(push_ids, [true; 32]);
224
225 result.join().await.unwrap();
226 assert_eq!(&**pusher.receive.lock(), mock_data);
227 }
228}