Skip to main content

fast_pull/core/
multi.rs

1extern crate std;
2use crate::{DownloadResult, Event, ProgressEntry, Puller, PullerError, Pusher, WorkerId};
3use bytes::Bytes;
4use core::{
5    sync::atomic::{AtomicUsize, Ordering},
6    time::Duration,
7};
8use crossfire::{MAsyncTx, MTx, mpmc, mpsc};
9use fast_steal::{Executor, Handle, Task, TaskQueue};
10use futures::TryStreamExt;
11use std::sync::Arc;
12use tokio::task::AbortHandle;
13
14#[derive(Debug, Clone)]
15pub struct DownloadOptions<'a, I: Iterator<Item = &'a ProgressEntry>> {
16    pub download_chunks: I,
17    pub concurrent: usize,
18    pub retry_gap: Duration,
19    pub pull_timeout: Duration,
20    pub push_queue_cap: usize,
21    pub min_chunk_size: u64,
22}
23
24pub fn download_multi<'a, R: Puller, W: Pusher, I: Iterator<Item = &'a ProgressEntry>>(
25    puller: R,
26    mut pusher: W,
27    options: DownloadOptions<'a, I>,
28) -> DownloadResult<TokioExecutor<R, W::Error>, R::Error, W::Error> {
29    let (tx, event_chain) = mpmc::unbounded_async();
30    let (tx_push, rx_push) = mpsc::bounded_async(options.push_queue_cap);
31    let tx_clone = tx.clone();
32    let rx_push = rx_push.into_blocking();
33    let push_handle = tokio::task::spawn_blocking(move || {
34        while let Ok((id, spin, mut data)) = rx_push.recv() {
35            loop {
36                match pusher.push(&spin, data) {
37                    Ok(_) => break,
38                    Err((err, bytes)) => {
39                        data = bytes;
40                        let _ = tx_clone.send(Event::PushError(id, err));
41                    }
42                }
43                std::thread::sleep(options.retry_gap);
44            }
45            let _ = tx_clone.send(Event::PushProgress(id, spin));
46        }
47        loop {
48            match pusher.flush() {
49                Ok(_) => break,
50                Err(err) => {
51                    let _ = tx_clone.send(Event::FlushError(err));
52                }
53            }
54            std::thread::sleep(options.retry_gap);
55        }
56    });
57    let executor: Arc<TokioExecutor<R, W::Error>> = Arc::new(TokioExecutor {
58        tx,
59        tx_push,
60        puller,
61        pull_timeout: options.pull_timeout,
62        retry_gap: options.retry_gap,
63        id: AtomicUsize::new(0),
64        min_chunk_size: options.min_chunk_size,
65    });
66    let task_queue = TaskQueue::new(options.download_chunks);
67    task_queue.set_threads(
68        options.concurrent,
69        options.min_chunk_size,
70        Some(executor.as_ref()),
71    );
72    DownloadResult::new(
73        event_chain,
74        push_handle,
75        None,
76        Some((Arc::downgrade(&executor), task_queue)),
77    )
78}
79
80#[derive(Clone)]
81pub struct TokioHandle(AbortHandle);
82impl Handle for TokioHandle {
83    type Output = ();
84    fn abort(&mut self) -> Self::Output {
85        self.0.abort();
86    }
87}
88#[derive(Debug)]
89pub struct TokioExecutor<R, WE>
90where
91    R: Puller,
92    WE: Send + Unpin + 'static,
93{
94    tx: MTx<mpmc::List<Event<R::Error, WE>>>,
95    tx_push: MAsyncTx<mpsc::Array<(WorkerId, ProgressEntry, Bytes)>>,
96    puller: R,
97    retry_gap: Duration,
98    pull_timeout: Duration,
99    id: AtomicUsize,
100    min_chunk_size: u64,
101}
102impl<R, WE> Executor for TokioExecutor<R, WE>
103where
104    R: Puller,
105    WE: Send + Unpin + 'static,
106{
107    type Handle = TokioHandle;
108    fn execute(&self, mut task: Task, task_queue: TaskQueue<Self::Handle>) -> Self::Handle {
109        let id = self.id.fetch_add(1, Ordering::SeqCst);
110        let mut puller = self.puller.clone();
111        let min_chunk_size = self.min_chunk_size;
112        let pull_timeout = self.pull_timeout;
113        let cfg_retry_gap = self.retry_gap;
114        let tx = self.tx.clone();
115        let tx_push = self.tx_push.clone();
116        let handle = tokio::spawn(async move {
117            loop {
118                let mut start = task.start();
119                if start >= task.end() {
120                    if task_queue.steal(&mut task, min_chunk_size, true) {
121                        continue;
122                    } else {
123                        break;
124                    }
125                }
126                let _ = tx.send(Event::Pulling(id));
127                let download_range = start..task.end();
128                let mut stream = loop {
129                    match puller.pull(Some(&download_range)).await {
130                        Ok(t) => break t,
131                        Err((e, retry_gap)) => {
132                            let _ = tx.send(Event::PullError(id, e));
133                            tokio::time::sleep(retry_gap.unwrap_or(cfg_retry_gap)).await;
134                        }
135                    }
136                };
137                loop {
138                    match tokio::time::timeout(pull_timeout, stream.try_next()).await {
139                        Ok(Ok(Some(mut chunk))) => {
140                            if chunk.is_empty() {
141                                continue;
142                            }
143                            let len = chunk.len() as u64;
144                            let Ok(span) = task.safe_add_start(start, len) else {
145                                break;
146                            };
147                            if span.end >= task.end() {
148                                task_queue.cancel_tasks(&task);
149                            }
150                            chunk = chunk
151                                .slice((span.start - start) as usize..(span.end - start) as usize);
152                            start = span.end;
153                            let _ = tx.send(Event::PullProgress(id, span.clone()));
154                            let tx_push = tx_push.clone();
155                            let _ = tokio::spawn(async move {
156                                tx_push.send((id, span, chunk)).await.unwrap();
157                            })
158                            .await;
159                            if start >= task.end() {
160                                break;
161                            }
162                        }
163                        Ok(Ok(None)) => break,
164                        Ok(Err((e, retry_gap))) => {
165                            let is_irrecoverable = e.is_irrecoverable();
166                            let _ = tx.send(Event::PullError(id, e));
167                            tokio::time::sleep(retry_gap.unwrap_or(cfg_retry_gap)).await;
168                            if is_irrecoverable {
169                                break;
170                            }
171                        }
172                        Err(_) => {
173                            let _ = tx.send(Event::PullTimeout(id));
174                            drop(stream);
175                            puller = puller.clone();
176                            break;
177                        }
178                    }
179                }
180            }
181            let _ = tx.send(Event::Finished(id));
182        });
183        TokioHandle(handle.abort_handle())
184    }
185}
186
187#[cfg(test)]
188mod tests {
189    use vec::Vec;
190
191    use super::*;
192    use crate::{
193        Merge, ProgressEntry,
194        mem::MemPusher,
195        mock::{MockPuller, build_mock_data},
196    };
197    use std::{dbg, vec};
198
199    #[tokio::test]
200    async fn test_concurrent_download() {
201        let mock_data = build_mock_data(3 * 1024);
202        let puller = MockPuller::new(&mock_data);
203        let pusher = MemPusher::with_capacity(mock_data.len());
204        #[allow(clippy::single_range_in_vec_init)]
205        let download_chunks = vec![0..mock_data.len() as u64];
206        let result = download_multi(
207            puller,
208            pusher.clone(),
209            DownloadOptions {
210                concurrent: 32,
211                retry_gap: Duration::from_secs(1),
212                push_queue_cap: 1024,
213                download_chunks: download_chunks.iter(),
214                pull_timeout: Duration::from_secs(5),
215                min_chunk_size: 1,
216            },
217        );
218
219        let mut pull_progress: Vec<ProgressEntry> = Vec::new();
220        let mut push_progress: Vec<ProgressEntry> = Vec::new();
221        let mut pull_ids = [false; 32];
222        let mut push_ids = [false; 32];
223        while let Ok(e) = result.event_chain.recv().await {
224            match e {
225                Event::PullProgress(id, p) => {
226                    pull_ids[id] = true;
227                    pull_progress.merge_progress(p);
228                }
229                Event::PushProgress(id, p) => {
230                    push_ids[id] = true;
231                    push_progress.merge_progress(p);
232                }
233                _ => {}
234            }
235        }
236        dbg!(&pull_progress);
237        dbg!(&push_progress);
238        assert_eq!(pull_progress, download_chunks);
239        assert_eq!(push_progress, download_chunks);
240        assert_eq!(pull_ids, [true; 32]);
241        assert_eq!(push_ids, [true; 32]);
242
243        result.join().await.unwrap();
244        assert_eq!(&**pusher.receive.lock(), mock_data);
245    }
246}