Skip to main content

fast_pull/core/
multi.rs

1extern crate std;
2use crate::{DownloadResult, Event, ProgressEntry, Puller, Pusher, Total, WorkerId};
3use bytes::Bytes;
4use core::{
5    sync::atomic::{AtomicUsize, Ordering},
6    time::Duration,
7};
8use crossfire::{MAsyncTx, MTx, mpmc, mpsc};
9use fast_steal::{Executor, Handle, Task, TaskQueue};
10use futures::TryStreamExt;
11use std::sync::Arc;
12use tokio::task::AbortHandle;
13
14#[derive(Debug, Clone)]
15pub struct DownloadOptions<'a, I: Iterator<Item = &'a ProgressEntry>> {
16    pub download_chunks: I,
17    pub concurrent: usize,
18    pub retry_gap: Duration,
19    pub push_queue_cap: usize,
20    pub min_chunk_size: u64,
21}
22
23pub fn download_multi<'a, R: Puller, W: Pusher, I: Iterator<Item = &'a ProgressEntry>>(
24    puller: R,
25    mut pusher: W,
26    options: DownloadOptions<'a, I>,
27) -> DownloadResult<TokioExecutor<R, W::Error>, R::Error, W::Error> {
28    let (tx, event_chain) = mpmc::unbounded_async();
29    let (tx_push, rx_push) = mpsc::bounded_async(options.push_queue_cap);
30    let tx_clone = tx.clone();
31    let rx_push = rx_push.into_blocking();
32    let push_handle = tokio::task::spawn_blocking(move || {
33        while let Ok((id, spin, mut data)) = rx_push.recv() {
34            loop {
35                match pusher.push(&spin, data) {
36                    Ok(_) => break,
37                    Err((err, bytes)) => {
38                        data = bytes;
39                        let _ = tx_clone.send(Event::PushError(id, err));
40                    }
41                }
42                std::thread::sleep(options.retry_gap);
43            }
44            let _ = tx_clone.send(Event::PushProgress(id, spin));
45        }
46        loop {
47            match pusher.flush() {
48                Ok(_) => break,
49                Err(err) => {
50                    let _ = tx_clone.send(Event::FlushError(err));
51                }
52            }
53            std::thread::sleep(options.retry_gap);
54        }
55    });
56    let executor: Arc<TokioExecutor<R, W::Error>> = Arc::new(TokioExecutor {
57        tx,
58        tx_push,
59        puller,
60        retry_gap: options.retry_gap,
61        id: AtomicUsize::new(0),
62        min_chunk_size: options.min_chunk_size,
63    });
64    let task_queue = TaskQueue::new(options.download_chunks);
65    task_queue.set_threads(
66        options.concurrent,
67        options.min_chunk_size,
68        Some(executor.as_ref()),
69    );
70    DownloadResult::new(
71        event_chain,
72        push_handle,
73        None,
74        Some((Arc::downgrade(&executor), task_queue)),
75    )
76}
77
78#[derive(Clone)]
79pub struct TokioHandle(AbortHandle);
80impl Handle for TokioHandle {
81    type Output = ();
82    fn abort(&mut self) -> Self::Output {
83        self.0.abort();
84    }
85}
86#[derive(Debug)]
87pub struct TokioExecutor<R, WE>
88where
89    R: Puller,
90    WE: Send + Unpin + 'static,
91{
92    tx: MTx<mpmc::List<Event<R::Error, WE>>>,
93    tx_push: MAsyncTx<mpsc::Array<(WorkerId, ProgressEntry, Bytes)>>,
94    puller: R,
95    retry_gap: Duration,
96    id: AtomicUsize,
97    min_chunk_size: u64,
98}
99impl<R, WE> Executor for TokioExecutor<R, WE>
100where
101    R: Puller,
102    WE: Send + Unpin + 'static,
103{
104    type Handle = TokioHandle;
105    fn execute(&self, task: Task, task_queue: TaskQueue<Self::Handle>) -> Self::Handle {
106        let id = self.id.fetch_add(1, Ordering::SeqCst);
107        let mut puller = self.puller.clone();
108        let min_chunk_size = self.min_chunk_size;
109        let cfg_retry_gap = self.retry_gap;
110        let tx = self.tx.clone();
111        let tx_push = self.tx_push.clone();
112        let handle = tokio::spawn(async move {
113            loop {
114                let mut start = task.start();
115                if start >= task.end() {
116                    if task_queue.steal(&task, min_chunk_size) {
117                        continue;
118                    } else {
119                        break;
120                    }
121                }
122                let _ = tx.send(Event::Pulling(id));
123                let download_range = start..task.end();
124                let mut stream = loop {
125                    match puller.pull(Some(&download_range)).await {
126                        Ok(t) => break t,
127                        Err((e, retry_gap)) => {
128                            let _ = tx.send(Event::PullError(id, e));
129                            tokio::time::sleep(retry_gap.unwrap_or(cfg_retry_gap)).await;
130                        }
131                    }
132                };
133                loop {
134                    match stream.try_next().await {
135                        Ok(Some(mut chunk)) => {
136                            let len = chunk.len() as u64;
137                            if task.fetch_add_start(len).is_err() {
138                                break;
139                            }
140                            let range_start = start;
141                            start += len;
142                            let range_end = start.min(task.end());
143                            if range_start >= range_end {
144                                break;
145                            }
146                            let span = range_start..range_end;
147                            chunk.truncate(span.total() as usize);
148                            let _ = tx.send(Event::PullProgress(id, span.clone()));
149                            let tx_push = tx_push.clone();
150                            let _ = tokio::spawn(async move {
151                                tx_push.send((id, span, chunk)).await.unwrap();
152                            })
153                            .await;
154                        }
155                        Ok(None) => break,
156                        Err((e, retry_gap)) => {
157                            let _ = tx.send(Event::PullError(id, e));
158                            tokio::time::sleep(retry_gap.unwrap_or(cfg_retry_gap)).await;
159                        }
160                    }
161                }
162            }
163            task_queue.finish_work(&task);
164            let _ = tx.send(Event::Finished(id));
165        });
166        TokioHandle(handle.abort_handle())
167    }
168}
169
170#[cfg(test)]
171mod tests {
172    use vec::Vec;
173
174    use super::*;
175    use crate::{
176        Merge, ProgressEntry,
177        mem::MemPusher,
178        mock::{MockPuller, build_mock_data},
179    };
180    use std::{dbg, vec};
181
182    #[tokio::test]
183    async fn test_concurrent_download() {
184        let mock_data = build_mock_data(3 * 1024);
185        let puller = MockPuller::new(&mock_data);
186        let pusher = MemPusher::with_capacity(mock_data.len());
187        #[allow(clippy::single_range_in_vec_init)]
188        let download_chunks = vec![0..mock_data.len() as u64];
189        let result = download_multi(
190            puller,
191            pusher.clone(),
192            DownloadOptions {
193                concurrent: 32,
194                retry_gap: Duration::from_secs(1),
195                push_queue_cap: 1024,
196                download_chunks: download_chunks.iter(),
197                min_chunk_size: 1,
198            },
199        );
200
201        let mut pull_progress: Vec<ProgressEntry> = Vec::new();
202        let mut push_progress: Vec<ProgressEntry> = Vec::new();
203        let mut pull_ids = [false; 32];
204        let mut push_ids = [false; 32];
205        while let Ok(e) = result.event_chain.recv().await {
206            match e {
207                Event::PullProgress(id, p) => {
208                    pull_ids[id] = true;
209                    pull_progress.merge_progress(p);
210                }
211                Event::PushProgress(id, p) => {
212                    push_ids[id] = true;
213                    push_progress.merge_progress(p);
214                }
215                _ => {}
216            }
217        }
218        dbg!(&pull_progress);
219        dbg!(&push_progress);
220        assert_eq!(pull_progress, download_chunks);
221        assert_eq!(push_progress, download_chunks);
222        assert_eq!(pull_ids, [true; 32]);
223        assert_eq!(push_ids, [true; 32]);
224
225        result.join().await.unwrap();
226        assert_eq!(&**pusher.receive.lock(), mock_data);
227    }
228}