1extern crate alloc;
2use super::macros::poll_ok;
3use crate::{DownloadResult, Event, ProgressEntry, SeqPuller, SeqPusher};
4use bytes::Bytes;
5use core::time::Duration;
6use fast_steal::{Executor, Handle};
7use futures::TryStreamExt;
8
9#[derive(Debug, Clone)]
10pub struct DownloadOptions {
11 pub retry_gap: Duration,
12 pub push_queue_cap: usize,
13}
14
15#[derive(Clone)]
16pub struct EmptyHandle;
17impl Handle for EmptyHandle {
18 type Output = ();
19 fn abort(&mut self) -> Self::Output {}
20}
21
22#[derive(Debug, Clone)]
23pub struct EmptyExecutor;
24impl Executor for EmptyExecutor {
25 type Handle = EmptyHandle;
26 fn execute(
27 self: alloc::sync::Arc<Self>,
28 _: alloc::sync::Arc<fast_steal::Task>,
29 _: alloc::sync::Arc<fast_steal::TaskList<Self>>,
30 ) -> Self::Handle {
31 EmptyHandle
32 }
33}
34
35pub async fn download_single<R, W>(
36 mut puller: R,
37 mut pusher: W,
38 options: DownloadOptions,
39) -> DownloadResult<EmptyExecutor, R::Error, W::Error>
40where
41 R: SeqPuller + 'static,
42 W: SeqPusher + 'static,
43{
44 let (tx, event_chain) = kanal::unbounded_async();
45 let (tx_push, rx_push) = kanal::bounded_async::<(ProgressEntry, Bytes)>(options.push_queue_cap);
46 let tx_clone = tx.clone();
47 const ID: usize = 0;
48 let push_handle = tokio::spawn(async move {
49 while let Ok((spin, data)) = rx_push.recv().await {
50 poll_ok!(
51 pusher.push(&data).await,
52 ID @ tx_clone => PushError,
53 options.retry_gap
54 );
55 tx_clone.send(Event::PushProgress(ID, spin)).await.unwrap();
56 }
57 poll_ok!(
58 pusher.flush().await,
59 tx_clone => FlushError,
60 options.retry_gap
61 );
62 });
63 let handle = tokio::spawn(async move {
64 tx.send(Event::Pulling(ID)).await.unwrap();
65 let mut downloaded: u64 = 0;
66 let mut stream = loop {
67 match puller.pull().await {
68 Ok(t) => break t,
69 Err((e, retry_gap)) => {
70 tx.send(Event::PullError(ID, e)).await.unwrap();
71 tokio::time::sleep(retry_gap.unwrap_or(options.retry_gap)).await;
72 }
73 }
74 };
75 loop {
76 match stream.try_next().await {
77 Ok(Some(chunk)) => {
78 let len = chunk.len() as u64;
79 let span = downloaded..(downloaded + len);
80 tx.send(Event::PullProgress(ID, span.clone()))
81 .await
82 .unwrap();
83 tx_push.send((span, chunk)).await.unwrap();
84 downloaded += len;
85 }
86 Ok(None) => break,
87 Err((e, retry_gap)) => {
88 tx.send(Event::PullError(ID, e)).await.unwrap();
89 tokio::time::sleep(retry_gap.unwrap_or(options.retry_gap)).await
90 }
91 }
92 }
93 tx.send(Event::Finished(ID)).await.unwrap();
94 });
95 DownloadResult::new(
96 event_chain,
97 push_handle,
98 Some(&[handle.abort_handle()]),
99 None,
100 )
101}
102
103#[cfg(test)]
104mod tests {
105 extern crate std;
106 use super::*;
107 use crate::{
108 MergeProgress,
109 mem::MemPusher,
110 mock::{MockPuller, build_mock_data},
111 };
112 use alloc::vec;
113 use std::dbg;
114 use vec::Vec;
115
116 #[tokio::test]
117 async fn test_sequential_download() {
118 let mock_data = build_mock_data(3 * 1024);
119 let puller = MockPuller::new(&mock_data);
120 let pusher = MemPusher::with_capacity(mock_data.len());
121 #[allow(clippy::single_range_in_vec_init)]
122 let download_chunks = vec![0..mock_data.len() as u64];
123 let result = download_single(
124 puller,
125 pusher.clone(),
126 DownloadOptions {
127 retry_gap: Duration::from_secs(1),
128 push_queue_cap: 1024,
129 },
130 )
131 .await;
132
133 let mut pull_progress: Vec<ProgressEntry> = Vec::new();
134 let mut push_progress: Vec<ProgressEntry> = Vec::new();
135 while let Ok(e) = result.event_chain.recv().await {
136 match e {
137 Event::PullProgress(_, p) => {
138 pull_progress.merge_progress(p);
139 }
140 Event::PushProgress(_, p) => {
141 push_progress.merge_progress(p);
142 }
143 _ => {}
144 }
145 }
146 dbg!(&pull_progress);
147 dbg!(&push_progress);
148 assert_eq!(pull_progress, download_chunks);
149 assert_eq!(push_progress, download_chunks);
150
151 result.join().await.unwrap();
152 assert_eq!(&**pusher.receive.lock().await, mock_data);
153 }
154}