1extern crate alloc;
2use super::macros::poll_ok;
3use crate::{DownloadResult, Event, ProgressEntry, RandPuller, RandPusher, Total, WorkerId};
4use alloc::{sync::Arc, vec::Vec};
5use bytes::Bytes;
6use core::{
7 num::{NonZero, NonZeroU64, NonZeroUsize},
8 sync::atomic::{AtomicUsize, Ordering},
9 time::Duration,
10};
11use fast_steal::{Executor, Handle, Task, TaskList};
12use futures::TryStreamExt;
13use tokio::task::AbortHandle;
14
15#[derive(Debug, Clone)]
16pub struct DownloadOptions {
17 pub download_chunks: Vec<ProgressEntry>,
18 pub concurrent: NonZeroUsize,
19 pub retry_gap: Duration,
20 pub push_queue_cap: usize,
21 pub min_chunk_size: NonZeroU64,
22}
23
24pub async fn download_multi<R, W>(
25 puller: R,
26 mut pusher: W,
27 options: DownloadOptions,
28) -> DownloadResult<TokioExecutor<R, W>, R::Error, W::Error>
29where
30 R: RandPuller + 'static + Sync,
31 W: RandPusher + 'static,
32{
33 let (tx, event_chain) = kanal::unbounded_async();
34 let (tx_push, rx_push) =
35 kanal::bounded_async::<(WorkerId, ProgressEntry, Bytes)>(options.push_queue_cap);
36 let tx_clone = tx.clone();
37 let push_handle = tokio::spawn(async move {
38 while let Ok((id, spin, data)) = rx_push.recv().await {
39 poll_ok!(
40 pusher.push(spin.clone(), &data).await,
41 id @ tx_clone => PushError,
42 options.retry_gap
43 );
44 tx_clone.send(Event::PushProgress(id, spin)).await.unwrap();
45 }
46 poll_ok!(
47 pusher.flush().await,
48 tx_clone => FlushError,
49 options.retry_gap
50 );
51 });
52 let executor: Arc<TokioExecutor<R, W>> = Arc::new(TokioExecutor {
53 tx,
54 tx_push,
55 puller,
56 retry_gap: options.retry_gap,
57 id: Arc::new(AtomicUsize::new(0)),
58 min_chunk_size: options.min_chunk_size,
59 });
60 let task_list = Arc::new(TaskList::run(&options.download_chunks[..], executor));
61 task_list
62 .clone()
63 .set_threads(options.concurrent, options.min_chunk_size);
64 DownloadResult::new(
65 event_chain,
66 push_handle,
67 &task_list.handles(|iter| iter.map(|h| h.0.clone()).collect::<Arc<[_]>>()),
68 Some(Arc::downgrade(&task_list)),
69 )
70}
71
72#[derive(Clone)]
73pub struct TokioHandle(AbortHandle);
74impl Handle for TokioHandle {
75 type Output = ();
76 fn abort(&mut self) -> Self::Output {
77 self.0.abort();
78 }
79}
80pub struct TokioExecutor<R, W>
81where
82 R: RandPuller + 'static,
83 W: RandPusher + 'static,
84{
85 tx: kanal::AsyncSender<Event<R::Error, W::Error>>,
86 tx_push: kanal::AsyncSender<(WorkerId, ProgressEntry, Bytes)>,
87 puller: R,
88 retry_gap: Duration,
89 id: Arc<AtomicUsize>,
90 min_chunk_size: NonZeroU64,
91}
92impl<R, W> Executor for TokioExecutor<R, W>
93where
94 R: RandPuller + 'static + Sync,
95 W: RandPusher + 'static,
96{
97 type Handle = TokioHandle;
98 fn execute(self: Arc<Self>, task: Arc<Task>, task_list: Arc<TaskList<Self>>) -> Self::Handle {
99 let id = self.id.fetch_add(1, Ordering::SeqCst);
100 let mut puller = self.puller.clone();
101 let handle = tokio::spawn(async move {
102 'steal_task: loop {
103 let mut start = task.start();
104 if start >= task.end() {
105 if task_list.steal(&task, NonZero::new(2 * self.min_chunk_size.get()).unwrap())
106 {
107 continue;
108 } else {
109 break;
110 }
111 }
112 self.tx.send(Event::Pulling(id)).await.unwrap();
113 let download_range = start..task.end();
114 let mut stream = puller.pull(&download_range);
115 loop {
116 match stream.try_next().await {
117 Ok(Some(mut chunk)) => {
118 let len = chunk.len() as u64;
119 task.fetch_add_start(len);
120 let range_start = start;
121 start += len;
122 let range_end = start.min(task.end());
123 if range_start >= range_end {
124 continue 'steal_task;
125 }
126 let span = range_start..range_end;
127 chunk.truncate(span.total() as usize);
128 self.tx
129 .send(Event::PullProgress(id, span.clone()))
130 .await
131 .unwrap();
132 self.tx_push.send((id, span, chunk)).await.unwrap();
133 }
134 Ok(None) => break,
135 Err(e) => {
136 self.tx.send(Event::PullError(id, e)).await.unwrap();
137 tokio::time::sleep(self.retry_gap).await;
138 }
139 }
140 }
141 }
142 task_list.remove(&task);
143 self.tx.send(Event::Finished(id)).await.unwrap();
144 });
145 TokioHandle(handle.abort_handle())
146 }
147}
148
149#[cfg(test)]
150mod tests {
151 extern crate std;
152 use super::*;
153 use crate::{
154 MergeProgress, ProgressEntry,
155 mem::MemPusher,
156 mock::{MockPuller, build_mock_data},
157 };
158 use alloc::vec;
159 use std::dbg;
160
161 #[tokio::test]
162 async fn test_concurrent_download() {
163 let mock_data = build_mock_data(3 * 1024);
164 let puller = MockPuller::new(&mock_data);
165 let pusher = MemPusher::with_capacity(mock_data.len());
166 #[allow(clippy::single_range_in_vec_init)]
167 let download_chunks = vec![0..mock_data.len() as u64];
168 let result = download_multi(
169 puller,
170 pusher.clone(),
171 DownloadOptions {
172 concurrent: NonZero::new(32).unwrap(),
173 retry_gap: Duration::from_secs(1),
174 push_queue_cap: 1024,
175 download_chunks: download_chunks.clone(),
176 min_chunk_size: NonZero::new(1).unwrap(),
177 },
178 )
179 .await;
180
181 let mut pull_progress: Vec<ProgressEntry> = Vec::new();
182 let mut push_progress: Vec<ProgressEntry> = Vec::new();
183 let mut pull_ids = [false; 32];
184 let mut push_ids = [false; 32];
185 while let Ok(e) = result.event_chain.recv().await {
186 match e {
187 Event::PullProgress(id, p) => {
188 pull_ids[id] = true;
189 pull_progress.merge_progress(p);
190 }
191 Event::PushProgress(id, p) => {
192 push_ids[id] = true;
193 push_progress.merge_progress(p);
194 }
195 _ => {}
196 }
197 }
198 dbg!(&pull_progress);
199 dbg!(&push_progress);
200 assert_eq!(pull_progress, download_chunks);
201 assert_eq!(push_progress, download_chunks);
202 assert_eq!(pull_ids, [true; 32]);
203 assert_eq!(push_ids, [true; 32]);
204
205 result.join().await.unwrap();
206 assert_eq!(&**pusher.receive.lock(), mock_data);
207 }
208}