Skip to main content

fast_pull/core/
multi.rs

1use crate::{DownloadResult, Event, ProgressEntry, Puller, PullerError, Pusher, WorkerId};
2use bytes::Bytes;
3use core::{
4    sync::atomic::{AtomicUsize, Ordering},
5    time::Duration,
6};
7use crossfire::{MAsyncTx, MTx, mpmc, mpsc};
8use fast_steal::{Executor, Handle, Task, TaskQueue};
9use futures::TryStreamExt;
10use std::sync::Arc;
11use tokio::sync::Notify;
12
13#[derive(Debug, Clone)]
14pub struct DownloadOptions<I: Iterator<Item = ProgressEntry>> {
15    pub download_chunks: I,
16    pub concurrent: usize,
17    pub retry_gap: Duration,
18    pub pull_timeout: Duration,
19    pub push_queue_cap: usize,
20    pub min_chunk_size: u64,
21    pub max_speculative: usize,
22}
23
24/// # Panics
25/// 当设置线程数,但 executor 意外为空时,panic
26pub fn download_multi<R: Puller, W: Pusher, I: Iterator<Item = ProgressEntry>>(
27    puller: R,
28    mut pusher: W,
29    options: DownloadOptions<I>,
30) -> DownloadResult<TokioExecutor<R, W::Error>, R::Error, W::Error> {
31    let (tx, event_chain) = mpmc::unbounded_async();
32    let (tx_push, rx_push) = mpsc::bounded_async(options.push_queue_cap);
33    let tx_clone = tx.clone();
34    let rx_push = rx_push.into_blocking();
35    let push_handle = tokio::task::spawn_blocking(move || {
36        while let Ok((id, spin, mut data)) = rx_push.recv() {
37            loop {
38                match pusher.push(&spin, data) {
39                    Ok(()) => break,
40                    Err((err, bytes)) => {
41                        data = bytes;
42                        let _ = tx_clone.send(Event::PushError(id, err));
43                    }
44                }
45                std::thread::sleep(options.retry_gap);
46            }
47            let _ = tx_clone.send(Event::PushProgress(id, spin));
48        }
49        loop {
50            match pusher.flush() {
51                Ok(()) => break,
52                Err(err) => {
53                    let _ = tx_clone.send(Event::FlushError(err));
54                }
55            }
56            std::thread::sleep(options.retry_gap);
57        }
58    });
59    let executor: Arc<TokioExecutor<R, W::Error>> = Arc::new(TokioExecutor {
60        tx,
61        tx_push,
62        puller,
63        pull_timeout: options.pull_timeout,
64        retry_gap: options.retry_gap,
65        id: AtomicUsize::new(0),
66        min_chunk_size: options.min_chunk_size,
67        max_speculative: options.max_speculative,
68    });
69    let task_queue = TaskQueue::new(options.download_chunks);
70    #[allow(clippy::unwrap_used)]
71    task_queue
72        .set_threads(
73            options.concurrent,
74            options.min_chunk_size,
75            Some(executor.as_ref()),
76        )
77        .unwrap();
78    DownloadResult::new(
79        event_chain,
80        push_handle,
81        None,
82        Some((Arc::downgrade(&executor), task_queue)),
83    )
84}
85
86#[derive(Debug, Clone)]
87pub struct TokioHandle {
88    id: usize,
89    notify: Arc<Notify>,
90}
91impl Handle for TokioHandle {
92    type Output = ();
93    type Id = usize;
94    fn abort(&mut self) -> Self::Output {
95        self.notify.notify_one();
96    }
97    fn is_self(&mut self, id: &Self::Id) -> bool {
98        self.id == *id
99    }
100}
101#[derive(Debug)]
102pub struct TokioExecutor<R, WE>
103where
104    R: Puller,
105    WE: Send + Unpin + 'static,
106{
107    tx: MTx<mpmc::List<Event<R::Error, WE>>>,
108    tx_push: MAsyncTx<mpsc::Array<(WorkerId, ProgressEntry, Bytes)>>,
109    puller: R,
110    retry_gap: Duration,
111    pull_timeout: Duration,
112    id: AtomicUsize,
113    min_chunk_size: u64,
114    max_speculative: usize,
115}
116impl<R, WE> Executor for TokioExecutor<R, WE>
117where
118    R: Puller,
119    WE: Send + Unpin + 'static,
120{
121    type Handle = TokioHandle;
122    fn execute(&self, mut task: Task, task_queue: TaskQueue<Self::Handle>) -> Self::Handle {
123        let id = self.id.fetch_add(1, Ordering::SeqCst);
124        let mut puller = self.puller.clone();
125        let min_chunk_size = self.min_chunk_size;
126        let pull_timeout = self.pull_timeout;
127        let cfg_retry_gap = self.retry_gap;
128        let max_speculative = self.max_speculative;
129        let tx = self.tx.clone();
130        let tx_push = self.tx_push.clone();
131        let notify = Arc::new(Notify::new());
132        let notify_clone = notify.clone();
133        tokio::spawn(async move {
134            'task: loop {
135                let mut start = task.start();
136                if start >= task.end() {
137                    if task_queue.steal(&mut task, min_chunk_size, max_speculative) {
138                        continue 'task;
139                    }
140                    break;
141                }
142                let _ = tx.send(Event::Pulling(id));
143                let download_range = start..task.end();
144                let mut stream = loop {
145                    let t = tokio::select! {
146                        () = notify.notified() => break 'task,
147                        t = puller.pull(Some(&download_range)) => t
148                    };
149                    match t {
150                        Ok(t) => break t,
151                        Err((e, retry_gap)) => {
152                            let _ = tx.send(Event::PullError(id, e));
153                            tokio::select! {
154                                () = notify.notified() => break 'task,
155                                () = tokio::time::sleep(retry_gap.unwrap_or(cfg_retry_gap)) => {}
156                            };
157                        }
158                    }
159                };
160                loop {
161                    let t = tokio::select! {
162                        () = notify.notified() => break 'task,
163                        t = tokio::time::timeout(pull_timeout, stream.try_next()) => t
164                    };
165                    match t {
166                        Ok(Ok(Some(mut chunk))) => {
167                            if chunk.is_empty() {
168                                continue;
169                            }
170                            let len = chunk.len() as u64;
171                            let Ok(span) = task.safe_add_start(start, len) else {
172                                start += len;
173                                continue;
174                            };
175                            if span.end >= task.end() {
176                                task_queue.cancel_task(&task, &id);
177                            }
178                            #[allow(clippy::cast_possible_truncation)]
179                            let slice_span =
180                                (span.start - start) as usize..(span.end - start) as usize;
181                            chunk = chunk.slice(slice_span);
182                            start = span.end;
183                            let _ = tx.send(Event::PullProgress(id, span.clone()));
184                            let _ = tx_push.send((id, span, chunk)).await;
185                            if start >= task.end() {
186                                continue 'task;
187                            }
188                        }
189                        Ok(Ok(None)) => continue 'task,
190                        Ok(Err((e, retry_gap))) => {
191                            let is_irrecoverable = e.is_irrecoverable();
192                            let _ = tx.send(Event::PullError(id, e));
193                            tokio::select! {
194                                () = notify.notified() => break 'task,
195                                () = tokio::time::sleep(retry_gap.unwrap_or(cfg_retry_gap)) => {}
196                            };
197                            if is_irrecoverable {
198                                continue 'task;
199                            }
200                        }
201                        Err(_) => {
202                            let _ = tx.send(Event::PullTimeout(id));
203                            drop(stream);
204                            puller = puller.clone();
205                            continue 'task;
206                        }
207                    }
208                }
209            }
210            let _ = tx.send(Event::Finished(id));
211        });
212        TokioHandle {
213            id,
214            notify: notify_clone,
215        }
216    }
217}
218
219#[cfg(test)]
220#[cfg(feature = "mem")]
221mod tests {
222    use vec::Vec;
223
224    use super::*;
225    use crate::{
226        Merge, ProgressEntry,
227        mem::MemPusher,
228        mock::{MockPuller, build_mock_data},
229    };
230    use std::{dbg, vec};
231
232    #[tokio::test(flavor = "multi_thread")]
233    async fn test_concurrent_download() {
234        let mock_data = build_mock_data(3 * 1024);
235        let puller = MockPuller::new(&mock_data);
236        let pusher = MemPusher::with_capacity(mock_data.len());
237        #[allow(clippy::single_range_in_vec_init)]
238        let download_chunks = vec![0..mock_data.len() as u64];
239        let result = download_multi(
240            puller,
241            pusher.clone(),
242            DownloadOptions {
243                concurrent: 32,
244                retry_gap: Duration::from_secs(1),
245                push_queue_cap: 1024,
246                download_chunks: download_chunks.iter().cloned(),
247                pull_timeout: Duration::from_secs(5),
248                min_chunk_size: 1,
249                max_speculative: 3,
250            },
251        );
252
253        let mut pull_progress: Vec<ProgressEntry> = Vec::new();
254        let mut push_progress: Vec<ProgressEntry> = Vec::new();
255        let mut pull_ids = [false; 32];
256        let mut push_ids = [false; 32];
257        while let Ok(e) = result.event_chain.recv().await {
258            match e {
259                Event::PullProgress(id, p) => {
260                    pull_ids[id] = true;
261                    pull_progress.merge_progress(p);
262                }
263                Event::PushProgress(id, p) => {
264                    push_ids[id] = true;
265                    push_progress.merge_progress(p);
266                }
267                _ => {}
268            }
269        }
270        dbg!(&pull_progress);
271        dbg!(&push_progress);
272        assert_eq!(pull_progress, download_chunks);
273        assert_eq!(push_progress, download_chunks);
274        assert!(pull_ids.iter().any(|x| *x));
275        assert!(push_ids.iter().any(|x| *x));
276
277        #[allow(clippy::unwrap_used)]
278        result.join().await.unwrap();
279        assert_eq!(&**pusher.receive.lock(), mock_data);
280    }
281}