fast_pull/core/
single.rs

1extern crate alloc;
2use super::macros::{check_running, poll_ok};
3use crate::{DownloadResult, Event, ProgressEntry, SeqReader, SeqWriter};
4use alloc::sync::Arc;
5use bytes::Bytes;
6use core::{sync::atomic::AtomicBool, time::Duration};
7use futures::TryStreamExt;
8
9#[derive(Debug, Clone)]
10pub struct DownloadOptions {
11    pub retry_gap: Duration,
12    pub write_queue_cap: usize,
13}
14
15pub async fn download_single<R, W>(
16    mut reader: R,
17    mut writer: W,
18    options: DownloadOptions,
19) -> DownloadResult<R::Error, W::Error>
20where
21    R: SeqReader + 'static,
22    W: SeqWriter + 'static,
23{
24    let (tx, event_chain) = kanal::unbounded_async();
25    let (tx_write, rx_write) =
26        kanal::bounded_async::<(ProgressEntry, Bytes)>(options.write_queue_cap);
27    let tx_clone = tx.clone();
28    const ID: usize = 0;
29    let handle = tokio::spawn(async move {
30        while let Ok((spin, data)) = rx_write.recv().await {
31            poll_ok!(
32                {},
33                writer.write(data.clone()).await,
34                ID @ tx_clone => WriteError,
35                options.retry_gap
36            );
37            tx_clone.send(Event::WriteProgress(ID, spin)).await.unwrap();
38        }
39        poll_ok!(
40            {},
41            writer.flush().await,
42            tx_clone => FlushError,
43            options.retry_gap
44        );
45    });
46    let running = Arc::new(AtomicBool::new(true));
47    let running_clone = running.clone();
48    tokio::spawn(async move {
49        check_running!(ID, running, tx);
50        tx.send(Event::Reading(ID)).await.unwrap();
51        let mut downloaded: u64 = 0;
52        let mut stream = reader.read();
53        loop {
54            check_running!(ID, running, tx);
55            match stream.try_next().await {
56                Ok(Some(chunk)) => {
57                    let len = chunk.len() as u64;
58                    let span = downloaded..(downloaded + len);
59                    tx.send(Event::ReadProgress(ID, span.clone()))
60                        .await
61                        .unwrap();
62                    tx_write.send((span, chunk)).await.unwrap();
63                    downloaded += len;
64                }
65                Ok(None) => break,
66                Err(e) => tx.send(Event::ReadError(ID, e)).await.unwrap(),
67            }
68        }
69        tx.send(Event::Finished(ID)).await.unwrap();
70    });
71    DownloadResult::new(event_chain, handle, running_clone)
72}
73
74#[cfg(test)]
75mod tests {
76    extern crate std;
77    use super::*;
78    use crate::{
79        MergeProgress,
80        core::mock::{MockSeqReader, MockSeqWriter, build_mock_data},
81    };
82    use alloc::vec;
83    use std::dbg;
84    use vec::Vec;
85
86    #[tokio::test]
87    async fn test_sequential_download() {
88        let mock_data = build_mock_data(3 * 1024);
89        let reader = MockSeqReader::new(mock_data.clone());
90        let writer = MockSeqWriter::new(&mock_data);
91        #[allow(clippy::single_range_in_vec_init)]
92        let download_chunks = vec![0..mock_data.len() as u64];
93        let result = download_single(
94            reader,
95            writer.clone(),
96            DownloadOptions {
97                retry_gap: Duration::from_secs(1),
98                write_queue_cap: 1024,
99            },
100        )
101        .await;
102
103        let mut download_progress: Vec<ProgressEntry> = Vec::new();
104        let mut write_progress: Vec<ProgressEntry> = Vec::new();
105        while let Ok(e) = result.event_chain.recv().await {
106            match e {
107                Event::ReadProgress(_, p) => {
108                    download_progress.merge_progress(p);
109                }
110                Event::WriteProgress(_, p) => {
111                    write_progress.merge_progress(p);
112                }
113                _ => {}
114            }
115        }
116        dbg!(&download_progress);
117        dbg!(&write_progress);
118        assert_eq!(download_progress, download_chunks);
119        assert_eq!(write_progress, download_chunks);
120
121        result.join().await.unwrap();
122        writer.assert().await;
123    }
124}