Skip to main content

fast_pull/core/
multi.rs

1use crate::{DownloadResult, Event, ProgressEntry, Puller, PullerError, Pusher, WorkerId};
2use bytes::Bytes;
3use core::{
4    sync::atomic::{AtomicUsize, Ordering},
5    time::Duration,
6};
7use crossfire::{MAsyncTx, MTx, mpmc, mpsc};
8use fast_steal::{Executor, Handle, Task, TaskQueue};
9use futures::TryStreamExt;
10use std::sync::Arc;
11use tokio::sync::Notify;
12
13#[derive(Debug, Clone)]
14pub struct DownloadOptions<I: Iterator<Item = ProgressEntry>> {
15    pub download_chunks: I,
16    pub concurrent: usize,
17    pub retry_gap: Duration,
18    pub pull_timeout: Duration,
19    pub push_queue_cap: usize,
20    pub min_chunk_size: u64,
21    pub max_speculative: usize,
22}
23
24/// # Panics
25/// 当设置线程数,但 executor 意外为空时,panic
26pub fn download_multi<R: Puller, W: Pusher, I: Iterator<Item = ProgressEntry>>(
27    puller: R,
28    mut pusher: W,
29    options: DownloadOptions<I>,
30) -> DownloadResult<TokioExecutor<R, W::Error>, R::Error, W::Error> {
31    let (tx, event_chain) = mpmc::unbounded_async();
32    let (tx_push, rx_push) = mpsc::bounded_async(options.push_queue_cap);
33    let tx_clone = tx.clone();
34    let rx_push = rx_push.into_blocking();
35    let push_handle = tokio::task::spawn_blocking(move || {
36        while let Ok((id, spin, mut data)) = rx_push.recv() {
37            loop {
38                match pusher.push(&spin, data) {
39                    Ok(()) => break,
40                    Err((err, bytes)) => {
41                        data = bytes;
42                        let _ = tx_clone.send(Event::PushError(id, err));
43                    }
44                }
45                std::thread::sleep(options.retry_gap);
46            }
47            let _ = tx_clone.send(Event::PushProgress(id, spin));
48        }
49        loop {
50            match pusher.flush() {
51                Ok(()) => break,
52                Err(err) => {
53                    let _ = tx_clone.send(Event::FlushError(err));
54                }
55            }
56            std::thread::sleep(options.retry_gap);
57        }
58    });
59    let executor: Arc<TokioExecutor<R, W::Error>> = Arc::new(TokioExecutor {
60        tx,
61        tx_push,
62        puller,
63        pull_timeout: options.pull_timeout,
64        retry_gap: options.retry_gap,
65        id: AtomicUsize::new(0),
66        min_chunk_size: options.min_chunk_size,
67        max_speculative: options.max_speculative,
68    });
69    let task_queue = TaskQueue::new(options.download_chunks);
70    #[allow(clippy::unwrap_used)]
71    task_queue
72        .set_threads(
73            options.concurrent,
74            options.min_chunk_size,
75            Some(executor.as_ref()),
76        )
77        .unwrap();
78    DownloadResult::new(
79        event_chain,
80        push_handle,
81        None,
82        Some((Arc::downgrade(&executor), task_queue)),
83    )
84}
85
86#[derive(Debug, Clone)]
87pub struct TokioHandle {
88    id: usize,
89    notify: Arc<Notify>,
90}
91impl Handle for TokioHandle {
92    type Output = ();
93    type Id = usize;
94    fn abort(&mut self) -> Self::Output {
95        self.notify.notify_one();
96    }
97    fn is_self(&mut self, id: &Self::Id) -> bool {
98        self.id == *id
99    }
100}
101#[derive(Debug)]
102pub struct TokioExecutor<R, WE>
103where
104    R: Puller,
105    WE: Send + Unpin + 'static,
106{
107    tx: MTx<mpmc::List<Event<R::Error, WE>>>,
108    tx_push: MAsyncTx<mpsc::Array<(WorkerId, ProgressEntry, Bytes)>>,
109    puller: R,
110    retry_gap: Duration,
111    pull_timeout: Duration,
112    id: AtomicUsize,
113    min_chunk_size: u64,
114    max_speculative: usize,
115}
116impl<R, WE> Executor for TokioExecutor<R, WE>
117where
118    R: Puller,
119    WE: Send + Unpin + 'static,
120{
121    type Handle = TokioHandle;
122    fn execute(&self, mut task: Task, task_queue: TaskQueue<Self::Handle>) -> Self::Handle {
123        let id = self.id.fetch_add(1, Ordering::SeqCst);
124        let mut puller = self.puller.clone();
125        let min_chunk_size = self.min_chunk_size;
126        let pull_timeout = self.pull_timeout;
127        let cfg_retry_gap = self.retry_gap;
128        let max_speculative = self.max_speculative;
129        let tx = self.tx.clone();
130        let tx_push = self.tx_push.clone();
131        let notify = Arc::new(Notify::new());
132        let notify_clone = notify.clone();
133        tokio::spawn(async move {
134            'task: loop {
135                let mut start = task.start();
136                if start >= task.end() {
137                    if task_queue.steal(&mut task, min_chunk_size, max_speculative) {
138                        continue 'task;
139                    }
140                    break;
141                }
142                let _ = tx.send(Event::Pulling(id));
143                let download_range = start..task.end();
144                let mut stream = loop {
145                    let t = tokio::select! {
146                        () = notify.notified() => break 'task,
147                        t = puller.pull(Some(&download_range)) => t
148                    };
149                    match t {
150                        Ok(t) => break t,
151                        Err((e, retry_gap)) => {
152                            let _ = tx.send(Event::PullError(id, e));
153                            tokio::select! {
154                                () = notify.notified() => break 'task,
155                                () = tokio::time::sleep(retry_gap.unwrap_or(cfg_retry_gap)) => {}
156                            };
157                        }
158                    }
159                };
160                tokio::pin! {
161                    let sleep = tokio::time::sleep(pull_timeout);
162                }
163                loop {
164                    sleep
165                        .as_mut()
166                        .reset(tokio::time::Instant::now() + pull_timeout);
167                    let t = tokio::select! {
168                        () = notify.notified() => break 'task,
169                        () = &mut sleep => {
170                            let _ = tx.send(Event::PullTimeout(id));
171                            drop(stream);
172                            puller = puller.clone();
173                            continue 'task;
174                        },
175                        t = stream.try_next() => t,
176                    };
177                    match t {
178                        Ok(Some(mut chunk)) => {
179                            if chunk.is_empty() {
180                                continue;
181                            }
182                            let len = chunk.len() as u64;
183                            let Ok(span) = task.safe_add_start(start, len) else {
184                                start += len;
185                                continue;
186                            };
187                            if span.end >= task.end() {
188                                task_queue.cancel_task(&task, &id);
189                            }
190                            #[allow(clippy::cast_possible_truncation)]
191                            let slice_span =
192                                (span.start - start) as usize..(span.end - start) as usize;
193                            chunk = chunk.slice(slice_span);
194                            start = span.end;
195                            let _ = tx.send(Event::PullProgress(id, span.clone()));
196                            let _ = tx_push.send((id, span, chunk)).await;
197                            if start >= task.end() {
198                                continue 'task;
199                            }
200                        }
201                        Ok(None) => continue 'task,
202                        Err((e, retry_gap)) => {
203                            let is_irrecoverable = e.is_irrecoverable();
204                            let _ = tx.send(Event::PullError(id, e));
205                            tokio::select! {
206                                () = notify.notified() => break 'task,
207                                () = tokio::time::sleep(retry_gap.unwrap_or(cfg_retry_gap)) => {}
208                            };
209                            if is_irrecoverable {
210                                continue 'task;
211                            }
212                        }
213                    }
214                }
215            }
216            let _ = tx.send(Event::Finished(id));
217        });
218        TokioHandle {
219            id,
220            notify: notify_clone,
221        }
222    }
223}
224
225#[cfg(test)]
226#[cfg(feature = "mem")]
227mod tests {
228    use vec::Vec;
229
230    use super::*;
231    use crate::{
232        Merge, ProgressEntry,
233        mem::MemPusher,
234        mock::{MockPuller, build_mock_data},
235    };
236    use std::{dbg, vec};
237
238    #[tokio::test(flavor = "multi_thread")]
239    async fn test_concurrent_download() {
240        let mock_data = build_mock_data(3 * 1024);
241        let puller = MockPuller::new(&mock_data);
242        let pusher = MemPusher::with_capacity(mock_data.len());
243        #[allow(clippy::single_range_in_vec_init)]
244        let download_chunks = vec![0..mock_data.len() as u64];
245        let result = download_multi(
246            puller,
247            pusher.clone(),
248            DownloadOptions {
249                concurrent: 32,
250                retry_gap: Duration::from_secs(1),
251                push_queue_cap: 1024,
252                download_chunks: download_chunks.iter().cloned(),
253                pull_timeout: Duration::from_secs(5),
254                min_chunk_size: 1,
255                max_speculative: 3,
256            },
257        );
258
259        let mut pull_progress: Vec<ProgressEntry> = Vec::new();
260        let mut push_progress: Vec<ProgressEntry> = Vec::new();
261        let mut pull_ids = [false; 32];
262        let mut push_ids = [false; 32];
263        while let Ok(e) = result.event_chain.recv().await {
264            match e {
265                Event::PullProgress(id, p) => {
266                    pull_ids[id] = true;
267                    pull_progress.merge_progress(p);
268                }
269                Event::PushProgress(id, p) => {
270                    push_ids[id] = true;
271                    push_progress.merge_progress(p);
272                }
273                _ => {}
274            }
275        }
276        dbg!(&pull_progress);
277        dbg!(&push_progress);
278        assert_eq!(pull_progress, download_chunks);
279        assert_eq!(push_progress, download_chunks);
280        assert!(pull_ids.iter().any(|x| *x));
281        assert!(push_ids.iter().any(|x| *x));
282
283        #[allow(clippy::unwrap_used)]
284        result.join().await.unwrap();
285        assert_eq!(&**pusher.receive.lock(), mock_data);
286    }
287}