fast_down/core/
single.rs

1use super::DownloadResult;
2use crate::{ConnectErrorKind, Event, SeqWriter};
3use reqwest::{Client, IntoUrl};
4use std::{
5    sync::{
6        atomic::{AtomicBool, Ordering},
7        Arc,
8    },
9    time::Duration,
10};
11use tokio::sync::mpsc;
12
13#[derive(Debug, Clone)]
14pub struct DownloadOptions {
15    pub client: Client,
16    pub retry_gap: Duration,
17}
18
19pub async fn download(
20    url: impl IntoUrl,
21    mut writer: impl SeqWriter + 'static,
22    options: DownloadOptions,
23) -> Result<DownloadResult, reqwest::Error> {
24    let url = url.into_url()?;
25    let (tx, event_chain) = mpsc::channel(1024);
26    let (tx_write, mut rx_write) = mpsc::channel(1024);
27    let tx_clone = tx.clone();
28    let handle = tokio::spawn(async move {
29        while let Some((spin, data)) = rx_write.recv().await {
30            loop {
31                match writer.write_sequentially(&data).await {
32                    Ok(_) => break,
33                    Err(e) => tx_clone.send(Event::WriteError(e)).await.unwrap(),
34                }
35                tokio::time::sleep(options.retry_gap).await;
36            }
37            tx_clone.send(Event::WriteProgress(spin)).await.unwrap();
38        }
39        loop {
40            match writer.flush().await {
41                Ok(_) => break,
42                Err(e) => tx_clone.send(Event::WriteError(e)).await.unwrap(),
43            };
44            tokio::time::sleep(options.retry_gap).await;
45        }
46    });
47    let running = Arc::new(AtomicBool::new(true));
48    let running_clone = running.clone();
49    tokio::spawn(async move {
50        let mut downloaded: u64 = 0;
51        let mut response = loop {
52            if !running.load(Ordering::Relaxed) {
53                tx.send(Event::Abort(0)).await.unwrap();
54                return;
55            }
56            tx.send(Event::Connecting(0)).await.unwrap();
57            match options.client.get(url.clone()).send().await {
58                Ok(response) => break response,
59                Err(e) => tx
60                    .send(Event::ConnectError(0, ConnectErrorKind::Reqwest(e)))
61                    .await
62                    .unwrap(),
63            }
64            tokio::time::sleep(options.retry_gap).await;
65        };
66        tx.send(Event::Downloading(0)).await.unwrap();
67        loop {
68            let chunk = loop {
69                if !running.load(Ordering::Relaxed) {
70                    tx.send(Event::Abort(0)).await.unwrap();
71                    return;
72                }
73                match response.chunk().await {
74                    Ok(chunk) => break chunk,
75                    Err(e) => tx.send(Event::DownloadError(0, e)).await.unwrap(),
76                }
77                tokio::time::sleep(options.retry_gap).await;
78            };
79            if chunk.is_none() {
80                break;
81            }
82            let chunk = chunk.unwrap();
83            let len = chunk.len() as u64;
84            let span = downloaded..(downloaded + len);
85            tx.send(Event::DownloadProgress(span.clone()))
86                .await
87                .unwrap();
88            tx_write.send((span, chunk)).await.unwrap();
89            downloaded += len as u64;
90        }
91        tx.send(Event::Finished(0)).await.unwrap();
92    });
93    Ok(DownloadResult::new(
94        event_chain,
95        handle,
96        Box::new(move || {
97            running_clone.store(false, Ordering::Relaxed);
98        }),
99    ))
100}
101
102#[cfg(test)]
103#[cfg(feature = "file")]
104mod tests {
105    use super::*;
106    use crate::writer::file::SeqFileWriter;
107    use crate::Total;
108    use tempfile::NamedTempFile;
109    use tokio::fs::File;
110    use tokio::io::AsyncReadExt;
111
112    fn build_mock_data(size: usize) -> Vec<u8> {
113        (0..size).map(|i| (i % 256) as u8).collect()
114    }
115
116    #[tokio::test]
117    async fn test_downloads_small_file_correctly() {
118        // 测试 9B 小文件
119        let mock_body = b"test data";
120        let mut server = mockito::Server::new_async().await;
121        let mock = server
122            .mock("GET", "/small")
123            .with_status(200)
124            .with_body(mock_body)
125            .create_async()
126            .await;
127
128        let temp_file = NamedTempFile::new().unwrap();
129        let file = temp_file.reopen().unwrap().into();
130
131        let client = Client::new();
132        let result = download(
133            format!("{}/small", server.url()),
134            SeqFileWriter::new(file, 8 * 1024 * 1024),
135            DownloadOptions {
136                client,
137                retry_gap: Duration::from_secs(1),
138            },
139        )
140        .await
141        .unwrap();
142
143        let mut progress_events = Vec::new();
144        let mut rx = result.event_chain.lock().await;
145        while let Some(event) = rx.recv().await {
146            progress_events.push(event);
147        }
148        dbg!(&progress_events);
149        result.join().await.unwrap();
150
151        let mut file_content = Vec::new();
152        File::open(temp_file.path())
153            .await
154            .unwrap()
155            .read_to_end(&mut file_content)
156            .await
157            .unwrap();
158        assert_eq!(file_content, mock_body);
159
160        assert_eq!(
161            progress_events
162                .iter()
163                .map(|m| if let Event::DownloadProgress(p) = m {
164                    p.total()
165                } else {
166                    0
167                })
168                .sum::<u64>(),
169            mock_body.len() as u64
170        );
171        assert_eq!(
172            progress_events
173                .iter()
174                .map(|m| if let Event::WriteProgress(p) = m {
175                    p.total()
176                } else {
177                    0
178                })
179                .sum::<u64>(),
180            mock_body.len() as u64
181        );
182        mock.assert_async().await;
183    }
184
185    #[tokio::test]
186    async fn test_downloads_empty_file_correctly() {
187        // 测试空文件下载
188        let mock_body = b"";
189        let mut server = mockito::Server::new_async().await;
190        let mock = server
191            .mock("GET", "/empty")
192            .with_status(200)
193            .with_body(mock_body)
194            .create_async()
195            .await;
196
197        let temp_file = NamedTempFile::new().unwrap();
198        let file = temp_file.reopen().unwrap().into();
199
200        let client = Client::new();
201        let result = download(
202            format!("{}/empty", server.url()),
203            SeqFileWriter::new(file, 8 * 1024 * 1024),
204            DownloadOptions {
205                client,
206                retry_gap: Duration::from_secs(1),
207            },
208        )
209        .await
210        .unwrap();
211
212        let mut progress_events = Vec::new();
213        let mut rx = result.event_chain.lock().await;
214        while let Some(event) = rx.recv().await {
215            progress_events.push(event);
216        }
217        dbg!(&progress_events);
218        result.join().await.unwrap();
219
220        // 验证空文件
221        let mut file_content = Vec::new();
222        File::open(temp_file.path())
223            .await
224            .unwrap()
225            .read_to_end(&mut file_content)
226            .await
227            .unwrap();
228        assert!(file_content.is_empty());
229
230        // 验证无进度事件
231        assert_eq!(
232            progress_events
233                .iter()
234                .map(|m| if let Event::DownloadProgress(p) = m {
235                    p.total()
236                } else {
237                    0
238                })
239                .sum::<u64>(),
240            mock_body.len() as u64
241        );
242        assert_eq!(
243            progress_events
244                .iter()
245                .map(|m| if let Event::WriteProgress(p) = m {
246                    p.total()
247                } else {
248                    0
249                })
250                .sum::<u64>(),
251            mock_body.len() as u64
252        );
253        mock.assert_async().await;
254    }
255
256    #[tokio::test]
257    async fn test_downloads_large_file_correctly() {
258        let mock_body = build_mock_data(5000);
259        let mut server = mockito::Server::new_async().await;
260        let mock = server
261            .mock("GET", "/large")
262            .with_status(200)
263            .with_body(&mock_body)
264            .create_async()
265            .await;
266
267        let temp_file = NamedTempFile::new().unwrap();
268        let file = temp_file.reopen().unwrap().into();
269
270        let client = Client::new();
271        let result = download(
272            format!("{}/large", server.url()),
273            SeqFileWriter::new(file, 8 * 1024 * 1024),
274            DownloadOptions {
275                client,
276                retry_gap: Duration::from_secs(1),
277            },
278        )
279        .await
280        .unwrap();
281
282        let mut progress_events = Vec::new();
283        let mut rx = result.event_chain.lock().await;
284        while let Some(event) = rx.recv().await {
285            progress_events.push(event);
286        }
287        dbg!(&progress_events);
288        result.join().await.unwrap();
289
290        // 验证文件内容
291        let mut file_content = Vec::new();
292        File::open(temp_file.path())
293            .await
294            .unwrap()
295            .read_to_end(&mut file_content)
296            .await
297            .unwrap();
298        assert_eq!(file_content, mock_body);
299
300        // 验证进度事件总和
301        assert_eq!(
302            progress_events
303                .iter()
304                .map(|m| if let Event::DownloadProgress(p) = m {
305                    p.total()
306                } else {
307                    0
308                })
309                .sum::<u64>(),
310            mock_body.len() as u64
311        );
312        assert_eq!(
313            progress_events
314                .iter()
315                .map(|m| if let Event::WriteProgress(p) = m {
316                    p.total()
317                } else {
318                    0
319                })
320                .sum::<u64>(),
321            mock_body.len() as u64
322        );
323        mock.assert_async().await;
324    }
325
326    #[tokio::test]
327    async fn test_downloads_exact_buffer_size_file() {
328        let mock_body = build_mock_data(4096);
329        let mut server = mockito::Server::new_async().await;
330        let mock = server
331            .mock("GET", "/exact_buffer_size_file")
332            .with_status(200)
333            .with_body(&mock_body)
334            .create_async()
335            .await;
336
337        let temp_file = NamedTempFile::new().unwrap();
338        let file = temp_file.reopen().unwrap().into();
339
340        let client = Client::new();
341        let result = download(
342            format!("{}/exact_buffer_size_file", server.url()),
343            SeqFileWriter::new(file, 8 * 1024 * 1024),
344            DownloadOptions {
345                client,
346                retry_gap: Duration::from_secs(1),
347            },
348        )
349        .await
350        .unwrap();
351
352        let mut progress_events = Vec::new();
353        let mut rx = result.event_chain.lock().await;
354        while let Some(event) = rx.recv().await {
355            progress_events.push(event);
356        }
357        dbg!(&progress_events);
358        result.join().await.unwrap();
359
360        // 验证文件内容
361        let mut file_content = Vec::new();
362        File::open(temp_file.path())
363            .await
364            .unwrap()
365            .read_to_end(&mut file_content)
366            .await
367            .unwrap();
368        assert_eq!(file_content, mock_body);
369
370        // 验证进度事件完整性
371        assert_eq!(
372            progress_events
373                .iter()
374                .map(|m| if let Event::DownloadProgress(p) = m {
375                    p.total()
376                } else {
377                    0
378                })
379                .sum::<u64>(),
380            mock_body.len() as u64
381        );
382        assert_eq!(
383            progress_events
384                .iter()
385                .map(|m| if let Event::WriteProgress(p) = m {
386                    p.total()
387                } else {
388                    0
389                })
390                .sum::<u64>(),
391            mock_body.len() as u64
392        );
393        mock.assert_async().await;
394    }
395}