1extern crate alloc;
2use super::macros::poll_ok;
3use crate::{DownloadResult, Event, ProgressEntry, SeqPuller, SeqPusher};
4use alloc::sync::Arc;
5use bytes::Bytes;
6use core::time::Duration;
7use fast_steal::{Executor, Handle};
8use futures::TryStreamExt;
9
10#[derive(Debug, Clone)]
11pub struct DownloadOptions {
12 pub retry_gap: Duration,
13 pub push_queue_cap: usize,
14}
15
16#[derive(Clone)]
17pub struct EmptyHandle;
18impl Handle for EmptyHandle {
19 type Output = ();
20 fn abort(&mut self) -> Self::Output {}
21}
22
23#[derive(Debug, Clone)]
24pub struct EmptyExecutor;
25impl Executor for EmptyExecutor {
26 type Handle = EmptyHandle;
27 fn execute(
28 self: Arc<Self>,
29 _: Arc<fast_steal::Task>,
30 _: Arc<fast_steal::TaskList<Self>>,
31 ) -> Self::Handle {
32 EmptyHandle
33 }
34}
35
36pub async fn download_single<R, W>(
37 mut puller: R,
38 mut pusher: W,
39 options: DownloadOptions,
40) -> DownloadResult<EmptyExecutor, R::Error, W::Error>
41where
42 R: SeqPuller + 'static,
43 W: SeqPusher + 'static,
44{
45 let (tx, event_chain) = kanal::unbounded_async();
46 let (tx_push, rx_push) = kanal::bounded_async::<(ProgressEntry, Bytes)>(options.push_queue_cap);
47 let tx_clone = tx.clone();
48 const ID: usize = 0;
49 let push_handle = tokio::spawn(async move {
50 while let Ok((spin, data)) = rx_push.recv().await {
51 poll_ok!(
52 pusher.push(&data).await,
53 ID @ tx_clone => PushError,
54 options.retry_gap
55 );
56 let _ = tx_clone.send(Event::PushProgress(ID, spin)).await;
57 }
58 poll_ok!(
59 pusher.flush().await,
60 tx_clone => FlushError,
61 options.retry_gap
62 );
63 });
64 let handle = tokio::spawn(async move {
65 let _ = tx.send(Event::Pulling(ID)).await;
66 let mut downloaded: u64 = 0;
67 let mut stream = loop {
68 match puller.pull().await {
69 Ok(t) => break t,
70 Err((e, retry_gap)) => {
71 let _ = tx.send(Event::PullError(ID, e)).await;
72 tokio::time::sleep(retry_gap.unwrap_or(options.retry_gap)).await;
73 }
74 }
75 };
76 loop {
77 match stream.try_next().await {
78 Ok(Some(chunk)) => {
79 let len = chunk.len() as u64;
80 let span = downloaded..(downloaded + len);
81 let _ = tx.send(Event::PullProgress(ID, span.clone())).await;
82 tx_push.send((span, chunk)).await.unwrap();
83 downloaded += len;
84 }
85 Ok(None) => break,
86 Err((e, retry_gap)) => {
87 let _ = tx.send(Event::PullError(ID, e)).await;
88 tokio::time::sleep(retry_gap.unwrap_or(options.retry_gap)).await
89 }
90 }
91 }
92 let _ = tx.send(Event::Finished(ID)).await;
93 });
94 DownloadResult::new(
95 event_chain,
96 push_handle,
97 Some(&[handle.abort_handle()]),
98 None,
99 )
100}
101
102#[cfg(test)]
103mod tests {
104 extern crate std;
105 use super::*;
106 use crate::{
107 MergeProgress,
108 mem::MemPusher,
109 mock::{MockPuller, build_mock_data},
110 };
111 use alloc::vec;
112 use std::dbg;
113 use vec::Vec;
114
115 #[tokio::test]
116 async fn test_sequential_download() {
117 let mock_data = build_mock_data(3 * 1024);
118 let puller = MockPuller::new(&mock_data);
119 let pusher = MemPusher::with_capacity(mock_data.len());
120 #[allow(clippy::single_range_in_vec_init)]
121 let download_chunks = vec![0..mock_data.len() as u64];
122 let result = download_single(
123 puller,
124 pusher.clone(),
125 DownloadOptions {
126 retry_gap: Duration::from_secs(1),
127 push_queue_cap: 1024,
128 },
129 )
130 .await;
131
132 let mut pull_progress: Vec<ProgressEntry> = Vec::new();
133 let mut push_progress: Vec<ProgressEntry> = Vec::new();
134 while let Ok(e) = result.event_chain.recv().await {
135 match e {
136 Event::PullProgress(_, p) => {
137 pull_progress.merge_progress(p);
138 }
139 Event::PushProgress(_, p) => {
140 push_progress.merge_progress(p);
141 }
142 _ => {}
143 }
144 }
145 dbg!(&pull_progress);
146 dbg!(&push_progress);
147 assert_eq!(pull_progress, download_chunks);
148 assert_eq!(push_progress, download_chunks);
149
150 result.join().await.unwrap();
151 assert_eq!(&**pusher.receive.lock(), mock_data);
152 }
153}