fast_pull/core/
multi.rs

1extern crate alloc;
2use super::macros::{check_running, poll_ok};
3use crate::{DownloadResult, Event, ProgressEntry, RandReader, RandWriter, Total, WorkerId};
4use alloc::sync::Arc;
5use bytes::Bytes;
6use core::{num::NonZeroUsize, sync::atomic::AtomicBool, time::Duration};
7use fast_steal::{SplitTask, StealTask, Task, TaskList};
8use futures::{TryStreamExt, lock::Mutex};
9
10#[derive(Debug, Clone)]
11pub struct DownloadOptions {
12    pub download_chunks: Vec<ProgressEntry>,
13    pub concurrent: NonZeroUsize,
14    pub retry_gap: Duration,
15    pub write_queue_cap: usize,
16}
17
18pub async fn download_multi<R, W>(
19    reader: R,
20    mut writer: W,
21    options: DownloadOptions,
22) -> DownloadResult<R::Error, W::Error>
23where
24    R: RandReader + 'static,
25    W: RandWriter + 'static,
26{
27    let (tx, event_chain) = kanal::unbounded_async();
28    let (tx_write, rx_write) =
29        kanal::bounded_async::<(WorkerId, ProgressEntry, Bytes)>(options.write_queue_cap);
30    let tx_clone = tx.clone();
31    let handle = tokio::spawn(async move {
32        while let Ok((id, spin, data)) = rx_write.recv().await {
33            poll_ok!(
34                {},
35                writer.write(spin.clone(), data.clone()).await,
36                id @ tx_clone => WriteError,
37                options.retry_gap
38            );
39            tx_clone.send(Event::WriteProgress(id, spin)).await.unwrap();
40        }
41        poll_ok!(
42            {},
43            writer.flush().await,
44            tx_clone => FlushError,
45            options.retry_gap
46        );
47    });
48    let mutex = Arc::new(Mutex::new(()));
49    let task_list = Arc::new(TaskList::from(&options.download_chunks[..]));
50    let tasks = Arc::from_iter(
51        Task::from(&*task_list)
52            .split_task(options.concurrent.get() as u64)
53            .map(Arc::new),
54    );
55    let running = Arc::new(AtomicBool::new(true));
56    for (id, task) in tasks.iter().enumerate() {
57        let task = task.clone();
58        let tasks = tasks.clone();
59        let task_list = task_list.clone();
60        let mutex = mutex.clone();
61        let tx = tx.clone();
62        let running = running.clone();
63        let mut reader = reader.clone();
64        let tx_write = tx_write.clone();
65        tokio::spawn(async move {
66            'steal_task: loop {
67                check_running!(id, running, tx);
68                let mut start = task.start();
69                if start >= task.end() {
70                    let guard = mutex.lock().await;
71                    if task.steal(&tasks, 16 * 1024) {
72                        continue;
73                    }
74                    drop(guard);
75                    tx.send(Event::Finished(id)).await.unwrap();
76                    return;
77                }
78                let download_range = &task_list.get_range(start..task.end());
79                for range in download_range {
80                    check_running!(id, running, tx);
81                    tx.send(Event::Reading(id)).await.unwrap();
82                    let mut stream = reader.read(range);
83                    let mut downloaded = 0;
84                    loop {
85                        check_running!(id, running, tx);
86                        match stream.try_next().await {
87                            Ok(Some(mut chunk)) => {
88                                let len = chunk.len() as u64;
89                                task.fetch_add_start(len);
90                                start += len;
91                                let range_start = range.start + downloaded;
92                                downloaded += len;
93                                let range_end = range.start + downloaded;
94                                let span = range_start..range_end.min(task_list.get(task.end()));
95                                let len = span.total() as usize;
96                                tx.send(Event::ReadProgress(id, span.clone()))
97                                    .await
98                                    .unwrap();
99                                tx_write
100                                    .send((id, span, chunk.split_to(len)))
101                                    .await
102                                    .unwrap();
103                                if start >= task.end() {
104                                    continue 'steal_task;
105                                }
106                            }
107                            Ok(None) => break,
108                            Err(e) => tx.send(Event::ReadError(id, e)).await.unwrap(),
109                        }
110                    }
111                }
112            }
113        });
114    }
115    DownloadResult::new(event_chain, handle, running)
116}
117
118#[cfg(test)]
119mod tests {
120    use super::*;
121    use crate::{
122        MergeProgress, ProgressEntry,
123        core::mock::{MockRandReader, MockRandWriter, build_mock_data},
124    };
125
126    #[tokio::test]
127    async fn test_concurrent_download() {
128        let mock_data = build_mock_data(3 * 1024);
129        let reader = MockRandReader::new(&mock_data);
130        let writer = MockRandWriter::new(&mock_data);
131        #[allow(clippy::single_range_in_vec_init)]
132        let download_chunks = vec![0..mock_data.len() as u64];
133        let result = download_multi(
134            reader,
135            writer.clone(),
136            DownloadOptions {
137                concurrent: NonZeroUsize::new(32).unwrap(),
138                retry_gap: Duration::from_secs(1),
139                write_queue_cap: 1024,
140                download_chunks: download_chunks.clone(),
141            },
142        )
143        .await;
144
145        let mut download_progress: Vec<ProgressEntry> = Vec::new();
146        let mut write_progress: Vec<ProgressEntry> = Vec::new();
147        while let Ok(e) = result.event_chain.recv().await {
148            match e {
149                Event::ReadProgress(_, p) => {
150                    download_progress.merge_progress(p);
151                }
152                Event::WriteProgress(_, p) => {
153                    write_progress.merge_progress(p);
154                }
155                _ => {}
156            }
157        }
158        dbg!(&download_progress);
159        dbg!(&write_progress);
160        assert_eq!(download_progress, download_chunks);
161        assert_eq!(write_progress, download_chunks);
162
163        result.join().await.unwrap();
164        writer.assert().await;
165    }
166}