fast_pull/core/
multi.rs

1extern crate alloc;
2use super::macros::poll_ok;
3use crate::{DownloadResult, Event, ProgressEntry, RandPuller, RandPusher, Total, WorkerId};
4use alloc::{sync::Arc, vec::Vec};
5use bytes::Bytes;
6use core::{
7    num::{NonZero, NonZeroU64, NonZeroUsize},
8    sync::atomic::{AtomicUsize, Ordering},
9    time::Duration,
10};
11use fast_steal::{Executor, Handle, Task, TaskList};
12use futures::TryStreamExt;
13use tokio::task::AbortHandle;
14
15#[derive(Debug, Clone)]
16pub struct DownloadOptions {
17    pub download_chunks: Vec<ProgressEntry>,
18    pub concurrent: NonZeroUsize,
19    pub retry_gap: Duration,
20    pub push_queue_cap: usize,
21    pub min_chunk_size: NonZeroU64,
22}
23
24pub async fn download_multi<R, W>(
25    puller: R,
26    mut pusher: W,
27    options: DownloadOptions,
28) -> DownloadResult<TokioExecutor<R, W::Error>, R::Error, W::Error>
29where
30    R: RandPuller + 'static + Sync,
31    W: RandPusher + 'static,
32{
33    let (tx, event_chain) = kanal::unbounded_async();
34    let (tx_push, rx_push) =
35        kanal::bounded_async::<(WorkerId, ProgressEntry, Bytes)>(options.push_queue_cap);
36    let tx_clone = tx.clone();
37    let push_handle = tokio::spawn(async move {
38        while let Ok((id, spin, data)) = rx_push.recv().await {
39            poll_ok!(
40                pusher.push(spin.clone(), &data).await,
41                id @ tx_clone => PushError,
42                options.retry_gap
43            );
44            tx_clone.send(Event::PushProgress(id, spin)).await.unwrap();
45        }
46        poll_ok!(
47            pusher.flush().await,
48            tx_clone => FlushError,
49            options.retry_gap
50        );
51    });
52    let executor: Arc<TokioExecutor<R, W::Error>> = Arc::new(TokioExecutor {
53        tx,
54        tx_push,
55        puller,
56        retry_gap: options.retry_gap,
57        id: Arc::new(AtomicUsize::new(0)),
58        min_chunk_size: options.min_chunk_size,
59    });
60    let task_list = Arc::new(TaskList::run(&options.download_chunks[..], executor));
61    task_list
62        .clone()
63        .set_threads(options.concurrent, options.min_chunk_size);
64    DownloadResult::new(
65        event_chain,
66        push_handle,
67        None,
68        Some(Arc::downgrade(&task_list)),
69    )
70}
71
72#[derive(Clone)]
73pub struct TokioHandle(AbortHandle);
74impl Handle for TokioHandle {
75    type Output = ();
76    fn abort(&mut self) -> Self::Output {
77        self.0.abort();
78    }
79}
80#[derive(Debug)]
81pub struct TokioExecutor<R, WE>
82where
83    R: RandPuller + 'static,
84    WE: Send + 'static,
85{
86    tx: kanal::AsyncSender<Event<R::Error, WE>>,
87    tx_push: kanal::AsyncSender<(WorkerId, ProgressEntry, Bytes)>,
88    puller: R,
89    retry_gap: Duration,
90    id: Arc<AtomicUsize>,
91    min_chunk_size: NonZeroU64,
92}
93impl<R, WE> Executor for TokioExecutor<R, WE>
94where
95    R: RandPuller + 'static + Sync,
96    WE: Send + 'static,
97{
98    type Handle = TokioHandle;
99    fn execute(self: Arc<Self>, task: Arc<Task>, task_list: Arc<TaskList<Self>>) -> Self::Handle {
100        let id = self.id.fetch_add(1, Ordering::SeqCst);
101        let steal_min_chunk_size = NonZero::new(2 * self.min_chunk_size.get()).unwrap();
102        let mut puller = self.puller.clone();
103        let handle = tokio::spawn(async move {
104            'steal_task: loop {
105                let mut start = task.start();
106                if start >= task.end() {
107                    if task_list.steal(&task, steal_min_chunk_size) {
108                        continue;
109                    } else {
110                        break;
111                    }
112                }
113                self.tx.send(Event::Pulling(id)).await.unwrap();
114                let download_range = start..task.end();
115                let mut stream = loop {
116                    match puller.pull(&download_range).await {
117                        Ok(t) => break t,
118                        Err((e, retry_gap)) => {
119                            self.tx.send(Event::PullError(id, e)).await.unwrap();
120                            tokio::time::sleep(retry_gap.unwrap_or(self.retry_gap)).await;
121                        }
122                    }
123                };
124                loop {
125                    match stream.try_next().await {
126                        Ok(Some(mut chunk)) => {
127                            let len = chunk.len() as u64;
128                            task.fetch_add_start(len);
129                            let range_start = start;
130                            start += len;
131                            let range_end = start.min(task.end());
132                            if range_start >= range_end {
133                                continue 'steal_task;
134                            }
135                            let span = range_start..range_end;
136                            chunk.truncate(span.total() as usize);
137                            self.tx
138                                .send(Event::PullProgress(id, span.clone()))
139                                .await
140                                .unwrap();
141                            self.tx_push.send((id, span, chunk)).await.unwrap();
142                        }
143                        Ok(None) => break,
144                        Err((e, retry_gap)) => {
145                            self.tx.send(Event::PullError(id, e)).await.unwrap();
146                            tokio::time::sleep(retry_gap.unwrap_or(self.retry_gap)).await;
147                        }
148                    }
149                }
150            }
151            task_list.remove(&task);
152            self.tx.send(Event::Finished(id)).await.unwrap();
153        });
154        TokioHandle(handle.abort_handle())
155    }
156}
157
158#[cfg(test)]
159mod tests {
160    extern crate std;
161    use super::*;
162    use crate::{
163        MergeProgress, ProgressEntry,
164        mem::MemPusher,
165        mock::{MockPuller, build_mock_data},
166    };
167    use alloc::vec;
168    use std::dbg;
169
170    #[tokio::test]
171    async fn test_concurrent_download() {
172        let mock_data = build_mock_data(3 * 1024);
173        let puller = MockPuller::new(&mock_data);
174        let pusher = MemPusher::with_capacity(mock_data.len());
175        #[allow(clippy::single_range_in_vec_init)]
176        let download_chunks = vec![0..mock_data.len() as u64];
177        let result = download_multi(
178            puller,
179            pusher.clone(),
180            DownloadOptions {
181                concurrent: NonZero::new(32).unwrap(),
182                retry_gap: Duration::from_secs(1),
183                push_queue_cap: 1024,
184                download_chunks: download_chunks.clone(),
185                min_chunk_size: NonZero::new(1).unwrap(),
186            },
187        )
188        .await;
189
190        let mut pull_progress: Vec<ProgressEntry> = Vec::new();
191        let mut push_progress: Vec<ProgressEntry> = Vec::new();
192        let mut pull_ids = [false; 32];
193        let mut push_ids = [false; 32];
194        while let Ok(e) = result.event_chain.recv().await {
195            match e {
196                Event::PullProgress(id, p) => {
197                    pull_ids[id] = true;
198                    pull_progress.merge_progress(p);
199                }
200                Event::PushProgress(id, p) => {
201                    push_ids[id] = true;
202                    push_progress.merge_progress(p);
203                }
204                _ => {}
205            }
206        }
207        dbg!(&pull_progress);
208        dbg!(&push_progress);
209        assert_eq!(pull_progress, download_chunks);
210        assert_eq!(push_progress, download_chunks);
211        assert_eq!(pull_ids, [true; 32]);
212        assert_eq!(push_ids, [true; 32]);
213
214        result.join().await.unwrap();
215        assert_eq!(&**pusher.receive.lock().await, mock_data);
216    }
217}