oma_fetch/
download.rs

1use crate::{CompressFile, DownloadSource, Event, checksum::ChecksumValidator};
2use std::{
3    fs::Permissions,
4    io::{self, ErrorKind, SeekFrom},
5    os::unix::fs::PermissionsExt,
6    path::Path,
7    time::Duration,
8};
9
10use async_compression::futures::bufread::{BzDecoder, GzipDecoder, XzDecoder, ZstdDecoder};
11use bon::Builder;
12use futures::{AsyncRead, TryStreamExt, io::BufReader};
13use reqwest::{
14    Client, Method, RequestBuilder,
15    header::{ACCEPT_RANGES, CONTENT_LENGTH, HeaderValue, RANGE},
16};
17use snafu::{ResultExt, Snafu};
18use tokio::{
19    fs::{self, File},
20    io::{AsyncBufReadExt as _, AsyncReadExt as _, AsyncSeekExt, AsyncWriteExt},
21    time::timeout,
22};
23
24use tokio_util::compat::{FuturesAsyncReadCompatExt, TokioAsyncReadCompatExt};
25use tracing::debug;
26
27use crate::{DownloadEntry, DownloadSourceType};
28
29const READ_FILE_BUFSIZE: usize = 65536;
30const DOWNLOAD_BUFSIZE: usize = 8192;
31
32#[derive(Snafu, Debug)]
33#[snafu(display("source list is empty"))]
34pub struct EmptySource {
35    file_name: String,
36}
37
38#[derive(Builder)]
39pub(crate) struct SingleDownloader<'a> {
40    client: &'a Client,
41    #[builder(with = |entry: &'a DownloadEntry| -> Result<_, EmptySource> {
42        if entry.source.is_empty() {
43            return Err(EmptySource { file_name: entry.filename.to_string() });
44        } else {
45            return Ok(entry);
46        }
47    })]
48    pub entry: &'a DownloadEntry,
49    progress: (usize, usize),
50    retry_times: usize,
51    msg: Option<String>,
52    download_list_index: usize,
53    file_type: CompressFile,
54    set_permission: Option<u32>,
55    timeout: Duration,
56}
57
58pub enum DownloadResult {
59    Success(SuccessSummary),
60    Failed { file_name: String },
61}
62
63#[derive(Debug)]
64pub struct SuccessSummary {
65    pub file_name: String,
66    pub index: usize,
67    pub wrote: bool,
68}
69
70#[derive(Debug, Snafu)]
71pub enum SingleDownloadError {
72    #[snafu(display("Failed to set permission"))]
73    SetPermission { source: io::Error },
74    #[snafu(display("Failed to open file as rw mode"))]
75    OpenAsWriteMode { source: io::Error },
76    #[snafu(display("Failed to open file"))]
77    Open { source: io::Error },
78    #[snafu(display("Failed to create file"))]
79    Create { source: io::Error },
80    #[snafu(display("Failed to seek file"))]
81    Seek { source: io::Error },
82    #[snafu(display("Failed to write file"))]
83    Write { source: io::Error },
84    #[snafu(display("Failed to flush file"))]
85    Flush { source: io::Error },
86    #[snafu(display("Failed to Remove file"))]
87    Remove { source: io::Error },
88    #[snafu(display("Failed to create symlink"))]
89    CreateSymlink { source: io::Error },
90    #[snafu(display("Request Error"))]
91    ReqwestError { source: reqwest::Error },
92    #[snafu(display("Broken pipe"))]
93    BrokenPipe { source: io::Error },
94    #[snafu(display("Send request timeout"))]
95    SendRequestTimeout,
96    #[snafu(display("Download file timeout"))]
97    DownloadTimeout,
98    #[snafu(display("checksum mismatch"))]
99    ChecksumMismatch,
100}
101
102impl SingleDownloader<'_> {
103    pub(crate) async fn try_download(self, callback: &impl AsyncFn(Event)) -> DownloadResult {
104        let mut sources = self.entry.source.clone();
105        assert!(!sources.is_empty());
106
107        sources.sort_unstable_by(|a, b| b.source_type.cmp(&a.source_type));
108
109        let msg = self.msg.as_deref().unwrap_or(&*self.entry.filename);
110
111        for (index, c) in sources.iter().enumerate() {
112            let download_res = match &c.source_type {
113                DownloadSourceType::Http { auth } => {
114                    self.try_http_download(c, auth, callback).await
115                }
116                DownloadSourceType::Local(as_symlink) => {
117                    self.download_local(c, *as_symlink, callback).await
118                }
119            };
120
121            match download_res {
122                Ok(b) => {
123                    callback(Event::DownloadDone {
124                        index: self.download_list_index,
125                        msg: msg.into(),
126                    })
127                    .await;
128
129                    return DownloadResult::Success(SuccessSummary {
130                        file_name: self.entry.filename.to_string(),
131                        index: self.download_list_index,
132                        wrote: b,
133                    });
134                }
135                Err(e) => {
136                    if index == sources.len() - 1 {
137                        callback(Event::Failed {
138                            file_name: self.entry.filename.clone(),
139                            error: e,
140                        })
141                        .await;
142
143                        return DownloadResult::Failed {
144                            file_name: self.entry.filename.to_string(),
145                        };
146                    }
147
148                    callback(Event::NextUrl {
149                        index: self.download_list_index,
150                        file_name: self.entry.filename.to_string(),
151                        err: e,
152                    })
153                    .await;
154                }
155            }
156        }
157
158        unreachable!()
159    }
160
161    /// Download file with retry (http)
162    async fn try_http_download(
163        &self,
164        source: &DownloadSource,
165        auth: &Option<(String, String)>,
166        callback: &impl AsyncFn(Event),
167    ) -> Result<bool, SingleDownloadError> {
168        let mut times = 1;
169        let mut allow_resume = self.entry.allow_resume;
170        loop {
171            match self
172                .http_download(allow_resume, source, auth, callback)
173                .await
174            {
175                Ok(s) => {
176                    return Ok(s);
177                }
178                Err(e) => match e {
179                    SingleDownloadError::ChecksumMismatch => {
180                        if self.retry_times == times {
181                            return Err(e);
182                        }
183
184                        if times > 1 {
185                            callback(Event::ChecksumMismatch {
186                                index: self.download_list_index,
187                                filename: self.entry.filename.to_string(),
188                                times,
189                            })
190                            .await;
191                        }
192
193                        times += 1;
194                        allow_resume = false;
195                    }
196                    e => {
197                        return Err(e);
198                    }
199                },
200            }
201        }
202    }
203
204    async fn http_download(
205        &self,
206        allow_resume: bool,
207        source: &DownloadSource,
208        auth: &Option<(String, String)>,
209        callback: &impl AsyncFn(Event),
210    ) -> Result<bool, SingleDownloadError> {
211        let file = self.entry.dir.join(&*self.entry.filename);
212        let file_exist = file.exists();
213        let mut file_size = file.metadata().ok().map(|x| x.len()).unwrap_or(0);
214
215        debug!("{} Exist file size is: {file_size}", file.display());
216        debug!("{} download url is: {}", file.display(), source.url);
217        let mut dest = None;
218        let mut validator = None;
219
220        if file.is_symlink() {
221            tokio::fs::remove_file(&file).await.context(RemoveSnafu)?;
222        }
223
224        // 如果要下载的文件已经存在,则验证 Checksum 是否正确,若正确则添加总进度条的进度,并返回
225        // 如果不存在,则继续往下走
226        if file_exist && !file.is_symlink() {
227            debug!("File: {} exists", self.entry.filename);
228
229            self.set_permission_with_path(&file).await?;
230
231            if let Some(hash) = &self.entry.hash {
232                debug!("Hash exist! It is: {}", hash);
233
234                let mut f = tokio::fs::OpenOptions::new()
235                    .write(true)
236                    .read(true)
237                    .open(&file)
238                    .await
239                    .context(OpenAsWriteModeSnafu)?;
240
241                debug!(
242                    "oma opened file: {} with write and read mode",
243                    self.entry.filename
244                );
245
246                let mut v = hash.get_validator();
247
248                debug!("Validator created.");
249
250                let (read, finish) = checksum(callback, &mut f, &mut v).await;
251
252                if finish {
253                    debug!(
254                        "{} checksum success, no need to download anything.",
255                        self.entry.filename
256                    );
257
258                    callback(Event::ProgressDone(self.download_list_index)).await;
259
260                    return Ok(false);
261                }
262
263                debug!(
264                    "checksum fail, will download this file: {}",
265                    self.entry.filename
266                );
267
268                if !allow_resume {
269                    callback(Event::GlobalProgressSub(read)).await;
270                } else {
271                    dest = Some(f);
272                    validator = Some(v);
273                }
274            }
275        }
276
277        let msg = self.progress_msg();
278        callback(Event::NewProgressSpinner {
279            index: self.download_list_index,
280            msg: msg.clone(),
281        })
282        .await;
283
284        let req = self.build_request_with_basic_auth(&source.url, Method::HEAD, auth);
285
286        let resp_head = timeout(self.timeout, req.send()).await;
287
288        callback(Event::ProgressDone(self.download_list_index)).await;
289
290        let resp_head = match resp_head {
291            Ok(Ok(resp)) => resp,
292            Ok(Err(e)) => {
293                return Err(SingleDownloadError::ReqwestError { source: e });
294            }
295            Err(_) => {
296                return Err(SingleDownloadError::SendRequestTimeout);
297            }
298        };
299
300        let head = resp_head.headers();
301
302        // 看看头是否有 ACCEPT_RANGES 这个变量
303        // 如果有,而且值不为 none,则可以断点续传
304        // 反之,则不能断点续传
305        let mut can_resume = match head.get(ACCEPT_RANGES) {
306            Some(x) if x == "none" => false,
307            Some(_) => true,
308            None => false,
309        };
310
311        debug!("Can resume? {can_resume}");
312
313        // 从服务器获取文件的总大小
314        let total_size = {
315            let total_size = head
316                .get(CONTENT_LENGTH)
317                .map(|x| x.to_owned())
318                .unwrap_or(HeaderValue::from(0));
319
320            total_size
321                .to_str()
322                .ok()
323                .and_then(|x| x.parse::<u64>().ok())
324                .unwrap_or_default()
325        };
326
327        debug!("File total size is: {total_size}");
328
329        let mut req = self.build_request_with_basic_auth(&source.url, Method::GET, auth);
330
331        if can_resume && allow_resume {
332            // 如果已存在的文件大小大于或等于要下载的文件,则重置文件大小,重新下载
333            // 因为已经走过一次 chekcusm 了,函数走到这里,则说明肯定文件完整性不对
334            if total_size <= file_size {
335                debug!("Exist file size is reset to 0, because total size <= exist file size");
336                callback(Event::GlobalProgressSub(file_size)).await;
337                file_size = 0;
338                can_resume = false;
339            }
340
341            // 发送 RANGE 的头,传入的是已经下载的文件的大小
342            debug!("oma will set header range as bytes={file_size}-");
343            req = req.header(RANGE, format!("bytes={}-", file_size));
344        }
345
346        debug!("Can resume? {can_resume}");
347
348        let resp = timeout(self.timeout, req.send()).await;
349
350        callback(Event::ProgressDone(self.download_list_index)).await;
351
352        let resp = match resp {
353            Ok(resp) => resp
354                .and_then(|resp| resp.error_for_status())
355                .context(ReqwestSnafu)?,
356            Err(_) => return Err(SingleDownloadError::SendRequestTimeout),
357        };
358
359        callback(Event::NewProgressBar {
360            index: self.download_list_index,
361            msg,
362            size: total_size,
363        })
364        .await;
365
366        let source = resp;
367
368        let hash = &self.entry.hash;
369
370        let mut self_progress = 0;
371        let (mut dest, mut validator) = if !can_resume || !allow_resume {
372            // 如果不能 resume,则使用创建模式
373            debug!(
374                "oma will open file: {} as create mode.",
375                self.entry.filename
376            );
377
378            let f = match File::create(&file).await {
379                Ok(f) => f,
380                Err(e) => {
381                    callback(Event::ProgressDone(self.download_list_index)).await;
382                    return Err(SingleDownloadError::Create { source: e });
383                }
384            };
385
386            if file_size > 0 {
387                callback(Event::GlobalProgressSub(file_size)).await;
388            }
389
390            if let Err(e) = f.set_len(0).await {
391                callback(Event::ProgressDone(self.download_list_index)).await;
392                return Err(SingleDownloadError::Create { source: e });
393            }
394
395            self.set_permission(&f).await?;
396
397            (f, hash.as_ref().map(|hash| hash.get_validator()))
398        } else if let Some((dest, validator)) = dest.zip(validator) {
399            callback(Event::ProgressInc {
400                index: self.download_list_index,
401                size: file_size,
402            })
403            .await;
404
405            debug!(
406                "oma will re use opened dest file for {}",
407                self.entry.filename
408            );
409            self_progress += file_size;
410
411            (dest, Some(validator))
412        } else {
413            debug!(
414                "oma will open file: {} as create mode.",
415                self.entry.filename
416            );
417
418            let f = match File::create(&file).await {
419                Ok(f) => f,
420                Err(e) => {
421                    callback(Event::ProgressDone(self.download_list_index)).await;
422                    return Err(SingleDownloadError::Create { source: e });
423                }
424            };
425
426            if let Err(e) = f.set_len(0).await {
427                callback(Event::ProgressDone(self.download_list_index)).await;
428                return Err(SingleDownloadError::Create { source: e });
429            }
430
431            self.set_permission(&f).await?;
432
433            (f, hash.as_ref().map(|hash| hash.get_validator()))
434        };
435
436        if can_resume && allow_resume {
437            // 把文件指针移动到末尾
438            debug!("oma will seek file: {} to end", self.entry.filename);
439            if let Err(e) = dest.seek(SeekFrom::End(0)).await {
440                callback(Event::ProgressDone(self.download_list_index)).await;
441                return Err(SingleDownloadError::Seek { source: e });
442            }
443        }
444        // 下载!
445        debug!("Start download!");
446
447        let bytes_stream = source
448            .bytes_stream()
449            .map_err(|e| io::Error::new(ErrorKind::Other, e))
450            .into_async_read();
451
452        let reader: &mut (dyn AsyncRead + Unpin + Send) = match self.file_type {
453            CompressFile::Xz => &mut XzDecoder::new(BufReader::new(bytes_stream)),
454            CompressFile::Gzip => &mut GzipDecoder::new(BufReader::new(bytes_stream)),
455            CompressFile::Bz2 => &mut BzDecoder::new(BufReader::new(bytes_stream)),
456            CompressFile::Nothing => &mut BufReader::new(bytes_stream),
457            CompressFile::Zstd => &mut ZstdDecoder::new(BufReader::new(bytes_stream)),
458        };
459
460        let mut reader = reader.compat();
461
462        let mut buf = vec![0u8; DOWNLOAD_BUFSIZE];
463
464        loop {
465            let size = match timeout(self.timeout, reader.read(&mut buf[..])).await {
466                Ok(Ok(size)) => size,
467                Ok(Err(e)) => {
468                    callback(Event::ProgressDone(self.download_list_index)).await;
469                    return Err(SingleDownloadError::BrokenPipe { source: e });
470                }
471                Err(_) => {
472                    callback(Event::ProgressDone(self.download_list_index)).await;
473                    return Err(SingleDownloadError::DownloadTimeout);
474                }
475            };
476
477            if size == 0 {
478                break;
479            }
480
481            if let Err(e) = dest.write_all(&buf[..size]).await {
482                callback(Event::ProgressDone(self.download_list_index)).await;
483                return Err(SingleDownloadError::Write { source: e });
484            }
485
486            callback(Event::ProgressInc {
487                index: self.download_list_index,
488                size: size as u64,
489            })
490            .await;
491
492            self_progress += size as u64;
493
494            callback(Event::GlobalProgressAdd(size as u64)).await;
495
496            if let Some(ref mut v) = validator {
497                v.update(&buf[..size]);
498            }
499        }
500
501        // 下载完成,告诉运行时不再写这个文件了
502        debug!("Download complete! shutting down dest file stream ...");
503        if let Err(e) = dest.shutdown().await {
504            callback(Event::ProgressDone(self.download_list_index)).await;
505            return Err(SingleDownloadError::Flush { source: e });
506        }
507
508        // 最后看看 checksum 验证是否通过
509        if let Some(v) = validator {
510            if !v.finish() {
511                debug!("checksum fail: {}", self.entry.filename);
512                debug!("{self_progress}");
513
514                callback(Event::GlobalProgressSub(self_progress)).await;
515                callback(Event::ProgressDone(self.download_list_index)).await;
516                return Err(SingleDownloadError::ChecksumMismatch);
517            }
518
519            debug!("checksum success: {}", self.entry.filename);
520        }
521
522        callback(Event::ProgressDone(self.download_list_index)).await;
523
524        Ok(true)
525    }
526
527    async fn set_permission(&self, f: &File) -> Result<(), SingleDownloadError> {
528        if let Some(mode) = self.set_permission {
529            debug!("Setting {} permission to {:#o}", self.entry.filename, mode);
530            f.set_permissions(Permissions::from_mode(mode))
531                .await
532                .context(SetPermissionSnafu)?;
533        }
534
535        Ok(())
536    }
537
538    async fn set_permission_with_path(&self, path: &Path) -> Result<(), SingleDownloadError> {
539        if let Some(mode) = self.set_permission {
540            debug!("Setting {} permission to {:#o}", self.entry.filename, mode);
541
542            fs::set_permissions(path, Permissions::from_mode(mode))
543                .await
544                .context(SetPermissionSnafu)?;
545        }
546
547        Ok(())
548    }
549
550    fn build_request_with_basic_auth(
551        &self,
552        url: &str,
553        method: Method,
554        auth: &Option<(String, String)>,
555    ) -> RequestBuilder {
556        let mut req = self.client.request(method, url);
557
558        if let Some((user, password)) = auth {
559            debug!("auth user: {}", user);
560            req = req.basic_auth(user, Some(password));
561        }
562
563        req
564    }
565
566    fn progress_msg(&self) -> String {
567        let (count, len) = &self.progress;
568        let msg = self.msg.as_deref().unwrap_or(&self.entry.filename);
569        let msg = format!("({count}/{len}) {msg}");
570
571        msg
572    }
573
574    /// Download local source file
575    async fn download_local(
576        &self,
577        source: &DownloadSource,
578        as_symlink: bool,
579        callback: &impl AsyncFn(Event),
580    ) -> Result<bool, SingleDownloadError> {
581        debug!("{:?}", self.entry);
582        let msg = self.progress_msg();
583
584        let url = source.url.strip_prefix("file:").unwrap();
585
586        let url_path = Path::new(url);
587
588        let total_size = tokio::fs::metadata(url_path)
589            .await
590            .context(OpenSnafu)?
591            .len();
592
593        let file = self.entry.dir.join(&*self.entry.filename);
594        if file.is_symlink() || (as_symlink && file.is_file()) {
595            tokio::fs::remove_file(&file).await.context(RemoveSnafu)?;
596        }
597
598        if as_symlink {
599            if let Some(hash) = &self.entry.hash {
600                self.checksum_local(callback, url_path, hash).await?;
601            }
602
603            tokio::fs::symlink(url_path, file)
604                .await
605                .context(CreateSymlinkSnafu)?;
606
607            return Ok(true);
608        }
609
610        callback(Event::NewProgressBar {
611            index: self.download_list_index,
612            msg,
613            size: total_size,
614        })
615        .await;
616
617        debug!("File path is: {}", url_path.display());
618
619        let from = File::open(&url_path).await.context(CreateSnafu)?;
620        let from = tokio::io::BufReader::new(from).compat();
621
622        debug!("Success open file: {}", url_path.display());
623
624        let mut to = File::create(self.entry.dir.join(&*self.entry.filename))
625            .await
626            .context(CreateSnafu)?;
627
628        self.set_permission(&to).await?;
629
630        let reader: &mut (dyn AsyncRead + Unpin + Send) = match self.file_type {
631            CompressFile::Xz => &mut XzDecoder::new(BufReader::new(from)),
632            CompressFile::Gzip => &mut GzipDecoder::new(BufReader::new(from)),
633            CompressFile::Bz2 => &mut BzDecoder::new(BufReader::new(from)),
634            CompressFile::Nothing => &mut BufReader::new(from),
635            CompressFile::Zstd => &mut ZstdDecoder::new(BufReader::new(from)),
636        };
637
638        let mut reader = reader.compat();
639
640        debug!(
641            "Success create file: {}",
642            self.entry.dir.join(&*self.entry.filename).display()
643        );
644
645        let mut v = self.entry.hash.as_ref().map(|v| v.get_validator());
646
647        let mut buf = vec![0u8; READ_FILE_BUFSIZE];
648        let mut self_progress = 0;
649
650        loop {
651            let size = reader.read(&mut buf[..]).await.context(BrokenPipeSnafu)?;
652            self_progress += size;
653
654            if size == 0 {
655                break;
656            }
657
658            to.write_all(&buf[..size]).await.context(WriteSnafu)?;
659
660            callback(Event::ProgressInc {
661                index: self.download_list_index,
662                size: size as u64,
663            })
664            .await;
665
666            if let Some(ref mut v) = v {
667                v.update(&buf[..size]);
668            }
669
670            callback(Event::GlobalProgressAdd(size as u64)).await;
671        }
672
673        if v.is_some_and(|v| !v.finish()) {
674            callback(Event::GlobalProgressSub(self_progress as u64)).await;
675            callback(Event::ProgressDone(self.download_list_index)).await;
676            return Err(SingleDownloadError::ChecksumMismatch);
677        }
678
679        callback(Event::ProgressDone(self.download_list_index)).await;
680
681        Ok(true)
682    }
683
684    async fn checksum_local(
685        &self,
686        callback: &impl AsyncFn(Event),
687        url_path: &Path,
688        hash: &crate::checksum::Checksum,
689    ) -> Result<(), SingleDownloadError> {
690        let mut f = fs::File::open(url_path).await.context(OpenSnafu)?;
691        let (size, finish) = checksum(callback, &mut f, &mut hash.get_validator()).await;
692
693        if !finish {
694            callback(Event::GlobalProgressSub(size)).await;
695            callback(Event::ProgressDone(self.download_list_index)).await;
696            return Err(SingleDownloadError::ChecksumMismatch);
697        }
698
699        Ok(())
700    }
701}
702
703async fn checksum(
704    callback: &impl AsyncFn(Event),
705    f: &mut File,
706    v: &mut ChecksumValidator,
707) -> (u64, bool) {
708    let mut reader = tokio::io::BufReader::with_capacity(READ_FILE_BUFSIZE, f);
709
710    let mut read = 0;
711
712    loop {
713        let Ok(buffer) = reader.fill_buf().await else {
714            debug!("Read file get fk, so re-download it");
715            break;
716        };
717
718        if buffer.is_empty() {
719            break;
720        }
721
722        v.update(buffer);
723
724        callback(Event::GlobalProgressAdd(buffer.len() as u64)).await;
725        read += buffer.len() as u64;
726        let len = buffer.len();
727
728        reader.consume(len);
729    }
730
731    (read, v.finish())
732}