Skip to main content

fast_pull/core/
multi.rs

1use crate::{DownloadResult, Event, ProgressEntry, Puller, PullerError, Pusher, WorkerId};
2use bytes::Bytes;
3use core::{
4    sync::atomic::{AtomicUsize, Ordering},
5    time::Duration,
6};
7use crossfire::{MAsyncTx, MTx, mpmc, mpsc};
8use fast_steal::{Executor, Handle, Task, TaskQueue};
9use futures::TryStreamExt;
10use std::sync::Arc;
11use tokio::sync::Notify;
12
13#[derive(Debug, Clone)]
14pub struct DownloadOptions<I: Iterator<Item = ProgressEntry>> {
15    pub download_chunks: I,
16    pub concurrent: usize,
17    pub retry_gap: Duration,
18    pub pull_timeout: Duration,
19    pub push_queue_cap: usize,
20    pub min_chunk_size: u64,
21    pub max_speculative: usize,
22}
23
24/// # Panics
25/// 当设置线程数,但 executor 意外为空时,panic
26pub fn download_multi<R: Puller, W: Pusher, I: Iterator<Item = ProgressEntry>>(
27    puller: R,
28    mut pusher: W,
29    options: DownloadOptions<I>,
30) -> DownloadResult<TokioExecutor<R, W::Error>, R::Error, W::Error> {
31    let (tx, event_chain) = mpmc::unbounded_async();
32    let (tx_push, rx_push) =
33        mpsc::bounded_async::<(WorkerId, ProgressEntry, Bytes)>(options.push_queue_cap);
34    let tx_clone = tx.clone();
35    let rx_push = rx_push.into_blocking();
36    let push_handle = tokio::task::spawn_blocking(move || {
37        while let Ok((id, spin, mut data)) = rx_push.recv() {
38            loop {
39                let _ = tx_clone.send(Event::Pushing(id, spin.clone()));
40                match pusher.push(&spin, data) {
41                    Ok(()) => break,
42                    Err((err, bytes)) => {
43                        data = bytes;
44                        let _ = tx_clone.send(Event::PushError(id, spin.clone(), err));
45                    }
46                }
47                std::thread::sleep(options.retry_gap);
48            }
49            let _ = tx_clone.send(Event::PushProgress(id, spin));
50        }
51        loop {
52            let _ = tx_clone.send(Event::Flushing);
53            match pusher.flush() {
54                Ok(()) => break,
55                Err(err) => {
56                    let _ = tx_clone.send(Event::FlushError(err));
57                }
58            }
59            std::thread::sleep(options.retry_gap);
60        }
61    });
62    let executor: Arc<TokioExecutor<R, W::Error>> = Arc::new(TokioExecutor {
63        tx,
64        tx_push,
65        puller,
66        pull_timeout: options.pull_timeout,
67        retry_gap: options.retry_gap,
68        id: AtomicUsize::new(0),
69        min_chunk_size: options.min_chunk_size,
70        max_speculative: options.max_speculative,
71    });
72    let task_queue = TaskQueue::new(options.download_chunks);
73    #[allow(clippy::unwrap_used)]
74    task_queue
75        .set_threads(
76            options.concurrent,
77            options.min_chunk_size,
78            Some(executor.as_ref()),
79        )
80        .unwrap();
81    DownloadResult::new(
82        event_chain,
83        push_handle,
84        None,
85        Some((Arc::downgrade(&executor), task_queue)),
86    )
87}
88
89#[derive(Debug, Clone)]
90pub struct TokioHandle {
91    id: usize,
92    notify: Arc<Notify>,
93}
94impl Handle for TokioHandle {
95    type Output = ();
96    type Id = usize;
97    fn abort(&mut self) -> Self::Output {
98        self.notify.notify_one();
99    }
100    fn is_self(&mut self, id: &Self::Id) -> bool {
101        self.id == *id
102    }
103}
104#[derive(Debug)]
105pub struct TokioExecutor<R, WE>
106where
107    R: Puller,
108    WE: Send + Unpin + 'static,
109{
110    tx: MTx<mpmc::List<Event<R::Error, WE>>>,
111    tx_push: MAsyncTx<mpsc::Array<(WorkerId, ProgressEntry, Bytes)>>,
112    puller: R,
113    retry_gap: Duration,
114    pull_timeout: Duration,
115    id: AtomicUsize,
116    min_chunk_size: u64,
117    max_speculative: usize,
118}
119impl<R, WE> Executor for TokioExecutor<R, WE>
120where
121    R: Puller,
122    WE: Send + Unpin + 'static,
123{
124    type Handle = TokioHandle;
125    #[allow(clippy::too_many_lines)]
126    fn execute(&self, mut task: Task, task_queue: TaskQueue<Self::Handle>) -> Self::Handle {
127        let id = self.id.fetch_add(1, Ordering::SeqCst);
128        let mut puller = self.puller.clone();
129        let min_chunk_size = self.min_chunk_size;
130        let pull_timeout = self.pull_timeout;
131        let cfg_retry_gap = self.retry_gap;
132        let max_speculative = self.max_speculative;
133        let tx = self.tx.clone();
134        let tx_push = self.tx_push.clone();
135        let notify = Arc::new(Notify::new());
136        let notify_clone = notify.clone();
137        tokio::spawn(async move {
138            'task: loop {
139                let mut start = task.start();
140                if start >= task.end() {
141                    if task_queue.steal(&id, &mut task, min_chunk_size, max_speculative) {
142                        tokio::select! {
143                            biased;
144                            () = notify.notified() => {}
145                            () = async {} => {}
146                        }
147                        continue 'task;
148                    }
149                    break;
150                }
151                let _ = tx.send(Event::Pulling(id));
152                let download_range = start..task.end();
153                let mut stream = loop {
154                    let t = tokio::select! {
155                        () = notify.notified() => break 'task,
156                        t = puller.pull(Some(&download_range)) => t
157                    };
158                    match t {
159                        Ok(t) => break t,
160                        Err((e, retry_gap)) => {
161                            let _ = tx.send(Event::PullError(id, e));
162                            tokio::select! {
163                                () = notify.notified() => break 'task,
164                                () = tokio::time::sleep(retry_gap.unwrap_or(cfg_retry_gap)) => {}
165                            };
166                        }
167                    }
168                };
169                tokio::pin! {
170                    let sleep = tokio::time::sleep(pull_timeout);
171                }
172                loop {
173                    sleep
174                        .as_mut()
175                        .reset(tokio::time::Instant::now() + pull_timeout);
176                    let t = tokio::select! {
177                        () = notify.notified() => break 'task,
178                        () = &mut sleep => {
179                            let _ = tx.send(Event::PullTimeout(id));
180                            drop(stream);
181                            puller = puller.clone();
182                            continue 'task;
183                        },
184                        t = stream.try_next() => t,
185                    };
186                    match t {
187                        Ok(Some(mut chunk)) => {
188                            if chunk.is_empty() {
189                                continue;
190                            }
191                            let len = chunk.len() as u64;
192                            let Ok(span) = task.safe_add_start(start, len) else {
193                                start += len;
194                                continue;
195                            };
196                            if span.end >= task.end() {
197                                task_queue.cancel_task(&task, &id);
198                            }
199                            #[allow(clippy::cast_possible_truncation)]
200                            let slice_span =
201                                (span.start - start) as usize..(span.end - start) as usize;
202                            chunk = chunk.slice(slice_span);
203                            start = span.end;
204                            let _ = tx.send(Event::PullProgress(id, span.clone()));
205                            let _ = tx_push.send((id, span, chunk)).await;
206                            if start >= task.end() {
207                                continue 'task;
208                            }
209                        }
210                        Ok(None) => continue 'task,
211                        Err((e, retry_gap)) => {
212                            let is_irrecoverable = e.is_irrecoverable();
213                            let _ = tx.send(Event::PullError(id, e));
214                            tokio::select! {
215                                () = notify.notified() => break 'task,
216                                () = tokio::time::sleep(retry_gap.unwrap_or(cfg_retry_gap)) => {}
217                            };
218                            if is_irrecoverable {
219                                continue 'task;
220                            }
221                        }
222                    }
223                }
224            }
225            let _ = tx.send(Event::Finished(id));
226        });
227        TokioHandle {
228            id,
229            notify: notify_clone,
230        }
231    }
232}
233
234#[cfg(test)]
235#[cfg(feature = "mem")]
236mod tests {
237    use vec::Vec;
238
239    use super::*;
240    use crate::{
241        Merge, ProgressEntry,
242        mem::MemPusher,
243        mock::{MockPuller, build_mock_data},
244    };
245    use std::{dbg, vec};
246
247    #[tokio::test(flavor = "multi_thread")]
248    async fn test_concurrent_download() {
249        let mock_data = build_mock_data(3 * 1024);
250        let puller = MockPuller::new(&mock_data);
251        let pusher = MemPusher::with_capacity(mock_data.len());
252        #[allow(clippy::single_range_in_vec_init)]
253        let download_chunks = vec![0..mock_data.len() as u64];
254        let result = download_multi(
255            puller,
256            pusher.clone(),
257            DownloadOptions {
258                concurrent: 32,
259                retry_gap: Duration::from_secs(1),
260                push_queue_cap: 1024,
261                download_chunks: download_chunks.iter().cloned(),
262                pull_timeout: Duration::from_secs(5),
263                min_chunk_size: 1,
264                max_speculative: 3,
265            },
266        );
267
268        let mut pull_progress: Vec<ProgressEntry> = Vec::new();
269        let mut push_progress: Vec<ProgressEntry> = Vec::new();
270        let mut pull_ids = [false; 32];
271        let mut push_ids = [false; 32];
272        while let Ok(e) = result.event_chain.recv().await {
273            match e {
274                Event::PullProgress(id, p) => {
275                    pull_ids[id] = true;
276                    pull_progress.merge_progress(p);
277                }
278                Event::PushProgress(id, p) => {
279                    push_ids[id] = true;
280                    push_progress.merge_progress(p);
281                }
282                _ => {}
283            }
284        }
285        dbg!(&pull_progress);
286        dbg!(&push_progress);
287        assert_eq!(pull_progress, download_chunks);
288        assert_eq!(push_progress, download_chunks);
289        assert!(pull_ids.iter().any(|x| *x));
290        assert!(push_ids.iter().any(|x| *x));
291
292        #[allow(clippy::unwrap_used)]
293        result.join().await.unwrap();
294        assert_eq!(&**pusher.receive.lock(), mock_data);
295    }
296}