1use crate::{DownloadResult, Event, Puller, PullerError, Pusher, multi::TokioExecutor};
2use core::time::Duration;
3use crossfire::{mpmc, spsc};
4use futures::TryStreamExt;
5
6#[derive(Debug, Clone, Copy)]
7pub struct DownloadOptions {
8 pub retry_gap: Duration,
9 pub push_queue_cap: usize,
10}
11
12pub fn download_single<R: Puller, W: Pusher>(
13 mut puller: R,
14 mut pusher: W,
15 options: DownloadOptions,
16) -> DownloadResult<TokioExecutor<R, W::Error>, R::Error, W::Error> {
17 const ID: usize = 0;
18 let (tx, event_chain) = mpmc::unbounded_async();
19 let (tx_push, rx_push) = spsc::bounded_async(options.push_queue_cap);
20 let tx_clone = tx.clone();
21 let rx_push = rx_push.into_blocking();
22 let push_handle = tokio::task::spawn_blocking(move || {
23 while let Ok((spin, mut data)) = rx_push.recv() {
24 loop {
25 match pusher.push(&spin, data) {
26 Ok(()) => break,
27 Err((err, bytes)) => {
28 data = bytes;
29 let _ = tx_clone.send(Event::PushError(ID, err));
30 }
31 }
32 std::thread::sleep(options.retry_gap);
33 }
34 let _ = tx_clone.send(Event::PushProgress(ID, spin));
35 }
36 loop {
37 match pusher.flush() {
38 Ok(()) => break,
39 Err(err) => {
40 let _ = tx_clone.send(Event::FlushError(err));
41 }
42 }
43 std::thread::sleep(options.retry_gap);
44 }
45 });
46 let handle = tokio::spawn(async move {
47 'redownload: loop {
48 let _ = tx.send(Event::Pulling(ID));
49 let mut downloaded: u64 = 0;
50 let mut stream = loop {
51 match puller.pull(None).await {
52 Ok(t) => break t,
53 Err((e, retry_gap)) => {
54 let _ = tx.send(Event::PullError(ID, e));
55 tokio::time::sleep(retry_gap.unwrap_or(options.retry_gap)).await;
56 }
57 }
58 };
59 loop {
60 match stream.try_next().await {
61 Ok(Some(chunk)) => {
62 let len = chunk.len() as u64;
63 let span = downloaded..(downloaded + len);
64 let _ = tx.send(Event::PullProgress(ID, span.clone()));
65 let _ = tx_push.send((span, chunk)).await;
66 downloaded += len;
67 }
68 Ok(None) => break 'redownload,
69 Err((e, retry_gap)) => {
70 let is_irrecoverable = e.is_irrecoverable();
71 let _ = tx.send(Event::PullError(ID, e));
72 tokio::time::sleep(retry_gap.unwrap_or(options.retry_gap)).await;
73 if is_irrecoverable {
74 continue 'redownload;
75 }
76 }
77 }
78 }
79 }
80 let _ = tx.send(Event::Finished(ID));
81 });
82 DownloadResult::new(
83 event_chain,
84 push_handle,
85 Some(&[handle.abort_handle()]),
86 None,
87 )
88}
89
90#[cfg(test)]
91#[cfg(feature = "mem")]
92mod tests {
93 use super::*;
94 use crate::{
95 Merge, ProgressEntry,
96 mem::MemPusher,
97 mock::{MockPuller, build_mock_data},
98 };
99 use std::{dbg, vec};
100 use vec::Vec;
101
102 #[tokio::test]
103 async fn test_sequential_download() {
104 let mock_data = build_mock_data(3 * 1024);
105 let puller = MockPuller::new(&mock_data);
106 let pusher = MemPusher::with_capacity(mock_data.len());
107 #[allow(clippy::single_range_in_vec_init)]
108 let download_chunks = vec![0..mock_data.len() as u64];
109 let result = download_single(
110 puller,
111 pusher.clone(),
112 DownloadOptions {
113 retry_gap: Duration::from_secs(1),
114 push_queue_cap: 1024,
115 },
116 );
117
118 let mut pull_progress: Vec<ProgressEntry> = Vec::new();
119 let mut push_progress: Vec<ProgressEntry> = Vec::new();
120 while let Ok(e) = result.event_chain.recv().await {
121 match e {
122 Event::PullProgress(_, p) => {
123 pull_progress.merge_progress(p);
124 }
125 Event::PushProgress(_, p) => {
126 push_progress.merge_progress(p);
127 }
128 _ => {}
129 }
130 }
131 dbg!(&pull_progress);
132 dbg!(&push_progress);
133 assert_eq!(pull_progress, download_chunks);
134 assert_eq!(push_progress, download_chunks);
135
136 #[allow(clippy::unwrap_used)]
137 result.join().await.unwrap();
138 assert_eq!(&**pusher.receive.lock(), mock_data);
139 }
140}