1use crate::{
2 DownloadResult, Event, ProgressEntry, Puller, PullerError, Pusher, multi::TokioExecutor,
3};
4use bytes::Bytes;
5use core::time::Duration;
6use crossfire::{mpmc, spsc};
7use futures::TryStreamExt;
8
9#[derive(Debug, Clone, Copy)]
10pub struct DownloadOptions {
11 pub retry_gap: Duration,
12 pub push_queue_cap: usize,
13}
14
15pub fn download_single<R: Puller, W: Pusher>(
16 mut puller: R,
17 mut pusher: W,
18 options: DownloadOptions,
19) -> DownloadResult<TokioExecutor<R, W::Error>, R::Error, W::Error> {
20 const ID: usize = 0;
21 let (tx, event_chain) = mpmc::unbounded_async();
22 let (tx_push, rx_push) = spsc::bounded_async::<(ProgressEntry, Bytes)>(options.push_queue_cap);
23 let tx_clone = tx.clone();
24 let rx_push = rx_push.into_blocking();
25 let push_handle = tokio::task::spawn_blocking(move || {
26 while let Ok((spin, mut data)) = rx_push.recv() {
27 loop {
28 let _ = tx_clone.send(Event::Pushing(ID, spin.clone()));
29 match pusher.push(&spin, data) {
30 Ok(()) => break,
31 Err((err, bytes)) => {
32 data = bytes;
33 let _ = tx_clone.send(Event::PushError(ID, spin.clone(), err));
34 }
35 }
36 std::thread::sleep(options.retry_gap);
37 }
38 let _ = tx_clone.send(Event::PushProgress(ID, spin));
39 }
40 loop {
41 let _ = tx_clone.send(Event::Flushing);
42 match pusher.flush() {
43 Ok(()) => break,
44 Err(err) => {
45 let _ = tx_clone.send(Event::FlushError(err));
46 }
47 }
48 std::thread::sleep(options.retry_gap);
49 }
50 });
51 let handle = tokio::spawn(async move {
52 'redownload: loop {
53 let _ = tx.send(Event::Pulling(ID));
54 let mut downloaded: u64 = 0;
55 let mut stream = loop {
56 match puller.pull(None).await {
57 Ok(t) => break t,
58 Err((e, retry_gap)) => {
59 let _ = tx.send(Event::PullError(ID, e));
60 tokio::time::sleep(retry_gap.unwrap_or(options.retry_gap)).await;
61 }
62 }
63 };
64 loop {
65 match stream.try_next().await {
66 Ok(Some(chunk)) => {
67 let len = chunk.len() as u64;
68 let span = downloaded..(downloaded + len);
69 let _ = tx.send(Event::PullProgress(ID, span.clone()));
70 let _ = tx_push.send((span, chunk)).await;
71 downloaded += len;
72 }
73 Ok(None) => break 'redownload,
74 Err((e, retry_gap)) => {
75 let is_irrecoverable = e.is_irrecoverable();
76 let _ = tx.send(Event::PullError(ID, e));
77 tokio::time::sleep(retry_gap.unwrap_or(options.retry_gap)).await;
78 if is_irrecoverable {
79 continue 'redownload;
80 }
81 }
82 }
83 }
84 }
85 let _ = tx.send(Event::Finished(ID));
86 });
87 DownloadResult::new(
88 event_chain,
89 push_handle,
90 Some(&[handle.abort_handle()]),
91 None,
92 )
93}
94
95#[cfg(test)]
96#[cfg(feature = "mem")]
97mod tests {
98 use super::*;
99 use crate::{
100 Merge, ProgressEntry,
101 mem::MemPusher,
102 mock::{MockPuller, build_mock_data},
103 };
104 use std::{dbg, vec};
105 use vec::Vec;
106
107 #[tokio::test]
108 async fn test_sequential_download() {
109 let mock_data = build_mock_data(3 * 1024);
110 let puller = MockPuller::new(&mock_data);
111 let pusher = MemPusher::with_capacity(mock_data.len());
112 #[allow(clippy::single_range_in_vec_init)]
113 let download_chunks = vec![0..mock_data.len() as u64];
114 let result = download_single(
115 puller,
116 pusher.clone(),
117 DownloadOptions {
118 retry_gap: Duration::from_secs(1),
119 push_queue_cap: 1024,
120 },
121 );
122
123 let mut pull_progress: Vec<ProgressEntry> = Vec::new();
124 let mut push_progress: Vec<ProgressEntry> = Vec::new();
125 while let Ok(e) = result.event_chain.recv().await {
126 match e {
127 Event::PullProgress(_, p) => {
128 pull_progress.merge_progress(p);
129 }
130 Event::PushProgress(_, p) => {
131 push_progress.merge_progress(p);
132 }
133 _ => {}
134 }
135 }
136 dbg!(&pull_progress);
137 dbg!(&push_progress);
138 assert_eq!(pull_progress, download_chunks);
139 assert_eq!(push_progress, download_chunks);
140
141 #[allow(clippy::unwrap_used)]
142 result.join().await.unwrap();
143 assert_eq!(&**pusher.receive.lock(), mock_data);
144 }
145}