fast_pull/core/
multi.rs

1extern crate alloc;
2use super::macros::{check_running, poll_ok};
3use crate::{DownloadResult, Event, ProgressEntry, RandReader, RandWriter, Total, WorkerId};
4use alloc::{sync::Arc, vec::Vec};
5use bytes::Bytes;
6use core::{num::NonZeroUsize, sync::atomic::AtomicBool, time::Duration};
7use fast_steal::{SplitTask, StealTask, Task, TaskList};
8use futures::TryStreamExt;
9use tokio::sync::Mutex;
10
11#[derive(Debug, Clone)]
12pub struct DownloadOptions {
13    pub download_chunks: Vec<ProgressEntry>,
14    pub concurrent: NonZeroUsize,
15    pub retry_gap: Duration,
16    pub write_queue_cap: usize,
17}
18
19pub async fn download_multi<R, W>(
20    reader: R,
21    mut writer: W,
22    options: DownloadOptions,
23) -> DownloadResult<R::Error, W::Error>
24where
25    R: RandReader + 'static,
26    W: RandWriter + 'static,
27{
28    let (tx, event_chain) = kanal::unbounded_async();
29    let (tx_write, rx_write) =
30        kanal::bounded_async::<(WorkerId, ProgressEntry, Bytes)>(options.write_queue_cap);
31    let tx_clone = tx.clone();
32    let handle = tokio::spawn(async move {
33        while let Ok((id, spin, data)) = rx_write.recv().await {
34            poll_ok!(
35                {},
36                writer.write(spin.clone(), data.clone()).await,
37                id @ tx_clone => WriteError,
38                options.retry_gap
39            );
40            tx_clone.send(Event::WriteProgress(id, spin)).await.unwrap();
41        }
42        poll_ok!(
43            {},
44            writer.flush().await,
45            tx_clone => FlushError,
46            options.retry_gap
47        );
48    });
49    let mutex = Arc::new(Mutex::new(()));
50    let task_list = Arc::new(TaskList::from(&options.download_chunks[..]));
51    let tasks = Arc::from_iter(
52        Task::from(&*task_list)
53            .split_task(options.concurrent.get() as u64)
54            .map(Arc::new),
55    );
56    let running = Arc::new(AtomicBool::new(true));
57    for (id, task) in tasks.iter().enumerate() {
58        let task = task.clone();
59        let tasks = tasks.clone();
60        let task_list = task_list.clone();
61        let mutex = mutex.clone();
62        let tx = tx.clone();
63        let running = running.clone();
64        let mut reader = reader.clone();
65        let tx_write = tx_write.clone();
66        tokio::spawn(async move {
67            'steal_task: loop {
68                check_running!(id, running, tx);
69                let mut start = task.start();
70                if start >= task.end() {
71                    let guard = mutex.lock().await;
72                    if task.steal(&tasks, 16 * 1024) {
73                        continue;
74                    }
75                    drop(guard);
76                    tx.send(Event::Finished(id)).await.unwrap();
77                    return;
78                }
79                let download_range = &task_list.get_range(start..task.end());
80                for range in download_range {
81                    check_running!(id, running, tx);
82                    tx.send(Event::Reading(id)).await.unwrap();
83                    let mut stream = reader.read(range);
84                    let mut downloaded = 0;
85                    loop {
86                        check_running!(id, running, tx);
87                        match stream.try_next().await {
88                            Ok(Some(mut chunk)) => {
89                                let len = chunk.len() as u64;
90                                task.fetch_add_start(len);
91                                start += len;
92                                let range_start = range.start + downloaded;
93                                downloaded += len;
94                                let range_end = range.start + downloaded;
95                                let span = range_start..range_end.min(task_list.get(task.end()));
96                                let len = span.total() as usize;
97                                tx.send(Event::ReadProgress(id, span.clone()))
98                                    .await
99                                    .unwrap();
100                                tx_write
101                                    .send((id, span, chunk.split_to(len)))
102                                    .await
103                                    .unwrap();
104                                if start >= task.end() {
105                                    continue 'steal_task;
106                                }
107                            }
108                            Ok(None) => break,
109                            Err(e) => {
110                                tx.send(Event::ReadError(id, e)).await.unwrap();
111                                tokio::time::sleep(options.retry_gap).await;
112                            }
113                        }
114                    }
115                }
116            }
117        });
118    }
119    DownloadResult::new(event_chain, handle, running)
120}
121
122#[cfg(test)]
123mod tests {
124    extern crate std;
125    use super::*;
126    use crate::{
127        MergeProgress, ProgressEntry,
128        core::mock::{MockRandReader, MockRandWriter, build_mock_data},
129    };
130    use alloc::vec;
131    use std::dbg;
132
133    #[tokio::test]
134    async fn test_concurrent_download() {
135        let mock_data = build_mock_data(3 * 1024);
136        let reader = MockRandReader::new(&mock_data);
137        let writer = MockRandWriter::new(&mock_data);
138        #[allow(clippy::single_range_in_vec_init)]
139        let download_chunks = vec![0..mock_data.len() as u64];
140        let result = download_multi(
141            reader,
142            writer.clone(),
143            DownloadOptions {
144                concurrent: NonZeroUsize::new(32).unwrap(),
145                retry_gap: Duration::from_secs(1),
146                write_queue_cap: 1024,
147                download_chunks: download_chunks.clone(),
148            },
149        )
150        .await;
151
152        let mut download_progress: Vec<ProgressEntry> = Vec::new();
153        let mut write_progress: Vec<ProgressEntry> = Vec::new();
154        while let Ok(e) = result.event_chain.recv().await {
155            match e {
156                Event::ReadProgress(_, p) => {
157                    download_progress.merge_progress(p);
158                }
159                Event::WriteProgress(_, p) => {
160                    write_progress.merge_progress(p);
161                }
162                _ => {}
163            }
164        }
165        dbg!(&download_progress);
166        dbg!(&write_progress);
167        assert_eq!(download_progress, download_chunks);
168        assert_eq!(write_progress, download_chunks);
169
170        result.join().await.unwrap();
171        writer.assert().await;
172    }
173}