1extern crate alloc;
2use super::macros::poll_ok;
3use crate::{DownloadResult, Event, ProgressEntry, SeqPuller, SeqPusher};
4use bytes::Bytes;
5use core::time::Duration;
6use futures::TryStreamExt;
7
8#[derive(Debug, Clone)]
9pub struct DownloadOptions {
10 pub retry_gap: Duration,
11 pub push_queue_cap: usize,
12}
13
14pub async fn download_single<R, W>(
15 mut puller: R,
16 mut pusher: W,
17 options: DownloadOptions,
18) -> DownloadResult<R::Error, W::Error>
19where
20 R: SeqPuller + 'static,
21 W: SeqPusher + 'static,
22{
23 let (tx, event_chain) = kanal::unbounded_async();
24 let (tx_push, rx_push) = kanal::bounded_async::<(ProgressEntry, Bytes)>(options.push_queue_cap);
25 let tx_clone = tx.clone();
26 const ID: usize = 0;
27 let push_handle = tokio::spawn(async move {
28 while let Ok((spin, data)) = rx_push.recv().await {
29 poll_ok!(
30 {},
31 pusher.push(data.clone()).await,
32 ID @ tx_clone => PushError,
33 options.retry_gap
34 );
35 tx_clone.send(Event::PushProgress(ID, spin)).await.unwrap();
36 }
37 poll_ok!(
38 {},
39 pusher.flush().await,
40 tx_clone => FlushError,
41 options.retry_gap
42 );
43 });
44 let handle = tokio::spawn(async move {
45 tx.send(Event::Pulling(ID)).await.unwrap();
46 let mut downloaded: u64 = 0;
47 let mut stream = puller.pull();
48 loop {
49 match stream.try_next().await {
50 Ok(Some(chunk)) => {
51 let len = chunk.len() as u64;
52 let span = downloaded..(downloaded + len);
53 tx.send(Event::PullProgress(ID, span.clone()))
54 .await
55 .unwrap();
56 tx_push.send((span, chunk)).await.unwrap();
57 downloaded += len;
58 }
59 Ok(None) => break,
60 Err(e) => {
61 tx.send(Event::PullError(ID, e)).await.unwrap();
62 tokio::time::sleep(options.retry_gap).await;
63 }
64 }
65 }
66 tx.send(Event::Finished(ID)).await.unwrap();
67 });
68 DownloadResult::new(event_chain, push_handle, &[handle.abort_handle()])
69}
70
71#[cfg(test)]
72mod tests {
73 extern crate std;
74 use super::*;
75 use crate::{
76 MergeProgress,
77 core::mock::{MockSeqPuller, MockSeqPusher, build_mock_data},
78 };
79 use alloc::vec;
80 use std::dbg;
81 use vec::Vec;
82
83 #[tokio::test]
84 async fn test_sequential_download() {
85 let mock_data = build_mock_data(3 * 1024);
86 let puller = MockSeqPuller::new(mock_data.clone());
87 let pusher = MockSeqPusher::new(&mock_data);
88 #[allow(clippy::single_range_in_vec_init)]
89 let download_chunks = vec![0..mock_data.len() as u64];
90 let result = download_single(
91 puller,
92 pusher.clone(),
93 DownloadOptions {
94 retry_gap: Duration::from_secs(1),
95 push_queue_cap: 1024,
96 },
97 )
98 .await;
99
100 let mut pull_progress: Vec<ProgressEntry> = Vec::new();
101 let mut push_progress: Vec<ProgressEntry> = Vec::new();
102 while let Ok(e) = result.event_chain.recv().await {
103 match e {
104 Event::PullProgress(_, p) => {
105 pull_progress.merge_progress(p);
106 }
107 Event::PushProgress(_, p) => {
108 push_progress.merge_progress(p);
109 }
110 _ => {}
111 }
112 }
113 dbg!(&pull_progress);
114 dbg!(&push_progress);
115 assert_eq!(pull_progress, download_chunks);
116 assert_eq!(push_progress, download_chunks);
117
118 result.join().await.unwrap();
119 pusher.assert().await;
120 }
121}