fast_pull/core/
single.rs

1extern crate alloc;
2use super::macros::poll_ok;
3use crate::{DownloadResult, Event, ProgressEntry, SeqPuller, SeqPusher};
4use bytes::Bytes;
5use core::time::Duration;
6use fast_steal::{Executor, Handle};
7use futures::TryStreamExt;
8
9#[derive(Debug, Clone)]
10pub struct DownloadOptions {
11    pub retry_gap: Duration,
12    pub push_queue_cap: usize,
13}
14
15#[derive(Clone)]
16pub struct EmptyHandle;
17impl Handle for EmptyHandle {
18    type Output = ();
19    fn abort(&mut self) -> Self::Output {}
20}
21pub struct EmptyExecutor;
22impl Executor for EmptyExecutor {
23    type Handle = EmptyHandle;
24    fn execute(
25        self: alloc::sync::Arc<Self>,
26        _: alloc::sync::Arc<fast_steal::Task>,
27        _: alloc::sync::Arc<fast_steal::TaskList<Self>>,
28    ) -> Self::Handle {
29        EmptyHandle
30    }
31}
32
33pub async fn download_single<R, W>(
34    mut puller: R,
35    mut pusher: W,
36    options: DownloadOptions,
37) -> DownloadResult<EmptyExecutor, R::Error, W::Error>
38where
39    R: SeqPuller + 'static,
40    W: SeqPusher + 'static,
41{
42    let (tx, event_chain) = kanal::unbounded_async();
43    let (tx_push, rx_push) = kanal::bounded_async::<(ProgressEntry, Bytes)>(options.push_queue_cap);
44    let tx_clone = tx.clone();
45    const ID: usize = 0;
46    let push_handle = tokio::spawn(async move {
47        while let Ok((spin, data)) = rx_push.recv().await {
48            poll_ok!(
49                pusher.push(&data).await,
50                ID @ tx_clone => PushError,
51                options.retry_gap
52            );
53            tx_clone.send(Event::PushProgress(ID, spin)).await.unwrap();
54        }
55        poll_ok!(
56            pusher.flush().await,
57            tx_clone => FlushError,
58            options.retry_gap
59        );
60    });
61    let handle = tokio::spawn(async move {
62        tx.send(Event::Pulling(ID)).await.unwrap();
63        let mut downloaded: u64 = 0;
64        let mut stream = puller.pull();
65        loop {
66            match stream.try_next().await {
67                Ok(Some(chunk)) => {
68                    let len = chunk.len() as u64;
69                    let span = downloaded..(downloaded + len);
70                    tx.send(Event::PullProgress(ID, span.clone()))
71                        .await
72                        .unwrap();
73                    tx_push.send((span, chunk)).await.unwrap();
74                    downloaded += len;
75                }
76                Ok(None) => break,
77                Err(e) => {
78                    tx.send(Event::PullError(ID, e)).await.unwrap();
79                    tokio::time::sleep(options.retry_gap).await;
80                }
81            }
82        }
83        tx.send(Event::Finished(ID)).await.unwrap();
84    });
85    DownloadResult::new(event_chain, push_handle, &[handle.abort_handle()], None)
86}
87
88#[cfg(test)]
89mod tests {
90    extern crate std;
91    use super::*;
92    use crate::{
93        MergeProgress,
94        mem::MemPusher,
95        mock::{MockPuller, build_mock_data},
96    };
97    use alloc::vec;
98    use std::dbg;
99    use vec::Vec;
100
101    #[tokio::test]
102    async fn test_sequential_download() {
103        let mock_data = build_mock_data(3 * 1024);
104        let puller = MockPuller::new(&mock_data);
105        let pusher = MemPusher::with_capacity(mock_data.len());
106        #[allow(clippy::single_range_in_vec_init)]
107        let download_chunks = vec![0..mock_data.len() as u64];
108        let result = download_single(
109            puller,
110            pusher.clone(),
111            DownloadOptions {
112                retry_gap: Duration::from_secs(1),
113                push_queue_cap: 1024,
114            },
115        )
116        .await;
117
118        let mut pull_progress: Vec<ProgressEntry> = Vec::new();
119        let mut push_progress: Vec<ProgressEntry> = Vec::new();
120        while let Ok(e) = result.event_chain.recv().await {
121            match e {
122                Event::PullProgress(_, p) => {
123                    pull_progress.merge_progress(p);
124                }
125                Event::PushProgress(_, p) => {
126                    push_progress.merge_progress(p);
127                }
128                _ => {}
129            }
130        }
131        dbg!(&pull_progress);
132        dbg!(&push_progress);
133        assert_eq!(pull_progress, download_chunks);
134        assert_eq!(push_progress, download_chunks);
135
136        result.join().await.unwrap();
137        assert_eq!(&**pusher.receive.lock(), mock_data);
138    }
139}