fast_pull/core/
multi.rs

1extern crate alloc;
2extern crate spin;
3use super::macros::poll_ok;
4use crate::{DownloadResult, Event, ProgressEntry, RandPuller, RandPusher, Total, WorkerId};
5use alloc::{sync::Arc, vec::Vec};
6use bytes::Bytes;
7use core::{num::NonZeroUsize, time::Duration};
8use fast_steal::{SplitTask, StealTask, Task, TaskList};
9use futures::TryStreamExt;
10
11#[derive(Debug, Clone)]
12pub struct DownloadOptions {
13    pub download_chunks: Vec<ProgressEntry>,
14    pub concurrent: NonZeroUsize,
15    pub retry_gap: Duration,
16    pub push_queue_cap: usize,
17}
18
19pub async fn download_multi<R, W>(
20    puller: R,
21    mut pusher: W,
22    options: DownloadOptions,
23) -> DownloadResult<R::Error, W::Error>
24where
25    R: RandPuller + 'static,
26    W: RandPusher + 'static,
27{
28    let (tx, event_chain) = kanal::unbounded_async();
29    let (tx_push, rx_push) =
30        kanal::bounded_async::<(WorkerId, ProgressEntry, Bytes)>(options.push_queue_cap);
31    let tx_clone = tx.clone();
32    let push_handle = tokio::spawn(async move {
33        while let Ok((id, spin, data)) = rx_push.recv().await {
34            poll_ok!(
35                {},
36                pusher.push(spin.clone(), data.clone()).await,
37                id @ tx_clone => PushError,
38                options.retry_gap
39            );
40            tx_clone.send(Event::PushProgress(id, spin)).await.unwrap();
41        }
42        poll_ok!(
43            {},
44            pusher.flush().await,
45            tx_clone => FlushError,
46            options.retry_gap
47        );
48    });
49    let mutex = Arc::new(spin::mutex::SpinMutex::<_>::new(()));
50    let task_list = Arc::new(TaskList::from(&options.download_chunks[..]));
51    let tasks = Arc::from_iter(
52        Task::from(&*task_list)
53            .split_task(options.concurrent.get() as u64)
54            .map(Arc::new),
55    );
56    let mut abort_handles = Vec::with_capacity(tasks.len());
57    for (id, task) in tasks.iter().enumerate() {
58        let task = task.clone();
59        let tasks = tasks.clone();
60        let task_list = task_list.clone();
61        let mutex = mutex.clone();
62        let tx = tx.clone();
63        let mut puller = puller.clone();
64        let tx_push = tx_push.clone();
65        let handle = tokio::spawn(async move {
66            'steal_task: loop {
67                let mut start = task.start();
68                if start >= task.end() {
69                    let guard = mutex.lock();
70                    if task.steal(&tasks, 16 * 1024) {
71                        continue;
72                    }
73                    drop(guard);
74                    tx.send(Event::Finished(id)).await.unwrap();
75                    return;
76                }
77                let download_range = &task_list.get_range(start..task.end());
78                for range in download_range {
79                    tx.send(Event::Pulling(id)).await.unwrap();
80                    let mut stream = puller.pull(range);
81                    let mut downloaded = 0;
82                    loop {
83                        match stream.try_next().await {
84                            Ok(Some(mut chunk)) => {
85                                let len = chunk.len() as u64;
86                                task.fetch_add_start(len);
87                                start += len;
88                                let range_start = range.start + downloaded;
89                                downloaded += len;
90                                let range_end = range.start + downloaded;
91                                let span = range_start..range_end.min(task_list.get(task.end()));
92                                let len = span.total() as usize;
93                                tx.send(Event::PullProgress(id, span.clone()))
94                                    .await
95                                    .unwrap();
96                                tx_push.send((id, span, chunk.split_to(len))).await.unwrap();
97                                if start >= task.end() {
98                                    continue 'steal_task;
99                                }
100                            }
101                            Ok(None) => break,
102                            Err(e) => {
103                                tx.send(Event::PullError(id, e)).await.unwrap();
104                                tokio::time::sleep(options.retry_gap).await;
105                            }
106                        }
107                    }
108                }
109            }
110        });
111        abort_handles.push(handle.abort_handle());
112    }
113    DownloadResult::new(event_chain, push_handle, &abort_handles)
114}
115
116#[cfg(test)]
117mod tests {
118    extern crate std;
119    use super::*;
120    use crate::{
121        MergeProgress, ProgressEntry,
122        core::mock::{MockRandPuller, MockRandPusher, build_mock_data},
123    };
124    use alloc::vec;
125    use std::dbg;
126
127    #[tokio::test]
128    async fn test_concurrent_download() {
129        let mock_data = build_mock_data(3 * 1024);
130        let puller = MockRandPuller::new(&mock_data);
131        let pusher = MockRandPusher::new(&mock_data);
132        #[allow(clippy::single_range_in_vec_init)]
133        let download_chunks = vec![0..mock_data.len() as u64];
134        let result = download_multi(
135            puller,
136            pusher.clone(),
137            DownloadOptions {
138                concurrent: NonZeroUsize::new(32).unwrap(),
139                retry_gap: Duration::from_secs(1),
140                push_queue_cap: 1024,
141                download_chunks: download_chunks.clone(),
142            },
143        )
144        .await;
145
146        let mut pull_progress: Vec<ProgressEntry> = Vec::new();
147        let mut push_progress: Vec<ProgressEntry> = Vec::new();
148        while let Ok(e) = result.event_chain.recv().await {
149            match e {
150                Event::PullProgress(_, p) => {
151                    pull_progress.merge_progress(p);
152                }
153                Event::PushProgress(_, p) => {
154                    push_progress.merge_progress(p);
155                }
156                _ => {}
157            }
158        }
159        dbg!(&pull_progress);
160        dbg!(&push_progress);
161        assert_eq!(pull_progress, download_chunks);
162        assert_eq!(push_progress, download_chunks);
163
164        result.join().await.unwrap();
165        pusher.assert().await;
166    }
167}