Skip to main content

fast_pull/core/
multi.rs

1use crate::{DownloadResult, Event, ProgressEntry, Puller, PullerError, Pusher, WorkerId};
2use bytes::Bytes;
3use core::{
4    sync::atomic::{AtomicUsize, Ordering},
5    time::Duration,
6};
7use crossfire::{MAsyncTx, MTx, mpmc, mpsc};
8use fast_steal::{Executor, Handle, Task, TaskQueue};
9use futures::TryStreamExt;
10use std::sync::Arc;
11use tokio::task::AbortHandle;
12
13#[derive(Debug, Clone)]
14pub struct DownloadOptions<'a, I: Iterator<Item = &'a ProgressEntry>> {
15    pub download_chunks: I,
16    pub concurrent: usize,
17    pub retry_gap: Duration,
18    pub pull_timeout: Duration,
19    pub push_queue_cap: usize,
20    pub min_chunk_size: u64,
21    pub max_speculative: usize,
22}
23
24pub fn download_multi<'a, R: Puller, W: Pusher, I: Iterator<Item = &'a ProgressEntry>>(
25    puller: R,
26    mut pusher: W,
27    options: DownloadOptions<'a, I>,
28) -> DownloadResult<TokioExecutor<R, W::Error>, R::Error, W::Error> {
29    let (tx, event_chain) = mpmc::unbounded_async();
30    let (tx_push, rx_push) = mpsc::bounded_async(options.push_queue_cap);
31    let tx_clone = tx.clone();
32    let rx_push = rx_push.into_blocking();
33    let push_handle = tokio::task::spawn_blocking(move || {
34        while let Ok((id, spin, mut data)) = rx_push.recv() {
35            loop {
36                match pusher.push(&spin, data) {
37                    Ok(_) => break,
38                    Err((err, bytes)) => {
39                        data = bytes;
40                        let _ = tx_clone.send(Event::PushError(id, err));
41                    }
42                }
43                std::thread::sleep(options.retry_gap);
44            }
45            let _ = tx_clone.send(Event::PushProgress(id, spin));
46        }
47        loop {
48            match pusher.flush() {
49                Ok(_) => break,
50                Err(err) => {
51                    let _ = tx_clone.send(Event::FlushError(err));
52                }
53            }
54            std::thread::sleep(options.retry_gap);
55        }
56    });
57    let executor: Arc<TokioExecutor<R, W::Error>> = Arc::new(TokioExecutor {
58        tx,
59        tx_push,
60        puller,
61        pull_timeout: options.pull_timeout,
62        retry_gap: options.retry_gap,
63        id: AtomicUsize::new(0),
64        min_chunk_size: options.min_chunk_size,
65        max_speculative: options.max_speculative,
66    });
67    let task_queue = TaskQueue::new(options.download_chunks);
68    task_queue.set_threads(
69        options.concurrent,
70        options.min_chunk_size,
71        Some(executor.as_ref()),
72    );
73    DownloadResult::new(
74        event_chain,
75        push_handle,
76        None,
77        Some((Arc::downgrade(&executor), task_queue)),
78    )
79}
80
81#[derive(Clone)]
82pub struct TokioHandle {
83    id: usize,
84    abort_handle: AbortHandle,
85}
86impl Handle for TokioHandle {
87    type Output = ();
88    type Id = usize;
89    fn abort(&mut self) -> Self::Output {
90        self.abort_handle.abort();
91    }
92    fn is_self(&mut self, id: &Self::Id) -> bool {
93        self.id == *id
94    }
95}
96#[derive(Debug)]
97pub struct TokioExecutor<R, WE>
98where
99    R: Puller,
100    WE: Send + Unpin + 'static,
101{
102    tx: MTx<mpmc::List<Event<R::Error, WE>>>,
103    tx_push: MAsyncTx<mpsc::Array<(WorkerId, ProgressEntry, Bytes)>>,
104    puller: R,
105    retry_gap: Duration,
106    pull_timeout: Duration,
107    id: AtomicUsize,
108    min_chunk_size: u64,
109    max_speculative: usize,
110}
111impl<R, WE> Executor for TokioExecutor<R, WE>
112where
113    R: Puller,
114    WE: Send + Unpin + 'static,
115{
116    type Handle = TokioHandle;
117    fn execute(&self, mut task: Task, task_queue: TaskQueue<Self::Handle>) -> Self::Handle {
118        let id = self.id.fetch_add(1, Ordering::SeqCst);
119        let mut puller = self.puller.clone();
120        let min_chunk_size = self.min_chunk_size;
121        let pull_timeout = self.pull_timeout;
122        let cfg_retry_gap = self.retry_gap;
123        let max_speculative = self.max_speculative;
124        let tx = self.tx.clone();
125        let tx_push = self.tx_push.clone();
126        let handle = tokio::spawn(async move {
127            'task: loop {
128                let mut start = task.start();
129                if start >= task.end() {
130                    if task_queue.steal(&mut task, min_chunk_size, max_speculative) {
131                        continue 'task;
132                    } else {
133                        break;
134                    }
135                }
136                let _ = tx.send(Event::Pulling(id));
137                let download_range = start..task.end();
138                let mut stream = loop {
139                    match puller.pull(Some(&download_range)).await {
140                        Ok(t) => break t,
141                        Err((e, retry_gap)) => {
142                            let _ = tx.send(Event::PullError(id, e));
143                            tokio::time::sleep(retry_gap.unwrap_or(cfg_retry_gap)).await;
144                        }
145                    }
146                };
147                loop {
148                    match tokio::time::timeout(pull_timeout, stream.try_next()).await {
149                        Ok(Ok(Some(mut chunk))) => {
150                            if chunk.is_empty() {
151                                continue;
152                            }
153                            let len = chunk.len() as u64;
154                            let Ok(span) = task.safe_add_start(start, len) else {
155                                start += len;
156                                continue;
157                            };
158                            if span.end >= task.end() {
159                                task_queue.cancel_task(&task, &id);
160                            }
161                            chunk = chunk
162                                .slice((span.start - start) as usize..(span.end - start) as usize);
163                            start = span.end;
164                            let _ = tx.send(Event::PullProgress(id, span.clone()));
165                            let tx_push = tx_push.clone();
166                            let _ = tokio::spawn(async move {
167                                tx_push.send((id, span, chunk)).await.unwrap();
168                            })
169                            .await;
170                            if start >= task.end() {
171                                continue 'task;
172                            }
173                        }
174                        Ok(Ok(None)) => continue 'task,
175                        Ok(Err((e, retry_gap))) => {
176                            let is_irrecoverable = e.is_irrecoverable();
177                            let _ = tx.send(Event::PullError(id, e));
178                            tokio::time::sleep(retry_gap.unwrap_or(cfg_retry_gap)).await;
179                            if is_irrecoverable {
180                                continue 'task;
181                            }
182                        }
183                        Err(_) => {
184                            let _ = tx.send(Event::PullTimeout(id));
185                            drop(stream);
186                            puller = puller.clone();
187                            continue 'task;
188                        }
189                    }
190                }
191            }
192            let _ = tx.send(Event::Finished(id));
193        });
194        TokioHandle {
195            id,
196            abort_handle: handle.abort_handle(),
197        }
198    }
199}
200
201#[cfg(test)]
202mod tests {
203    use vec::Vec;
204
205    use super::*;
206    use crate::{
207        Merge, ProgressEntry,
208        mem::MemPusher,
209        mock::{MockPuller, build_mock_data},
210    };
211    use std::{dbg, vec};
212
213    #[tokio::test]
214    async fn test_concurrent_download() {
215        let mock_data = build_mock_data(3 * 1024);
216        let puller = MockPuller::new(&mock_data);
217        let pusher = MemPusher::with_capacity(mock_data.len());
218        #[allow(clippy::single_range_in_vec_init)]
219        let download_chunks = vec![0..mock_data.len() as u64];
220        let result = download_multi(
221            puller,
222            pusher.clone(),
223            DownloadOptions {
224                concurrent: 32,
225                retry_gap: Duration::from_secs(1),
226                push_queue_cap: 1024,
227                download_chunks: download_chunks.iter(),
228                pull_timeout: Duration::from_secs(5),
229                min_chunk_size: 1,
230                max_speculative: 3,
231            },
232        );
233
234        let mut pull_progress: Vec<ProgressEntry> = Vec::new();
235        let mut push_progress: Vec<ProgressEntry> = Vec::new();
236        let mut pull_ids = [false; 32];
237        let mut push_ids = [false; 32];
238        while let Ok(e) = result.event_chain.recv().await {
239            match e {
240                Event::PullProgress(id, p) => {
241                    pull_ids[id] = true;
242                    pull_progress.merge_progress(p);
243                }
244                Event::PushProgress(id, p) => {
245                    push_ids[id] = true;
246                    push_progress.merge_progress(p);
247                }
248                _ => {}
249            }
250        }
251        dbg!(&pull_progress);
252        dbg!(&push_progress);
253        assert_eq!(pull_progress, download_chunks);
254        assert_eq!(push_progress, download_chunks);
255        assert_eq!(pull_ids, [true; 32]);
256        assert_eq!(push_ids, [true; 32]);
257
258        result.join().await.unwrap();
259        assert_eq!(&**pusher.receive.lock(), mock_data);
260    }
261}