1use crate::{DownloadResult, Event, Puller, PullerError, Pusher, multi::TokioExecutor};
2use core::time::Duration;
3use crossfire::{mpmc, spsc};
4use futures::TryStreamExt;
5
6#[derive(Debug, Clone)]
7pub struct DownloadOptions {
8 pub retry_gap: Duration,
9 pub push_queue_cap: usize,
10}
11
12pub fn download_single<R: Puller, W: Pusher>(
13 mut puller: R,
14 mut pusher: W,
15 options: DownloadOptions,
16) -> DownloadResult<TokioExecutor<R, W::Error>, R::Error, W::Error> {
17 let (tx, event_chain) = mpmc::unbounded_async();
18 let (tx_push, rx_push) = spsc::bounded_async(options.push_queue_cap);
19 let tx_clone = tx.clone();
20 const ID: usize = 0;
21 let rx_push = rx_push.into_blocking();
22 let push_handle = tokio::task::spawn_blocking(move || {
23 while let Ok((spin, mut data)) = rx_push.recv() {
24 loop {
25 match pusher.push(&spin, data) {
26 Ok(_) => break,
27 Err((err, bytes)) => {
28 data = bytes;
29 let _ = tx_clone.send(Event::PushError(ID, err));
30 }
31 }
32 std::thread::sleep(options.retry_gap);
33 }
34 let _ = tx_clone.send(Event::PushProgress(ID, spin));
35 }
36 loop {
37 match pusher.flush() {
38 Ok(_) => break,
39 Err(err) => {
40 let _ = tx_clone.send(Event::FlushError(err));
41 }
42 }
43 std::thread::sleep(options.retry_gap);
44 }
45 });
46 let handle = tokio::spawn(async move {
47 'redownload: loop {
48 let _ = tx.send(Event::Pulling(ID));
49 let mut downloaded: u64 = 0;
50 let mut stream = loop {
51 match puller.pull(None).await {
52 Ok(t) => break t,
53 Err((e, retry_gap)) => {
54 let _ = tx.send(Event::PullError(ID, e));
55 tokio::time::sleep(retry_gap.unwrap_or(options.retry_gap)).await;
56 }
57 }
58 };
59 loop {
60 match stream.try_next().await {
61 Ok(Some(chunk)) => {
62 let len = chunk.len() as u64;
63 let span = downloaded..(downloaded + len);
64 let _ = tx.send(Event::PullProgress(ID, span.clone()));
65 tx_push.send((span, chunk)).await.unwrap();
66 downloaded += len;
67 }
68 Ok(None) => break 'redownload,
69 Err((e, retry_gap)) => {
70 let is_irrecoverable = e.is_irrecoverable();
71 let _ = tx.send(Event::PullError(ID, e));
72 tokio::time::sleep(retry_gap.unwrap_or(options.retry_gap)).await;
73 if is_irrecoverable {
74 continue 'redownload;
75 }
76 }
77 }
78 }
79 }
80 let _ = tx.send(Event::Finished(ID));
81 });
82 DownloadResult::new(
83 event_chain,
84 push_handle,
85 Some(&[handle.abort_handle()]),
86 None,
87 )
88}
89
90#[cfg(test)]
91mod tests {
92 use super::*;
93 use crate::{
94 Merge, ProgressEntry,
95 mem::MemPusher,
96 mock::{MockPuller, build_mock_data},
97 };
98 use std::{dbg, vec};
99 use vec::Vec;
100
101 #[tokio::test]
102 async fn test_sequential_download() {
103 let mock_data = build_mock_data(3 * 1024);
104 let puller = MockPuller::new(&mock_data);
105 let pusher = MemPusher::with_capacity(mock_data.len());
106 #[allow(clippy::single_range_in_vec_init)]
107 let download_chunks = vec![0..mock_data.len() as u64];
108 let result = download_single(
109 puller,
110 pusher.clone(),
111 DownloadOptions {
112 retry_gap: Duration::from_secs(1),
113 push_queue_cap: 1024,
114 },
115 );
116
117 let mut pull_progress: Vec<ProgressEntry> = Vec::new();
118 let mut push_progress: Vec<ProgressEntry> = Vec::new();
119 while let Ok(e) = result.event_chain.recv().await {
120 match e {
121 Event::PullProgress(_, p) => {
122 pull_progress.merge_progress(p);
123 }
124 Event::PushProgress(_, p) => {
125 push_progress.merge_progress(p);
126 }
127 _ => {}
128 }
129 }
130 dbg!(&pull_progress);
131 dbg!(&push_progress);
132 assert_eq!(pull_progress, download_chunks);
133 assert_eq!(push_progress, download_chunks);
134
135 result.join().await.unwrap();
136 assert_eq!(&**pusher.receive.lock(), mock_data);
137 }
138}