fast_pull/core/
single.rs

1extern crate alloc;
2use super::macros::{check_running, poll_ok};
3use crate::{DownloadResult, Event, ProgressEntry, SeqReader, SeqWriter};
4use alloc::sync::Arc;
5use bytes::Bytes;
6use core::{sync::atomic::AtomicBool, time::Duration};
7use futures::TryStreamExt;
8
9#[derive(Debug, Clone)]
10pub struct DownloadOptions {
11    pub retry_gap: Duration,
12    pub write_queue_cap: usize,
13}
14
15pub async fn download_single<R, W>(
16    mut reader: R,
17    mut writer: W,
18    options: DownloadOptions,
19) -> DownloadResult<R::Error, W::Error>
20where
21    R: SeqReader + 'static,
22    W: SeqWriter + 'static,
23{
24    let (tx, event_chain) = kanal::unbounded_async();
25    let (tx_write, rx_write) =
26        kanal::bounded_async::<(ProgressEntry, Bytes)>(options.write_queue_cap);
27    let tx_clone = tx.clone();
28    const ID: usize = 0;
29    let handle = tokio::spawn(async move {
30        while let Ok((spin, data)) = rx_write.recv().await {
31            poll_ok!(
32                {},
33                writer.write(data.clone()).await,
34                ID @ tx_clone => WriteError,
35                options.retry_gap
36            );
37            tx_clone.send(Event::WriteProgress(ID, spin)).await.unwrap();
38        }
39        poll_ok!(
40            {},
41            writer.flush().await,
42            tx_clone => FlushError,
43            options.retry_gap
44        );
45    });
46    let running = Arc::new(AtomicBool::new(true));
47    let running_clone = running.clone();
48    tokio::spawn(async move {
49        check_running!(ID, running, tx);
50        tx.send(Event::Reading(ID)).await.unwrap();
51        let mut downloaded: u64 = 0;
52        let mut stream = reader.read();
53        loop {
54            check_running!(ID, running, tx);
55            match stream.try_next().await {
56                Ok(Some(chunk)) => {
57                    let len = chunk.len() as u64;
58                    let span = downloaded..(downloaded + len);
59                    tx.send(Event::ReadProgress(ID, span.clone()))
60                        .await
61                        .unwrap();
62                    tx_write.send((span, chunk)).await.unwrap();
63                    downloaded += len;
64                }
65                Ok(None) => break,
66                Err(e) => tx.send(Event::ReadError(ID, e)).await.unwrap(),
67            }
68        }
69        tx.send(Event::Finished(ID)).await.unwrap();
70    });
71    DownloadResult::new(event_chain, handle, running_clone)
72}
73
74#[cfg(test)]
75mod tests {
76    use super::*;
77    use crate::{
78        MergeProgress,
79        core::mock::{MockSeqReader, MockSeqWriter, build_mock_data},
80    };
81
82    #[tokio::test]
83    async fn test_sequential_download() {
84        let mock_data = build_mock_data(3 * 1024);
85        let reader = MockSeqReader::new(mock_data.clone());
86        let writer = MockSeqWriter::new(&mock_data);
87        #[allow(clippy::single_range_in_vec_init)]
88        let download_chunks = vec![0..mock_data.len() as u64];
89        let result = download_single(
90            reader,
91            writer.clone(),
92            DownloadOptions {
93                retry_gap: Duration::from_secs(1),
94                write_queue_cap: 1024,
95            },
96        )
97        .await;
98
99        let mut download_progress: Vec<ProgressEntry> = Vec::new();
100        let mut write_progress: Vec<ProgressEntry> = Vec::new();
101        while let Ok(e) = result.event_chain.recv().await {
102            match e {
103                Event::ReadProgress(_, p) => {
104                    download_progress.merge_progress(p);
105                }
106                Event::WriteProgress(_, p) => {
107                    write_progress.merge_progress(p);
108                }
109                _ => {}
110            }
111        }
112        dbg!(&download_progress);
113        dbg!(&write_progress);
114        assert_eq!(download_progress, download_chunks);
115        assert_eq!(write_progress, download_chunks);
116
117        result.join().await.unwrap();
118        writer.assert().await;
119    }
120}