fast_pull/core/
multi.rs

1extern crate alloc;
2use super::macros::poll_ok;
3use crate::{DownloadResult, Event, ProgressEntry, RandPuller, RandPusher, Total, WorkerId};
4use alloc::{sync::Arc, vec::Vec};
5use bytes::Bytes;
6use core::{
7    num::{NonZeroU64, NonZeroUsize},
8    sync::atomic::{AtomicUsize, Ordering},
9    time::Duration,
10};
11use fast_steal::{Executor, Handle, Task, TaskList};
12use futures::TryStreamExt;
13use tokio::task::AbortHandle;
14
15#[derive(Debug, Clone)]
16pub struct DownloadOptions {
17    pub download_chunks: Vec<ProgressEntry>,
18    pub concurrent: NonZeroUsize,
19    pub retry_gap: Duration,
20    pub push_queue_cap: usize,
21    pub min_chunk_size: NonZeroU64,
22}
23
24pub async fn download_multi<R, W>(
25    puller: R,
26    mut pusher: W,
27    options: DownloadOptions,
28) -> DownloadResult<TokioExecutor<R, W::Error>, R::Error, W::Error>
29where
30    R: RandPuller + 'static + Sync,
31    W: RandPusher + 'static,
32{
33    let (tx, event_chain) = kanal::unbounded_async();
34    let (tx_push, rx_push) =
35        kanal::bounded_async::<(WorkerId, ProgressEntry, Bytes)>(options.push_queue_cap);
36    let tx_clone = tx.clone();
37    let push_handle = tokio::spawn(async move {
38        while let Ok((id, spin, data)) = rx_push.recv().await {
39            poll_ok!(
40                pusher.push(spin.clone(), &data).await,
41                id @ tx_clone => PushError,
42                options.retry_gap
43            );
44            let _ = tx_clone.send(Event::PushProgress(id, spin)).await;
45        }
46        poll_ok!(
47            pusher.flush().await,
48            tx_clone => FlushError,
49            options.retry_gap
50        );
51    });
52    let executor: Arc<TokioExecutor<R, W::Error>> = Arc::new(TokioExecutor {
53        tx,
54        tx_push,
55        puller,
56        retry_gap: options.retry_gap,
57        id: Arc::new(AtomicUsize::new(0)),
58        min_chunk_size: options.min_chunk_size,
59    });
60    let task_list = Arc::new(TaskList::run(&options.download_chunks[..], executor));
61    task_list
62        .clone()
63        .set_threads(options.concurrent, options.min_chunk_size);
64    DownloadResult::new(
65        event_chain,
66        push_handle,
67        None,
68        Some(Arc::downgrade(&task_list)),
69    )
70}
71
72#[derive(Clone)]
73pub struct TokioHandle(AbortHandle);
74impl Handle for TokioHandle {
75    type Output = ();
76    fn abort(&mut self) -> Self::Output {
77        self.0.abort();
78    }
79}
80#[derive(Debug)]
81pub struct TokioExecutor<R, WE>
82where
83    R: RandPuller + 'static,
84    WE: Send + 'static,
85{
86    tx: kanal::AsyncSender<Event<R::Error, WE>>,
87    tx_push: kanal::AsyncSender<(WorkerId, ProgressEntry, Bytes)>,
88    puller: R,
89    retry_gap: Duration,
90    id: Arc<AtomicUsize>,
91    min_chunk_size: NonZeroU64,
92}
93impl<R, WE> Executor for TokioExecutor<R, WE>
94where
95    R: RandPuller + 'static + Sync,
96    WE: Send + 'static,
97{
98    type Handle = TokioHandle;
99    fn execute(self: Arc<Self>, task: Arc<Task>, task_list: Arc<TaskList<Self>>) -> Self::Handle {
100        let id = self.id.fetch_add(1, Ordering::SeqCst);
101        let mut puller = self.puller.clone();
102        let handle = tokio::spawn(async move {
103            'steal_task: loop {
104                let mut start = task.start();
105                if start >= task.end() {
106                    if task_list.steal(&task, self.min_chunk_size) {
107                        continue;
108                    } else {
109                        break;
110                    }
111                }
112                let _ = self.tx.send(Event::Pulling(id)).await;
113                let download_range = start..task.end();
114                let mut stream = loop {
115                    match puller.pull(&download_range).await {
116                        Ok(t) => break t,
117                        Err((e, retry_gap)) => {
118                            let _ = self.tx.send(Event::PullError(id, e)).await;
119                            tokio::time::sleep(retry_gap.unwrap_or(self.retry_gap)).await;
120                        }
121                    }
122                };
123                loop {
124                    match stream.try_next().await {
125                        Ok(Some(mut chunk)) => {
126                            let len = chunk.len() as u64;
127                            task.fetch_add_start(len);
128                            let range_start = start;
129                            start += len;
130                            let range_end = start.min(task.end());
131                            if range_start >= range_end {
132                                continue 'steal_task;
133                            }
134                            let span = range_start..range_end;
135                            chunk.truncate(span.total() as usize);
136                            let _ = self.tx.send(Event::PullProgress(id, span.clone())).await;
137                            self.tx_push.send((id, span, chunk)).await.unwrap();
138                        }
139                        Ok(None) => break,
140                        Err((e, retry_gap)) => {
141                            let _ = self.tx.send(Event::PullError(id, e)).await;
142                            tokio::time::sleep(retry_gap.unwrap_or(self.retry_gap)).await;
143                        }
144                    }
145                }
146            }
147            task_list.remove(&task);
148            let _ = self.tx.send(Event::Finished(id)).await;
149        });
150        TokioHandle(handle.abort_handle())
151    }
152}
153
154#[cfg(test)]
155mod tests {
156    extern crate std;
157    use super::*;
158    use crate::{
159        MergeProgress, ProgressEntry,
160        mem::MemPusher,
161        mock::{MockPuller, build_mock_data},
162    };
163    use alloc::vec;
164    use core::num::NonZero;
165    use std::dbg;
166
167    #[tokio::test]
168    async fn test_concurrent_download() {
169        let mock_data = build_mock_data(3 * 1024);
170        let puller = MockPuller::new(&mock_data);
171        let pusher = MemPusher::with_capacity(mock_data.len());
172        #[allow(clippy::single_range_in_vec_init)]
173        let download_chunks = vec![0..mock_data.len() as u64];
174        let result = download_multi(
175            puller,
176            pusher.clone(),
177            DownloadOptions {
178                concurrent: NonZero::new(32).unwrap(),
179                retry_gap: Duration::from_secs(1),
180                push_queue_cap: 1024,
181                download_chunks: download_chunks.clone(),
182                min_chunk_size: NonZero::new(1).unwrap(),
183            },
184        )
185        .await;
186
187        let mut pull_progress: Vec<ProgressEntry> = Vec::new();
188        let mut push_progress: Vec<ProgressEntry> = Vec::new();
189        let mut pull_ids = [false; 32];
190        let mut push_ids = [false; 32];
191        while let Ok(e) = result.event_chain.recv().await {
192            match e {
193                Event::PullProgress(id, p) => {
194                    pull_ids[id] = true;
195                    pull_progress.merge_progress(p);
196                }
197                Event::PushProgress(id, p) => {
198                    push_ids[id] = true;
199                    push_progress.merge_progress(p);
200                }
201                _ => {}
202            }
203        }
204        dbg!(&pull_progress);
205        dbg!(&push_progress);
206        assert_eq!(pull_progress, download_chunks);
207        assert_eq!(push_progress, download_chunks);
208        assert_eq!(pull_ids, [true; 32]);
209        assert_eq!(push_ids, [true; 32]);
210
211        result.join().await.unwrap();
212        assert_eq!(&**pusher.receive.lock(), mock_data);
213    }
214}