Skip to main content

fast_pull/core/
single.rs

1use crate::{
2    DownloadResult, Event, ProgressEntry, Puller, PullerError, Pusher, multi::TokioExecutor,
3};
4use bytes::Bytes;
5use core::time::Duration;
6use crossfire::{mpmc, spsc};
7use futures::TryStreamExt;
8
9#[derive(Debug, Clone, Copy)]
10pub struct DownloadOptions {
11    pub retry_gap: Duration,
12    pub push_queue_cap: usize,
13}
14
15pub fn download_single<R: Puller, W: Pusher>(
16    mut puller: R,
17    mut pusher: W,
18    options: DownloadOptions,
19) -> DownloadResult<TokioExecutor<R, W::Error>, R::Error, W::Error> {
20    const ID: usize = 0;
21    let (tx, event_chain) = mpmc::unbounded_async();
22    let (tx_push, rx_push) = spsc::bounded_async::<(ProgressEntry, Bytes)>(options.push_queue_cap);
23    let tx_clone = tx.clone();
24    let rx_push = rx_push.into_blocking();
25    let push_handle = tokio::task::spawn_blocking(move || {
26        while let Ok((spin, mut data)) = rx_push.recv() {
27            loop {
28                let _ = tx_clone.send(Event::Pushing(ID, spin.clone()));
29                match pusher.push(&spin, data) {
30                    Ok(()) => break,
31                    Err((err, bytes)) => {
32                        data = bytes;
33                        let _ = tx_clone.send(Event::PushError(ID, spin.clone(), err));
34                    }
35                }
36                std::thread::sleep(options.retry_gap);
37            }
38            let _ = tx_clone.send(Event::PushProgress(ID, spin));
39        }
40        loop {
41            let _ = tx_clone.send(Event::Flushing);
42            match pusher.flush() {
43                Ok(()) => break,
44                Err(err) => {
45                    let _ = tx_clone.send(Event::FlushError(err));
46                }
47            }
48            std::thread::sleep(options.retry_gap);
49        }
50    });
51    let handle = tokio::spawn(async move {
52        'redownload: loop {
53            let _ = tx.send(Event::Pulling(ID));
54            let mut downloaded: u64 = 0;
55            let mut stream = loop {
56                match puller.pull(None).await {
57                    Ok(t) => break t,
58                    Err((e, retry_gap)) => {
59                        let _ = tx.send(Event::PullError(ID, e));
60                        tokio::time::sleep(retry_gap.unwrap_or(options.retry_gap)).await;
61                    }
62                }
63            };
64            loop {
65                match stream.try_next().await {
66                    Ok(Some(chunk)) => {
67                        let len = chunk.len() as u64;
68                        let span = downloaded..(downloaded + len);
69                        let _ = tx.send(Event::PullProgress(ID, span.clone()));
70                        let _ = tx_push.send((span, chunk)).await;
71                        downloaded += len;
72                    }
73                    Ok(None) => break 'redownload,
74                    Err((e, retry_gap)) => {
75                        let is_irrecoverable = e.is_irrecoverable();
76                        let _ = tx.send(Event::PullError(ID, e));
77                        tokio::time::sleep(retry_gap.unwrap_or(options.retry_gap)).await;
78                        if is_irrecoverable {
79                            continue 'redownload;
80                        }
81                    }
82                }
83            }
84        }
85        let _ = tx.send(Event::Finished(ID));
86    });
87    DownloadResult::new(
88        event_chain,
89        push_handle,
90        Some(&[handle.abort_handle()]),
91        None,
92    )
93}
94
95#[cfg(test)]
96#[cfg(feature = "mem")]
97mod tests {
98    use super::*;
99    use crate::{
100        Merge, ProgressEntry,
101        mem::MemPusher,
102        mock::{MockPuller, build_mock_data},
103    };
104    use std::{dbg, vec};
105    use vec::Vec;
106
107    #[tokio::test]
108    async fn test_sequential_download() {
109        let mock_data = build_mock_data(3 * 1024);
110        let puller = MockPuller::new(&mock_data);
111        let pusher = MemPusher::with_capacity(mock_data.len());
112        #[allow(clippy::single_range_in_vec_init)]
113        let download_chunks = vec![0..mock_data.len() as u64];
114        let result = download_single(
115            puller,
116            pusher.clone(),
117            DownloadOptions {
118                retry_gap: Duration::from_secs(1),
119                push_queue_cap: 1024,
120            },
121        );
122
123        let mut pull_progress: Vec<ProgressEntry> = Vec::new();
124        let mut push_progress: Vec<ProgressEntry> = Vec::new();
125        while let Ok(e) = result.event_chain.recv().await {
126            match e {
127                Event::PullProgress(_, p) => {
128                    pull_progress.merge_progress(p);
129                }
130                Event::PushProgress(_, p) => {
131                    push_progress.merge_progress(p);
132                }
133                _ => {}
134            }
135        }
136        dbg!(&pull_progress);
137        dbg!(&push_progress);
138        assert_eq!(pull_progress, download_chunks);
139        assert_eq!(push_progress, download_chunks);
140
141        #[allow(clippy::unwrap_used)]
142        result.join().await.unwrap();
143        assert_eq!(&**pusher.receive.lock(), mock_data);
144    }
145}