1extern crate alloc;
2use super::macros::poll_ok;
3use crate::{DownloadResult, Event, ProgressEntry, RandPuller, RandPusher, Total, WorkerId};
4use alloc::{sync::Arc, vec::Vec};
5use bytes::Bytes;
6use core::{num::NonZeroUsize, time::Duration};
7use fast_steal::{Executor, Handle, Task, TaskList};
8use futures::TryStreamExt;
9use tokio::task::AbortHandle;
10
11#[derive(Debug, Clone)]
12pub struct DownloadOptions {
13 pub download_chunks: Vec<ProgressEntry>,
14 pub concurrent: NonZeroUsize,
15 pub retry_gap: Duration,
16 pub push_queue_cap: usize,
17}
18
19pub async fn download_multi<R, W>(
20 puller: R,
21 mut pusher: W,
22 options: DownloadOptions,
23) -> DownloadResult<R::Error, W::Error>
24where
25 R: RandPuller + 'static + Sync,
26 W: RandPusher + 'static,
27{
28 let (tx, event_chain) = kanal::unbounded_async();
29 let (tx_push, rx_push) =
30 kanal::bounded_async::<(WorkerId, ProgressEntry, Bytes)>(options.push_queue_cap);
31 let tx_clone = tx.clone();
32 let push_handle = tokio::spawn(async move {
33 while let Ok((id, spin, data)) = rx_push.recv().await {
34 poll_ok!(
35 {},
36 pusher.push(spin.clone(), data.clone()).await,
37 id @ tx_clone => PushError,
38 options.retry_gap
39 );
40 tx_clone.send(Event::PushProgress(id, spin)).await.unwrap();
41 }
42 poll_ok!(
43 {},
44 pusher.flush().await,
45 tx_clone => FlushError,
46 options.retry_gap
47 );
48 });
49 let executor: TokioExecutor<R, W> = TokioExecutor {
50 tx,
51 tx_push,
52 puller,
53 retry_gap: options.retry_gap,
54 };
55 let task_list = TaskList::run(
56 options.concurrent.get(),
57 8 * 1024,
58 &options.download_chunks[..],
59 executor,
60 );
61 DownloadResult::new(
62 event_chain,
63 push_handle,
64 &task_list
65 .handles()
66 .iter()
67 .map(|h| h.0.clone())
68 .collect::<Arc<[_]>>(),
69 )
70}
71
72#[derive(Clone)]
73pub struct TokioHandle(AbortHandle);
74impl Handle for TokioHandle {
75 type Output = ();
76 fn abort(&mut self) -> Self::Output {
77 self.0.abort();
78 }
79}
80pub struct TokioExecutor<R, W>
81where
82 R: RandPuller + 'static,
83 W: RandPusher + 'static,
84{
85 tx: kanal::AsyncSender<Event<R::Error, W::Error>>,
86 tx_push: kanal::AsyncSender<(WorkerId, ProgressEntry, Bytes)>,
87 puller: R,
88 retry_gap: Duration,
89}
90impl<R, W> Executor for TokioExecutor<R, W>
91where
92 R: RandPuller + 'static + Sync,
93 W: RandPusher + 'static,
94{
95 type Handle = TokioHandle;
96 fn execute(self: Arc<Self>, task: Arc<Task>, task_list: Arc<TaskList<Self>>) -> Self::Handle {
97 let id = 1; let handle = tokio::spawn(async move {
99 'steal_task: loop {
100 let mut start = task.start();
101 if start >= task.end() {
102 if task_list.steal(&task, 16 * 1024) {
103 continue;
104 }
105 break;
106 }
107 self.tx.send(Event::Pulling(id)).await.unwrap();
108 let download_range = start..task.end();
109 let mut puller = self.puller.clone();
110 let mut stream = puller.pull(&download_range);
111 loop {
112 match stream.try_next().await {
113 Ok(Some(mut chunk)) => {
114 let len = chunk.len() as u64;
115 task.fetch_add_start(len);
116 let range_start = start;
117 start += len;
118 let range_end = start.min(task.end());
119 if range_start >= range_end {
120 continue 'steal_task;
121 }
122 let span = range_start..range_end;
123 let len = span.total() as usize;
124 self.tx
125 .send(Event::PullProgress(id, span.clone()))
126 .await
127 .unwrap();
128 self.tx_push
129 .send((id, span, chunk.split_to(len)))
130 .await
131 .unwrap();
132 }
133 Ok(None) => break,
134 Err(e) => {
135 self.tx.send(Event::PullError(id, e)).await.unwrap();
136 tokio::time::sleep(self.retry_gap).await;
137 }
138 }
139 }
140 }
141 self.tx.send(Event::Finished(id)).await.unwrap();
142 task_list.remove(&task);
143 });
144 TokioHandle(handle.abort_handle())
145 }
146}
147
148#[cfg(test)]
149mod tests {
150 extern crate std;
151 use super::*;
152 use crate::{
153 MergeProgress, ProgressEntry,
154 core::mock::{MockRandPuller, MockRandPusher, build_mock_data},
155 };
156 use alloc::vec;
157 use std::dbg;
158
159 #[tokio::test]
160 async fn test_concurrent_download() {
161 let mock_data = build_mock_data(3 * 1024);
162 let puller = MockRandPuller::new(&mock_data);
163 let pusher = MockRandPusher::new(&mock_data);
164 #[allow(clippy::single_range_in_vec_init)]
165 let download_chunks = vec![0..mock_data.len() as u64];
166 let result = download_multi(
167 puller,
168 pusher.clone(),
169 DownloadOptions {
170 concurrent: NonZeroUsize::new(32).unwrap(),
171 retry_gap: Duration::from_secs(1),
172 push_queue_cap: 1024,
173 download_chunks: download_chunks.clone(),
174 },
175 )
176 .await;
177
178 let mut pull_progress: Vec<ProgressEntry> = Vec::new();
179 let mut push_progress: Vec<ProgressEntry> = Vec::new();
180 while let Ok(e) = result.event_chain.recv().await {
181 match e {
182 Event::PullProgress(_, p) => {
183 pull_progress.merge_progress(p);
184 }
185 Event::PushProgress(_, p) => {
186 push_progress.merge_progress(p);
187 }
188 _ => {}
189 }
190 }
191 dbg!(&pull_progress);
192 dbg!(&push_progress);
193 assert_eq!(pull_progress, download_chunks);
194 assert_eq!(push_progress, download_chunks);
195
196 result.join().await.unwrap();
197 pusher.assert().await;
198 }
199}