fast_pull/core/
multi.rs

1extern crate alloc;
2use super::macros::poll_ok;
3use crate::{DownloadResult, Event, ProgressEntry, RandPuller, RandPusher, Total, WorkerId};
4use alloc::{sync::Arc, vec::Vec};
5use bytes::Bytes;
6use core::{num::NonZeroUsize, time::Duration};
7use fast_steal::{Executor, Handle, Task, TaskList};
8use futures::TryStreamExt;
9use tokio::task::AbortHandle;
10
11#[derive(Debug, Clone)]
12pub struct DownloadOptions {
13    pub download_chunks: Vec<ProgressEntry>,
14    pub concurrent: NonZeroUsize,
15    pub retry_gap: Duration,
16    pub push_queue_cap: usize,
17}
18
19pub async fn download_multi<R, W>(
20    puller: R,
21    mut pusher: W,
22    options: DownloadOptions,
23) -> DownloadResult<R::Error, W::Error>
24where
25    R: RandPuller + 'static + Sync,
26    W: RandPusher + 'static,
27{
28    let (tx, event_chain) = kanal::unbounded_async();
29    let (tx_push, rx_push) =
30        kanal::bounded_async::<(WorkerId, ProgressEntry, Bytes)>(options.push_queue_cap);
31    let tx_clone = tx.clone();
32    let push_handle = tokio::spawn(async move {
33        while let Ok((id, spin, data)) = rx_push.recv().await {
34            poll_ok!(
35                {},
36                pusher.push(spin.clone(), data.clone()).await,
37                id @ tx_clone => PushError,
38                options.retry_gap
39            );
40            tx_clone.send(Event::PushProgress(id, spin)).await.unwrap();
41        }
42        poll_ok!(
43            {},
44            pusher.flush().await,
45            tx_clone => FlushError,
46            options.retry_gap
47        );
48    });
49    let executor: TokioExecutor<R, W> = TokioExecutor {
50        tx,
51        tx_push,
52        puller,
53        retry_gap: options.retry_gap,
54    };
55    let task_list = TaskList::run(
56        options.concurrent.get(),
57        8 * 1024,
58        &options.download_chunks[..],
59        executor,
60    );
61    DownloadResult::new(
62        event_chain,
63        push_handle,
64        &task_list
65            .handles()
66            .iter()
67            .map(|h| h.0.clone())
68            .collect::<Arc<[_]>>(),
69    )
70}
71
72#[derive(Clone)]
73pub struct TokioHandle(AbortHandle);
74impl Handle for TokioHandle {
75    type Output = ();
76    fn abort(&mut self) -> Self::Output {
77        self.0.abort();
78    }
79}
80pub struct TokioExecutor<R, W>
81where
82    R: RandPuller + 'static,
83    W: RandPusher + 'static,
84{
85    tx: kanal::AsyncSender<Event<R::Error, W::Error>>,
86    tx_push: kanal::AsyncSender<(WorkerId, ProgressEntry, Bytes)>,
87    puller: R,
88    retry_gap: Duration,
89}
90impl<R, W> Executor for TokioExecutor<R, W>
91where
92    R: RandPuller + 'static + Sync,
93    W: RandPusher + 'static,
94{
95    type Handle = TokioHandle;
96    fn execute(self: Arc<Self>, task: Arc<Task>, task_list: Arc<TaskList<Self>>) -> Self::Handle {
97        let id = 1; // TODO: worker id
98        let handle = tokio::spawn(async move {
99            'steal_task: loop {
100                let mut start = task.start();
101                if start >= task.end() {
102                    if task_list.steal(&task, 16 * 1024) {
103                        continue;
104                    }
105                    break;
106                }
107                self.tx.send(Event::Pulling(id)).await.unwrap();
108                let download_range = start..task.end();
109                let mut puller = self.puller.clone();
110                let mut stream = puller.pull(&download_range);
111                loop {
112                    match stream.try_next().await {
113                        Ok(Some(mut chunk)) => {
114                            let len = chunk.len() as u64;
115                            task.fetch_add_start(len);
116                            let range_start = start;
117                            start += len;
118                            let range_end = start.min(task.end());
119                            if range_start >= range_end {
120                                continue 'steal_task;
121                            }
122                            let span = range_start..range_end;
123                            let len = span.total() as usize;
124                            self.tx
125                                .send(Event::PullProgress(id, span.clone()))
126                                .await
127                                .unwrap();
128                            self.tx_push
129                                .send((id, span, chunk.split_to(len)))
130                                .await
131                                .unwrap();
132                        }
133                        Ok(None) => break,
134                        Err(e) => {
135                            self.tx.send(Event::PullError(id, e)).await.unwrap();
136                            tokio::time::sleep(self.retry_gap).await;
137                        }
138                    }
139                }
140            }
141            self.tx.send(Event::Finished(id)).await.unwrap();
142            task_list.remove(&task);
143        });
144        TokioHandle(handle.abort_handle())
145    }
146}
147
148#[cfg(test)]
149mod tests {
150    extern crate std;
151    use super::*;
152    use crate::{
153        MergeProgress, ProgressEntry,
154        core::mock::{MockRandPuller, MockRandPusher, build_mock_data},
155    };
156    use alloc::vec;
157    use std::dbg;
158
159    #[tokio::test]
160    async fn test_concurrent_download() {
161        let mock_data = build_mock_data(3 * 1024);
162        let puller = MockRandPuller::new(&mock_data);
163        let pusher = MockRandPusher::new(&mock_data);
164        #[allow(clippy::single_range_in_vec_init)]
165        let download_chunks = vec![0..mock_data.len() as u64];
166        let result = download_multi(
167            puller,
168            pusher.clone(),
169            DownloadOptions {
170                concurrent: NonZeroUsize::new(32).unwrap(),
171                retry_gap: Duration::from_secs(1),
172                push_queue_cap: 1024,
173                download_chunks: download_chunks.clone(),
174            },
175        )
176        .await;
177
178        let mut pull_progress: Vec<ProgressEntry> = Vec::new();
179        let mut push_progress: Vec<ProgressEntry> = Vec::new();
180        while let Ok(e) = result.event_chain.recv().await {
181            match e {
182                Event::PullProgress(_, p) => {
183                    pull_progress.merge_progress(p);
184                }
185                Event::PushProgress(_, p) => {
186                    push_progress.merge_progress(p);
187                }
188                _ => {}
189            }
190        }
191        dbg!(&pull_progress);
192        dbg!(&push_progress);
193        assert_eq!(pull_progress, download_chunks);
194        assert_eq!(push_progress, download_chunks);
195
196        result.join().await.unwrap();
197        pusher.assert().await;
198    }
199}