fast_pull/core/
single.rs

1extern crate alloc;
2use super::macros::poll_ok;
3use crate::{DownloadResult, Event, ProgressEntry, SeqPuller, SeqPusher};
4use alloc::sync::Arc;
5use bytes::Bytes;
6use core::time::Duration;
7use fast_steal::{Executor, Handle};
8use futures::TryStreamExt;
9
10#[derive(Debug, Clone)]
11pub struct DownloadOptions {
12    pub retry_gap: Duration,
13    pub push_queue_cap: usize,
14}
15
16#[derive(Clone)]
17pub struct EmptyHandle;
18impl Handle for EmptyHandle {
19    type Output = ();
20    fn abort(&mut self) -> Self::Output {}
21}
22
23#[derive(Debug, Clone)]
24pub struct EmptyExecutor;
25impl Executor for EmptyExecutor {
26    type Handle = EmptyHandle;
27    fn execute(
28        self: Arc<Self>,
29        _: Arc<fast_steal::Task>,
30        _: Arc<fast_steal::TaskList<Self>>,
31    ) -> Self::Handle {
32        EmptyHandle
33    }
34}
35
36pub async fn download_single<R, W>(
37    mut puller: R,
38    mut pusher: W,
39    options: DownloadOptions,
40) -> DownloadResult<EmptyExecutor, R::Error, W::Error>
41where
42    R: SeqPuller + 'static,
43    W: SeqPusher + 'static,
44{
45    let (tx, event_chain) = kanal::unbounded_async();
46    let (tx_push, rx_push) = kanal::bounded_async::<(ProgressEntry, Bytes)>(options.push_queue_cap);
47    let tx_clone = tx.clone();
48    const ID: usize = 0;
49    let push_handle = tokio::spawn(async move {
50        while let Ok((spin, data)) = rx_push.recv().await {
51            poll_ok!(
52                pusher.push(&data).await,
53                ID @ tx_clone => PushError,
54                options.retry_gap
55            );
56            let _ = tx_clone.send(Event::PushProgress(ID, spin)).await;
57        }
58        poll_ok!(
59            pusher.flush().await,
60            tx_clone => FlushError,
61            options.retry_gap
62        );
63    });
64    let handle = tokio::spawn(async move {
65        let _ = tx.send(Event::Pulling(ID)).await;
66        let mut downloaded: u64 = 0;
67        let mut stream = loop {
68            match puller.pull().await {
69                Ok(t) => break t,
70                Err((e, retry_gap)) => {
71                    let _ = tx.send(Event::PullError(ID, e)).await;
72                    tokio::time::sleep(retry_gap.unwrap_or(options.retry_gap)).await;
73                }
74            }
75        };
76        loop {
77            match stream.try_next().await {
78                Ok(Some(chunk)) => {
79                    let len = chunk.len() as u64;
80                    let span = downloaded..(downloaded + len);
81                    let _ = tx.send(Event::PullProgress(ID, span.clone())).await;
82                    tx_push.send((span, chunk)).await.unwrap();
83                    downloaded += len;
84                }
85                Ok(None) => break,
86                Err((e, retry_gap)) => {
87                    let _ = tx.send(Event::PullError(ID, e)).await;
88                    tokio::time::sleep(retry_gap.unwrap_or(options.retry_gap)).await
89                }
90            }
91        }
92        let _ = tx.send(Event::Finished(ID)).await;
93    });
94    DownloadResult::new(
95        event_chain,
96        push_handle,
97        Some(&[handle.abort_handle()]),
98        None,
99    )
100}
101
102#[cfg(test)]
103mod tests {
104    extern crate std;
105    use super::*;
106    use crate::{
107        MergeProgress,
108        mem::MemPusher,
109        mock::{MockPuller, build_mock_data},
110    };
111    use alloc::vec;
112    use std::dbg;
113    use vec::Vec;
114
115    #[tokio::test]
116    async fn test_sequential_download() {
117        let mock_data = build_mock_data(3 * 1024);
118        let puller = MockPuller::new(&mock_data);
119        let pusher = MemPusher::with_capacity(mock_data.len());
120        #[allow(clippy::single_range_in_vec_init)]
121        let download_chunks = vec![0..mock_data.len() as u64];
122        let result = download_single(
123            puller,
124            pusher.clone(),
125            DownloadOptions {
126                retry_gap: Duration::from_secs(1),
127                push_queue_cap: 1024,
128            },
129        )
130        .await;
131
132        let mut pull_progress: Vec<ProgressEntry> = Vec::new();
133        let mut push_progress: Vec<ProgressEntry> = Vec::new();
134        while let Ok(e) = result.event_chain.recv().await {
135            match e {
136                Event::PullProgress(_, p) => {
137                    pull_progress.merge_progress(p);
138                }
139                Event::PushProgress(_, p) => {
140                    push_progress.merge_progress(p);
141                }
142                _ => {}
143            }
144        }
145        dbg!(&pull_progress);
146        dbg!(&push_progress);
147        assert_eq!(pull_progress, download_chunks);
148        assert_eq!(push_progress, download_chunks);
149
150        result.join().await.unwrap();
151        assert_eq!(&**pusher.receive.lock(), mock_data);
152    }
153}