fast_pull/core/
multi.rs

1extern crate alloc;
2use super::macros::poll_ok;
3use crate::{DownloadResult, Event, ProgressEntry, RandPuller, RandPusher, Total, WorkerId};
4use alloc::{sync::Arc, vec::Vec};
5use bytes::Bytes;
6use core::{
7    num::{NonZero, NonZeroU64, NonZeroUsize},
8    sync::atomic::{AtomicUsize, Ordering},
9    time::Duration,
10};
11use fast_steal::{Executor, Handle, Task, TaskList};
12use futures::TryStreamExt;
13use tokio::task::AbortHandle;
14
15#[derive(Debug, Clone)]
16pub struct DownloadOptions {
17    pub download_chunks: Vec<ProgressEntry>,
18    pub concurrent: NonZeroUsize,
19    pub retry_gap: Duration,
20    pub push_queue_cap: usize,
21    pub min_chunk_size: NonZeroU64,
22}
23
24pub async fn download_multi<R, W>(
25    puller: R,
26    mut pusher: W,
27    options: DownloadOptions,
28) -> DownloadResult<R::Error, W::Error>
29where
30    R: RandPuller + 'static + Sync,
31    W: RandPusher + 'static,
32{
33    let (tx, event_chain) = kanal::unbounded_async();
34    let (tx_push, rx_push) =
35        kanal::bounded_async::<(WorkerId, ProgressEntry, Bytes)>(options.push_queue_cap);
36    let tx_clone = tx.clone();
37    let push_handle = tokio::spawn(async move {
38        while let Ok((id, spin, data)) = rx_push.recv().await {
39            poll_ok!(
40                {},
41                pusher.push(spin.clone(), data.clone()).await,
42                id @ tx_clone => PushError,
43                options.retry_gap
44            );
45            tx_clone.send(Event::PushProgress(id, spin)).await.unwrap();
46        }
47        poll_ok!(
48            {},
49            pusher.flush().await,
50            tx_clone => FlushError,
51            options.retry_gap
52        );
53    });
54    let executor: TokioExecutor<R, W> = TokioExecutor {
55        tx,
56        tx_push,
57        puller,
58        retry_gap: options.retry_gap,
59        id: Arc::new(AtomicUsize::new(0)),
60        min_chunk_size: options.min_chunk_size,
61    };
62    let task_list = TaskList::run(
63        options.concurrent,
64        options.min_chunk_size,
65        &options.download_chunks[..],
66        executor,
67    );
68    DownloadResult::new(
69        event_chain,
70        push_handle,
71        &task_list.handles(|iter| iter.map(|h| h.0.clone()).collect::<Arc<[_]>>()),
72    )
73}
74
75#[derive(Clone)]
76pub struct TokioHandle(AbortHandle);
77impl Handle for TokioHandle {
78    type Output = ();
79    fn abort(&mut self) -> Self::Output {
80        self.0.abort();
81    }
82}
83pub struct TokioExecutor<R, W>
84where
85    R: RandPuller + 'static,
86    W: RandPusher + 'static,
87{
88    tx: kanal::AsyncSender<Event<R::Error, W::Error>>,
89    tx_push: kanal::AsyncSender<(WorkerId, ProgressEntry, Bytes)>,
90    puller: R,
91    retry_gap: Duration,
92    id: Arc<AtomicUsize>,
93    min_chunk_size: NonZeroU64,
94}
95impl<R, W> Executor for TokioExecutor<R, W>
96where
97    R: RandPuller + 'static + Sync,
98    W: RandPusher + 'static,
99{
100    type Handle = TokioHandle;
101    fn execute(self: Arc<Self>, task: Arc<Task>, task_list: Arc<TaskList<Self>>) -> Self::Handle {
102        let id = self.id.fetch_add(1, Ordering::SeqCst);
103        let handle = tokio::spawn(async move {
104            'steal_task: loop {
105                let mut start = task.start();
106                if start >= task.end() {
107                    if task_list.steal(&task, NonZero::new(2 * self.min_chunk_size.get()).unwrap())
108                    {
109                        continue;
110                    }
111                    break;
112                }
113                self.tx.send(Event::Pulling(id)).await.unwrap();
114                let download_range = start..task.end();
115                let mut puller = self.puller.clone();
116                let mut stream = puller.pull(&download_range);
117                loop {
118                    match stream.try_next().await {
119                        Ok(Some(mut chunk)) => {
120                            let len = chunk.len() as u64;
121                            task.fetch_add_start(len);
122                            let range_start = start;
123                            start += len;
124                            let range_end = start.min(task.end());
125                            if range_start >= range_end {
126                                continue 'steal_task;
127                            }
128                            let span = range_start..range_end;
129                            let len = span.total() as usize;
130                            self.tx
131                                .send(Event::PullProgress(id, span.clone()))
132                                .await
133                                .unwrap();
134                            self.tx_push
135                                .send((id, span, chunk.split_to(len)))
136                                .await
137                                .unwrap();
138                        }
139                        Ok(None) => break,
140                        Err(e) => {
141                            self.tx.send(Event::PullError(id, e)).await.unwrap();
142                            tokio::time::sleep(self.retry_gap).await;
143                        }
144                    }
145                }
146            }
147            self.tx.send(Event::Finished(id)).await.unwrap();
148            task_list.remove(&task);
149        });
150        TokioHandle(handle.abort_handle())
151    }
152}
153
154#[cfg(test)]
155mod tests {
156    extern crate std;
157    use super::*;
158    use crate::{
159        MergeProgress, ProgressEntry,
160        core::mock::{MockRandPuller, MockRandPusher, build_mock_data},
161    };
162    use alloc::vec;
163    use std::dbg;
164
165    #[tokio::test]
166    async fn test_concurrent_download() {
167        let mock_data = build_mock_data(3 * 1024);
168        let puller = MockRandPuller::new(&mock_data);
169        let pusher = MockRandPusher::new(&mock_data);
170        #[allow(clippy::single_range_in_vec_init)]
171        let download_chunks = vec![0..mock_data.len() as u64];
172        let result = download_multi(
173            puller,
174            pusher.clone(),
175            DownloadOptions {
176                concurrent: NonZero::new(32).unwrap(),
177                retry_gap: Duration::from_secs(1),
178                push_queue_cap: 1024,
179                download_chunks: download_chunks.clone(),
180                min_chunk_size: NonZero::new(1).unwrap(),
181            },
182        )
183        .await;
184
185        let mut pull_progress: Vec<ProgressEntry> = Vec::new();
186        let mut push_progress: Vec<ProgressEntry> = Vec::new();
187        let mut pull_ids = [false; 32];
188        let mut push_ids = [false; 32];
189        while let Ok(e) = result.event_chain.recv().await {
190            match e {
191                Event::PullProgress(id, p) => {
192                    pull_ids[id] = true;
193                    pull_progress.merge_progress(p);
194                }
195                Event::PushProgress(id, p) => {
196                    push_ids[id] = true;
197                    push_progress.merge_progress(p);
198                }
199                _ => {}
200            }
201        }
202        dbg!(&pull_progress);
203        dbg!(&push_progress);
204        assert_eq!(pull_progress, download_chunks);
205        assert_eq!(push_progress, download_chunks);
206        assert_eq!(pull_ids, [true; 32]);
207        assert_eq!(push_ids, [true; 32]);
208
209        result.join().await.unwrap();
210        pusher.assert().await;
211    }
212}