1extern crate alloc;
2extern crate spin;
3use super::macros::poll_ok;
4use crate::{DownloadResult, Event, ProgressEntry, RandPuller, RandPusher, Total, WorkerId};
5use alloc::{sync::Arc, vec::Vec};
6use bytes::Bytes;
7use core::{num::NonZeroUsize, time::Duration};
8use fast_steal::{SplitTask, StealTask, Task, TaskList};
9use futures::TryStreamExt;
10
11#[derive(Debug, Clone)]
12pub struct DownloadOptions {
13 pub download_chunks: Vec<ProgressEntry>,
14 pub concurrent: NonZeroUsize,
15 pub retry_gap: Duration,
16 pub push_queue_cap: usize,
17}
18
19pub async fn download_multi<R, W>(
20 puller: R,
21 mut pusher: W,
22 options: DownloadOptions,
23) -> DownloadResult<R::Error, W::Error>
24where
25 R: RandPuller + 'static,
26 W: RandPusher + 'static,
27{
28 let (tx, event_chain) = kanal::unbounded_async();
29 let (tx_push, rx_push) =
30 kanal::bounded_async::<(WorkerId, ProgressEntry, Bytes)>(options.push_queue_cap);
31 let tx_clone = tx.clone();
32 let push_handle = tokio::spawn(async move {
33 while let Ok((id, spin, data)) = rx_push.recv().await {
34 poll_ok!(
35 {},
36 pusher.push(spin.clone(), data.clone()).await,
37 id @ tx_clone => PushError,
38 options.retry_gap
39 );
40 tx_clone.send(Event::PushProgress(id, spin)).await.unwrap();
41 }
42 poll_ok!(
43 {},
44 pusher.flush().await,
45 tx_clone => FlushError,
46 options.retry_gap
47 );
48 });
49 let mutex = Arc::new(spin::mutex::SpinMutex::<_>::new(()));
50 let task_list = Arc::new(TaskList::from(&options.download_chunks[..]));
51 let tasks = Arc::from_iter(
52 Task::from(&*task_list)
53 .split_task(options.concurrent.get() as u64)
54 .map(Arc::new),
55 );
56 let mut abort_handles = Vec::with_capacity(tasks.len());
57 for (id, task) in tasks.iter().enumerate() {
58 let task = task.clone();
59 let tasks = tasks.clone();
60 let task_list = task_list.clone();
61 let mutex = mutex.clone();
62 let tx = tx.clone();
63 let mut puller = puller.clone();
64 let tx_push = tx_push.clone();
65 let handle = tokio::spawn(async move {
66 'steal_task: loop {
67 let mut start = task.start();
68 if start >= task.end() {
69 let guard = mutex.lock();
70 if task.steal(&tasks, 16 * 1024) {
71 continue;
72 }
73 drop(guard);
74 tx.send(Event::Finished(id)).await.unwrap();
75 return;
76 }
77 let download_range = &task_list.get_range(start..task.end());
78 for range in download_range {
79 tx.send(Event::Pulling(id)).await.unwrap();
80 let mut stream = puller.pull(range);
81 let mut downloaded = 0;
82 loop {
83 match stream.try_next().await {
84 Ok(Some(mut chunk)) => {
85 let len = chunk.len() as u64;
86 task.fetch_add_start(len);
87 start += len;
88 let range_start = range.start + downloaded;
89 downloaded += len;
90 let range_end = range.start + downloaded;
91 let span = range_start..range_end.min(task_list.get(task.end()));
92 let len = span.total() as usize;
93 tx.send(Event::PullProgress(id, span.clone()))
94 .await
95 .unwrap();
96 tx_push.send((id, span, chunk.split_to(len))).await.unwrap();
97 if start >= task.end() {
98 continue 'steal_task;
99 }
100 }
101 Ok(None) => break,
102 Err(e) => {
103 tx.send(Event::PullError(id, e)).await.unwrap();
104 tokio::time::sleep(options.retry_gap).await;
105 }
106 }
107 }
108 }
109 }
110 });
111 abort_handles.push(handle.abort_handle());
112 }
113 DownloadResult::new(event_chain, push_handle, &abort_handles)
114}
115
116#[cfg(test)]
117mod tests {
118 extern crate std;
119 use super::*;
120 use crate::{
121 MergeProgress, ProgressEntry,
122 core::mock::{MockRandPuller, MockRandPusher, build_mock_data},
123 };
124 use alloc::vec;
125 use std::dbg;
126
127 #[tokio::test]
128 async fn test_concurrent_download() {
129 let mock_data = build_mock_data(3 * 1024);
130 let puller = MockRandPuller::new(&mock_data);
131 let pusher = MockRandPusher::new(&mock_data);
132 #[allow(clippy::single_range_in_vec_init)]
133 let download_chunks = vec![0..mock_data.len() as u64];
134 let result = download_multi(
135 puller,
136 pusher.clone(),
137 DownloadOptions {
138 concurrent: NonZeroUsize::new(32).unwrap(),
139 retry_gap: Duration::from_secs(1),
140 push_queue_cap: 1024,
141 download_chunks: download_chunks.clone(),
142 },
143 )
144 .await;
145
146 let mut pull_progress: Vec<ProgressEntry> = Vec::new();
147 let mut push_progress: Vec<ProgressEntry> = Vec::new();
148 while let Ok(e) = result.event_chain.recv().await {
149 match e {
150 Event::PullProgress(_, p) => {
151 pull_progress.merge_progress(p);
152 }
153 Event::PushProgress(_, p) => {
154 push_progress.merge_progress(p);
155 }
156 _ => {}
157 }
158 }
159 dbg!(&pull_progress);
160 dbg!(&push_progress);
161 assert_eq!(pull_progress, download_chunks);
162 assert_eq!(push_progress, download_chunks);
163
164 result.join().await.unwrap();
165 pusher.assert().await;
166 }
167}