fast_pull/core/
multi.rs

1extern crate alloc;
2use super::macros::{check_running, poll_ok};
3use crate::{DownloadResult, Event, ProgressEntry, RandReader, RandWriter, Total, WorkerId};
4use alloc::{sync::Arc, vec::Vec};
5use bytes::Bytes;
6use core::{num::NonZeroUsize, sync::atomic::AtomicBool, time::Duration};
7use fast_steal::{SplitTask, StealTask, Task, TaskList};
8use futures::TryStreamExt;
9use tokio::sync::Mutex;
10
11#[derive(Debug, Clone)]
12pub struct DownloadOptions {
13    pub download_chunks: Vec<ProgressEntry>,
14    pub concurrent: NonZeroUsize,
15    pub retry_gap: Duration,
16    pub write_queue_cap: usize,
17}
18
19pub async fn download_multi<R, W>(
20    reader: R,
21    mut writer: W,
22    options: DownloadOptions,
23) -> DownloadResult<R::Error, W::Error>
24where
25    R: RandReader + 'static,
26    W: RandWriter + 'static,
27{
28    let (tx, event_chain) = kanal::unbounded_async();
29    let (tx_write, rx_write) =
30        kanal::bounded_async::<(WorkerId, ProgressEntry, Bytes)>(options.write_queue_cap);
31    let tx_clone = tx.clone();
32    let handle = tokio::spawn(async move {
33        while let Ok((id, spin, data)) = rx_write.recv().await {
34            poll_ok!(
35                {},
36                writer.write(spin.clone(), data.clone()).await,
37                id @ tx_clone => WriteError,
38                options.retry_gap
39            );
40            tx_clone.send(Event::WriteProgress(id, spin)).await.unwrap();
41        }
42        poll_ok!(
43            {},
44            writer.flush().await,
45            tx_clone => FlushError,
46            options.retry_gap
47        );
48    });
49    let mutex = Arc::new(Mutex::new(()));
50    let task_list = Arc::new(TaskList::from(&options.download_chunks[..]));
51    let tasks = Arc::from_iter(
52        Task::from(&*task_list)
53            .split_task(options.concurrent.get() as u64)
54            .map(Arc::new),
55    );
56    let running = Arc::new(AtomicBool::new(true));
57    for (id, task) in tasks.iter().enumerate() {
58        let task = task.clone();
59        let tasks = tasks.clone();
60        let task_list = task_list.clone();
61        let mutex = mutex.clone();
62        let tx = tx.clone();
63        let running = running.clone();
64        let mut reader = reader.clone();
65        let tx_write = tx_write.clone();
66        tokio::spawn(async move {
67            'steal_task: loop {
68                check_running!(id, running, tx);
69                let mut start = task.start();
70                if start >= task.end() {
71                    let guard = mutex.lock().await;
72                    if task.steal(&tasks, 16 * 1024) {
73                        continue;
74                    }
75                    drop(guard);
76                    tx.send(Event::Finished(id)).await.unwrap();
77                    return;
78                }
79                let download_range = &task_list.get_range(start..task.end());
80                for range in download_range {
81                    check_running!(id, running, tx);
82                    tx.send(Event::Reading(id)).await.unwrap();
83                    let mut stream = reader.read(range);
84                    let mut downloaded = 0;
85                    loop {
86                        check_running!(id, running, tx);
87                        match stream.try_next().await {
88                            Ok(Some(mut chunk)) => {
89                                let len = chunk.len() as u64;
90                                task.fetch_add_start(len);
91                                start += len;
92                                let range_start = range.start + downloaded;
93                                downloaded += len;
94                                let range_end = range.start + downloaded;
95                                let span = range_start..range_end.min(task_list.get(task.end()));
96                                let len = span.total() as usize;
97                                tx.send(Event::ReadProgress(id, span.clone()))
98                                    .await
99                                    .unwrap();
100                                tx_write
101                                    .send((id, span, chunk.split_to(len)))
102                                    .await
103                                    .unwrap();
104                                if start >= task.end() {
105                                    continue 'steal_task;
106                                }
107                            }
108                            Ok(None) => break,
109                            Err(e) => tx.send(Event::ReadError(id, e)).await.unwrap(),
110                        }
111                    }
112                }
113            }
114        });
115    }
116    DownloadResult::new(event_chain, handle, running)
117}
118
119#[cfg(test)]
120mod tests {
121    extern crate std;
122    use super::*;
123    use crate::{
124        MergeProgress, ProgressEntry,
125        core::mock::{MockRandReader, MockRandWriter, build_mock_data},
126    };
127    use alloc::vec;
128    use std::dbg;
129
130    #[tokio::test]
131    async fn test_concurrent_download() {
132        let mock_data = build_mock_data(3 * 1024);
133        let reader = MockRandReader::new(&mock_data);
134        let writer = MockRandWriter::new(&mock_data);
135        #[allow(clippy::single_range_in_vec_init)]
136        let download_chunks = vec![0..mock_data.len() as u64];
137        let result = download_multi(
138            reader,
139            writer.clone(),
140            DownloadOptions {
141                concurrent: NonZeroUsize::new(32).unwrap(),
142                retry_gap: Duration::from_secs(1),
143                write_queue_cap: 1024,
144                download_chunks: download_chunks.clone(),
145            },
146        )
147        .await;
148
149        let mut download_progress: Vec<ProgressEntry> = Vec::new();
150        let mut write_progress: Vec<ProgressEntry> = Vec::new();
151        while let Ok(e) = result.event_chain.recv().await {
152            match e {
153                Event::ReadProgress(_, p) => {
154                    download_progress.merge_progress(p);
155                }
156                Event::WriteProgress(_, p) => {
157                    write_progress.merge_progress(p);
158                }
159                _ => {}
160            }
161        }
162        dbg!(&download_progress);
163        dbg!(&write_progress);
164        assert_eq!(download_progress, download_chunks);
165        assert_eq!(write_progress, download_chunks);
166
167        result.join().await.unwrap();
168        writer.assert().await;
169    }
170}