Skip to main content

fast_pull/core/
single.rs

1use crate::{DownloadResult, Event, Puller, PullerError, Pusher, multi::TokioExecutor};
2use core::time::Duration;
3use crossfire::{mpmc, spsc};
4use futures::TryStreamExt;
5
6#[derive(Debug, Clone)]
7pub struct DownloadOptions {
8    pub retry_gap: Duration,
9    pub push_queue_cap: usize,
10}
11
12pub fn download_single<R: Puller, W: Pusher>(
13    mut puller: R,
14    mut pusher: W,
15    options: DownloadOptions,
16) -> DownloadResult<TokioExecutor<R, W::Error>, R::Error, W::Error> {
17    let (tx, event_chain) = mpmc::unbounded_async();
18    let (tx_push, rx_push) = spsc::bounded_async(options.push_queue_cap);
19    let tx_clone = tx.clone();
20    const ID: usize = 0;
21    let rx_push = rx_push.into_blocking();
22    let push_handle = tokio::task::spawn_blocking(move || {
23        while let Ok((spin, mut data)) = rx_push.recv() {
24            loop {
25                match pusher.push(&spin, data) {
26                    Ok(_) => break,
27                    Err((err, bytes)) => {
28                        data = bytes;
29                        let _ = tx_clone.send(Event::PushError(ID, err));
30                    }
31                }
32                std::thread::sleep(options.retry_gap);
33            }
34            let _ = tx_clone.send(Event::PushProgress(ID, spin));
35        }
36        loop {
37            match pusher.flush() {
38                Ok(_) => break,
39                Err(err) => {
40                    let _ = tx_clone.send(Event::FlushError(err));
41                }
42            }
43            std::thread::sleep(options.retry_gap);
44        }
45    });
46    let handle = tokio::spawn(async move {
47        'redownload: loop {
48            let _ = tx.send(Event::Pulling(ID));
49            let mut downloaded: u64 = 0;
50            let mut stream = loop {
51                match puller.pull(None).await {
52                    Ok(t) => break t,
53                    Err((e, retry_gap)) => {
54                        let _ = tx.send(Event::PullError(ID, e));
55                        tokio::time::sleep(retry_gap.unwrap_or(options.retry_gap)).await;
56                    }
57                }
58            };
59            loop {
60                match stream.try_next().await {
61                    Ok(Some(chunk)) => {
62                        let len = chunk.len() as u64;
63                        let span = downloaded..(downloaded + len);
64                        let _ = tx.send(Event::PullProgress(ID, span.clone()));
65                        tx_push.send((span, chunk)).await.unwrap();
66                        downloaded += len;
67                    }
68                    Ok(None) => break 'redownload,
69                    Err((e, retry_gap)) => {
70                        let is_irrecoverable = e.is_irrecoverable();
71                        let _ = tx.send(Event::PullError(ID, e));
72                        tokio::time::sleep(retry_gap.unwrap_or(options.retry_gap)).await;
73                        if is_irrecoverable {
74                            continue 'redownload;
75                        }
76                    }
77                }
78            }
79        }
80        let _ = tx.send(Event::Finished(ID));
81    });
82    DownloadResult::new(
83        event_chain,
84        push_handle,
85        Some(&[handle.abort_handle()]),
86        None,
87    )
88}
89
90#[cfg(test)]
91mod tests {
92    use super::*;
93    use crate::{
94        Merge, ProgressEntry,
95        mem::MemPusher,
96        mock::{MockPuller, build_mock_data},
97    };
98    use std::{dbg, vec};
99    use vec::Vec;
100
101    #[tokio::test]
102    async fn test_sequential_download() {
103        let mock_data = build_mock_data(3 * 1024);
104        let puller = MockPuller::new(&mock_data);
105        let pusher = MemPusher::with_capacity(mock_data.len());
106        #[allow(clippy::single_range_in_vec_init)]
107        let download_chunks = vec![0..mock_data.len() as u64];
108        let result = download_single(
109            puller,
110            pusher.clone(),
111            DownloadOptions {
112                retry_gap: Duration::from_secs(1),
113                push_queue_cap: 1024,
114            },
115        );
116
117        let mut pull_progress: Vec<ProgressEntry> = Vec::new();
118        let mut push_progress: Vec<ProgressEntry> = Vec::new();
119        while let Ok(e) = result.event_chain.recv().await {
120            match e {
121                Event::PullProgress(_, p) => {
122                    pull_progress.merge_progress(p);
123                }
124                Event::PushProgress(_, p) => {
125                    push_progress.merge_progress(p);
126                }
127                _ => {}
128            }
129        }
130        dbg!(&pull_progress);
131        dbg!(&push_progress);
132        assert_eq!(pull_progress, download_chunks);
133        assert_eq!(push_progress, download_chunks);
134
135        result.join().await.unwrap();
136        assert_eq!(&**pusher.receive.lock(), mock_data);
137    }
138}