fast_pull/core/
single.rs

1extern crate alloc;
2use super::macros::{check_running, poll_ok};
3use crate::{DownloadResult, Event, ProgressEntry, SeqReader, SeqWriter};
4use alloc::sync::Arc;
5use bytes::Bytes;
6use core::{sync::atomic::AtomicBool, time::Duration};
7use futures::TryStreamExt;
8
9#[derive(Debug, Clone)]
10pub struct DownloadOptions {
11    pub retry_gap: Duration,
12    pub write_queue_cap: usize,
13}
14
15pub async fn download_single<R, W>(
16    mut reader: R,
17    mut writer: W,
18    options: DownloadOptions,
19) -> DownloadResult<R::Error, W::Error>
20where
21    R: SeqReader + 'static,
22    W: SeqWriter + 'static,
23{
24    let (tx, event_chain) = kanal::unbounded_async();
25    let (tx_write, rx_write) =
26        kanal::bounded_async::<(ProgressEntry, Bytes)>(options.write_queue_cap);
27    let tx_clone = tx.clone();
28    const ID: usize = 0;
29    let handle = tokio::spawn(async move {
30        while let Ok((spin, data)) = rx_write.recv().await {
31            poll_ok!(
32                {},
33                writer.write(data.clone()).await,
34                ID @ tx_clone => WriteError,
35                options.retry_gap
36            );
37            tx_clone.send(Event::WriteProgress(ID, spin)).await.unwrap();
38        }
39        poll_ok!(
40            {},
41            writer.flush().await,
42            tx_clone => FlushError,
43            options.retry_gap
44        );
45    });
46    let running = Arc::new(AtomicBool::new(true));
47    let running_clone = running.clone();
48    tokio::spawn(async move {
49        check_running!(ID, running, tx);
50        tx.send(Event::Reading(ID)).await.unwrap();
51        let mut downloaded: u64 = 0;
52        let mut stream = reader.read();
53        loop {
54            check_running!(ID, running, tx);
55            match stream.try_next().await {
56                Ok(Some(chunk)) => {
57                    let len = chunk.len() as u64;
58                    let span = downloaded..(downloaded + len);
59                    tx.send(Event::ReadProgress(ID, span.clone()))
60                        .await
61                        .unwrap();
62                    tx_write.send((span, chunk)).await.unwrap();
63                    downloaded += len;
64                }
65                Ok(None) => break,
66                Err(e) => {
67                    tx.send(Event::ReadError(ID, e)).await.unwrap();
68                    tokio::time::sleep(options.retry_gap).await;
69                }
70            }
71        }
72        tx.send(Event::Finished(ID)).await.unwrap();
73    });
74    DownloadResult::new(event_chain, handle, running_clone)
75}
76
77#[cfg(test)]
78mod tests {
79    extern crate std;
80    use super::*;
81    use crate::{
82        MergeProgress,
83        core::mock::{MockSeqReader, MockSeqWriter, build_mock_data},
84    };
85    use alloc::vec;
86    use std::dbg;
87    use vec::Vec;
88
89    #[tokio::test]
90    async fn test_sequential_download() {
91        let mock_data = build_mock_data(3 * 1024);
92        let reader = MockSeqReader::new(mock_data.clone());
93        let writer = MockSeqWriter::new(&mock_data);
94        #[allow(clippy::single_range_in_vec_init)]
95        let download_chunks = vec![0..mock_data.len() as u64];
96        let result = download_single(
97            reader,
98            writer.clone(),
99            DownloadOptions {
100                retry_gap: Duration::from_secs(1),
101                write_queue_cap: 1024,
102            },
103        )
104        .await;
105
106        let mut download_progress: Vec<ProgressEntry> = Vec::new();
107        let mut write_progress: Vec<ProgressEntry> = Vec::new();
108        while let Ok(e) = result.event_chain.recv().await {
109            match e {
110                Event::ReadProgress(_, p) => {
111                    download_progress.merge_progress(p);
112                }
113                Event::WriteProgress(_, p) => {
114                    write_progress.merge_progress(p);
115                }
116                _ => {}
117            }
118        }
119        dbg!(&download_progress);
120        dbg!(&write_progress);
121        assert_eq!(download_progress, download_chunks);
122        assert_eq!(write_progress, download_chunks);
123
124        result.join().await.unwrap();
125        writer.assert().await;
126    }
127}