oma_fetch/
download.rs

1use crate::{CompressFile, DownloadSource, Event, checksum::ChecksumValidator};
2use std::{
3    fs::Permissions,
4    future::Future,
5    io::{self, ErrorKind, SeekFrom},
6    os::unix::fs::PermissionsExt,
7    path::Path,
8    time::Duration,
9};
10
11use async_compression::futures::bufread::{BzDecoder, GzipDecoder, XzDecoder, ZstdDecoder};
12use bon::Builder;
13use futures::{AsyncRead, TryStreamExt, io::BufReader};
14use oma_utils::url_no_escape::url_no_escape;
15use reqwest::{
16    Client, Method, RequestBuilder,
17    header::{ACCEPT_RANGES, CONTENT_LENGTH, HeaderValue, RANGE},
18};
19use snafu::{ResultExt, Snafu};
20use tokio::{
21    fs::{self, File},
22    io::{AsyncReadExt as _, AsyncSeekExt, AsyncWriteExt},
23    time::timeout,
24};
25
26use tokio_util::compat::{FuturesAsyncReadCompatExt, TokioAsyncReadCompatExt};
27use tracing::debug;
28
29use crate::{DownloadEntry, DownloadSourceType};
30
31#[derive(Snafu, Debug)]
32#[snafu(display("source list is empty"))]
33pub struct EmptySource {
34    file_name: String,
35}
36
37#[derive(Builder)]
38pub(crate) struct SingleDownloader<'a> {
39    client: &'a Client,
40    #[builder(with = |entry: &'a DownloadEntry| -> Result<_, EmptySource> {
41        if entry.source.is_empty() {
42            return Err(EmptySource { file_name: entry.filename.to_string() });
43        } else {
44            return Ok(entry);
45        }
46    })]
47    pub entry: &'a DownloadEntry,
48    progress: (usize, usize),
49    retry_times: usize,
50    msg: Option<String>,
51    download_list_index: usize,
52    file_type: CompressFile,
53    set_permission: Option<u32>,
54    timeout: Duration,
55}
56
57pub enum DownloadResult {
58    Success(SuccessSummary),
59    Failed { file_name: String },
60}
61
62#[derive(Debug)]
63pub struct SuccessSummary {
64    pub file_name: String,
65    pub index: usize,
66    pub wrote: bool,
67}
68
69#[derive(Debug, Snafu)]
70pub enum SingleDownloadError {
71    #[snafu(display("Failed to set permission"))]
72    SetPermission { source: io::Error },
73    #[snafu(display("Failed to open file as rw mode"))]
74    OpenAsWriteMode { source: io::Error },
75    #[snafu(display("Failed to open file"))]
76    Open { source: io::Error },
77    #[snafu(display("Failed to create file"))]
78    Create { source: io::Error },
79    #[snafu(display("Failed to seek file"))]
80    Seek { source: io::Error },
81    #[snafu(display("Failed to write file"))]
82    Write { source: io::Error },
83    #[snafu(display("Failed to flush file"))]
84    Flush { source: io::Error },
85    #[snafu(display("Failed to Remove file"))]
86    Remove { source: io::Error },
87    #[snafu(display("Failed to create symlink"))]
88    CreateSymlink { source: io::Error },
89    #[snafu(display("Request Error"))]
90    ReqwestError { source: reqwest::Error },
91    #[snafu(display("Broken pipe"))]
92    BrokenPipe { source: io::Error },
93    #[snafu(display("Send request timeout"))]
94    SendRequestTimeout,
95    #[snafu(display("Download file timeout"))]
96    DownloadTimeout,
97    #[snafu(display("checksum mismatch"))]
98    ChecksumMismatch,
99}
100
101impl SingleDownloader<'_> {
102    pub(crate) async fn try_download<F, Fut>(self, callback: &F) -> DownloadResult
103    where
104        F: Fn(Event) -> Fut,
105        Fut: Future<Output = ()>,
106    {
107        let mut sources = self.entry.source.clone();
108        assert!(!sources.is_empty());
109
110        sources.sort_unstable_by(|a, b| b.source_type.cmp(&a.source_type));
111
112        let msg = self.msg.as_deref().unwrap_or(&*self.entry.filename);
113
114        for (index, c) in sources.iter().enumerate() {
115            let download_res = match &c.source_type {
116                DownloadSourceType::Http { auth } => {
117                    self.try_http_download(c, auth, callback).await
118                }
119                DownloadSourceType::Local(as_symlink) => {
120                    self.download_local(c, *as_symlink, callback).await
121                }
122            };
123
124            match download_res {
125                Ok(b) => {
126                    callback(Event::DownloadDone {
127                        index: self.download_list_index,
128                        msg: msg.into(),
129                    })
130                    .await;
131
132                    return DownloadResult::Success(SuccessSummary {
133                        file_name: self.entry.filename.to_string(),
134                        index: self.download_list_index,
135                        wrote: b,
136                    });
137                }
138                Err(e) => {
139                    if index == sources.len() - 1 {
140                        callback(Event::Failed {
141                            file_name: self.entry.filename.clone(),
142                            error: e,
143                        })
144                        .await;
145
146                        return DownloadResult::Failed {
147                            file_name: self.entry.filename.to_string(),
148                        };
149                    }
150
151                    callback(Event::NextUrl {
152                        index: self.download_list_index,
153                        file_name: self.entry.filename.to_string(),
154                        err: e,
155                    })
156                    .await;
157                }
158            }
159        }
160
161        unreachable!()
162    }
163
164    /// Download file with retry (http)
165    async fn try_http_download<F, Fut>(
166        &self,
167        source: &DownloadSource,
168        auth: &Option<(String, String)>,
169        callback: &F,
170    ) -> Result<bool, SingleDownloadError>
171    where
172        F: Fn(Event) -> Fut,
173        Fut: Future<Output = ()>,
174    {
175        let mut times = 1;
176        let mut allow_resume = self.entry.allow_resume;
177        loop {
178            match self
179                .http_download(allow_resume, source, auth, callback)
180                .await
181            {
182                Ok(s) => {
183                    return Ok(s);
184                }
185                Err(e) => match e {
186                    SingleDownloadError::ChecksumMismatch => {
187                        if self.retry_times == times {
188                            return Err(e);
189                        }
190
191                        if times > 1 {
192                            callback(Event::ChecksumMismatch {
193                                index: self.download_list_index,
194                                filename: self.entry.filename.to_string(),
195                                times,
196                            })
197                            .await;
198                        }
199
200                        times += 1;
201                        allow_resume = false;
202                    }
203                    e => {
204                        return Err(e);
205                    }
206                },
207            }
208        }
209    }
210
211    async fn http_download<F, Fut>(
212        &self,
213        allow_resume: bool,
214        source: &DownloadSource,
215        auth: &Option<(String, String)>,
216        callback: &F,
217    ) -> Result<bool, SingleDownloadError>
218    where
219        F: Fn(Event) -> Fut,
220        Fut: Future<Output = ()>,
221    {
222        let file = self.entry.dir.join(&*self.entry.filename);
223        let file_exist = file.exists();
224        let mut file_size = file.metadata().ok().map(|x| x.len()).unwrap_or(0);
225
226        debug!("{} Exist file size is: {file_size}", file.display());
227        debug!("{} download url is: {}", file.display(), source.url);
228        let mut dest = None;
229        let mut validator = None;
230
231        if file.is_symlink() {
232            tokio::fs::remove_file(&file).await.context(RemoveSnafu)?;
233        }
234
235        // 如果要下载的文件已经存在,则验证 Checksum 是否正确,若正确则添加总进度条的进度,并返回
236        // 如果不存在,则继续往下走
237        if file_exist && !file.is_symlink() {
238            debug!("File: {} exists", self.entry.filename);
239
240            self.set_permission_with_path(&file).await?;
241
242            if let Some(hash) = &self.entry.hash {
243                debug!("Hash exist! It is: {}", hash);
244
245                let mut f = tokio::fs::OpenOptions::new()
246                    .write(true)
247                    .read(true)
248                    .open(&file)
249                    .await
250                    .context(OpenAsWriteModeSnafu)?;
251
252                debug!(
253                    "oma opened file: {} with write and read mode",
254                    self.entry.filename
255                );
256
257                let mut v = hash.get_validator();
258
259                debug!("Validator created.");
260
261                let (read, finish) = checksum(callback, file_size, &mut f, &mut v).await;
262
263                if finish {
264                    debug!(
265                        "{} checksum success, no need to download anything.",
266                        self.entry.filename
267                    );
268
269                    callback(Event::ProgressDone(self.download_list_index)).await;
270
271                    return Ok(false);
272                }
273
274                debug!(
275                    "checksum fail, will download this file: {}",
276                    self.entry.filename
277                );
278
279                if !allow_resume {
280                    callback(Event::GlobalProgressSub(read)).await;
281                } else {
282                    dest = Some(f);
283                    validator = Some(v);
284                }
285            }
286        }
287
288        let msg = self.progress_msg();
289        callback(Event::NewProgressSpinner {
290            index: self.download_list_index,
291            msg: msg.clone(),
292        })
293        .await;
294
295        let req = self.build_request_with_basic_auth(&source.url, Method::HEAD, auth);
296
297        let resp_head = timeout(self.timeout, req.send()).await;
298
299        callback(Event::ProgressDone(self.download_list_index)).await;
300
301        let resp_head = match resp_head {
302            Ok(Ok(resp)) => resp,
303            Ok(Err(e)) => {
304                return Err(SingleDownloadError::ReqwestError { source: e });
305            }
306            Err(_) => {
307                return Err(SingleDownloadError::SendRequestTimeout);
308            }
309        };
310
311        let head = resp_head.headers();
312
313        // 看看头是否有 ACCEPT_RANGES 这个变量
314        // 如果有,而且值不为 none,则可以断点续传
315        // 反之,则不能断点续传
316        let mut can_resume = match head.get(ACCEPT_RANGES) {
317            Some(x) if x == "none" => false,
318            Some(_) => true,
319            None => false,
320        };
321
322        debug!("Can resume? {can_resume}");
323
324        // 从服务器获取文件的总大小
325        let total_size = {
326            let total_size = head
327                .get(CONTENT_LENGTH)
328                .map(|x| x.to_owned())
329                .unwrap_or(HeaderValue::from(0));
330
331            total_size
332                .to_str()
333                .ok()
334                .and_then(|x| x.parse::<u64>().ok())
335                .unwrap_or_default()
336        };
337
338        debug!("File total size is: {total_size}");
339
340        let mut req = self.build_request_with_basic_auth(&source.url, Method::GET, auth);
341
342        if can_resume && allow_resume {
343            // 如果已存在的文件大小大于或等于要下载的文件,则重置文件大小,重新下载
344            // 因为已经走过一次 chekcusm 了,函数走到这里,则说明肯定文件完整性不对
345            if total_size <= file_size {
346                debug!("Exist file size is reset to 0, because total size <= exist file size");
347                callback(Event::GlobalProgressSub(file_size)).await;
348                file_size = 0;
349                can_resume = false;
350            }
351
352            // 发送 RANGE 的头,传入的是已经下载的文件的大小
353            debug!("oma will set header range as bytes={file_size}-");
354            req = req.header(RANGE, format!("bytes={}-", file_size));
355        }
356
357        debug!("Can resume? {can_resume}");
358
359        let resp = timeout(self.timeout, req.send()).await;
360
361        callback(Event::ProgressDone(self.download_list_index)).await;
362
363        let resp = match resp {
364            Ok(resp) => resp
365                .and_then(|resp| resp.error_for_status())
366                .context(ReqwestSnafu)?,
367            Err(_) => return Err(SingleDownloadError::SendRequestTimeout),
368        };
369
370        callback(Event::NewProgressBar {
371            index: self.download_list_index,
372            msg,
373            size: total_size,
374        })
375        .await;
376
377        let source = resp;
378
379        // 初始化 checksum 验证器
380        // 如果文件存在,则 checksum 验证器已经初试过一次,因此进度条加已经验证过的文件大小
381        let hash = &self.entry.hash;
382        let mut validator = if let Some(v) = validator {
383            callback(Event::ProgressInc {
384                index: self.download_list_index,
385                size: file_size,
386            })
387            .await;
388
389            Some(v)
390        } else {
391            hash.as_ref().map(|hash| hash.get_validator())
392        };
393
394        let mut self_progress = 0;
395        let mut dest = if !can_resume || !allow_resume {
396            // 如果不能 resume,则使用创建模式
397            debug!(
398                "oma will open file: {} as create mode.",
399                self.entry.filename
400            );
401
402            let f = match File::create(&file).await {
403                Ok(f) => f,
404                Err(e) => {
405                    callback(Event::ProgressDone(self.download_list_index)).await;
406                    return Err(SingleDownloadError::Create { source: e });
407                }
408            };
409
410            if let Err(e) = f.set_len(0).await {
411                callback(Event::ProgressDone(self.download_list_index)).await;
412                return Err(SingleDownloadError::Create { source: e });
413            }
414
415            self.set_permission(&f).await?;
416
417            f
418        } else if let Some(dest) = dest {
419            debug!(
420                "oma will re use opened dest file for {}",
421                self.entry.filename
422            );
423            self_progress += file_size;
424
425            dest
426        } else {
427            debug!(
428                "oma will open file: {} as create mode.",
429                self.entry.filename
430            );
431
432            let f = match File::create(&file).await {
433                Ok(f) => f,
434                Err(e) => {
435                    callback(Event::ProgressDone(self.download_list_index)).await;
436                    return Err(SingleDownloadError::Create { source: e });
437                }
438            };
439
440            if let Err(e) = f.set_len(0).await {
441                callback(Event::ProgressDone(self.download_list_index)).await;
442                return Err(SingleDownloadError::Create { source: e });
443            }
444
445            self.set_permission(&f).await?;
446
447            f
448        };
449
450        if can_resume && allow_resume {
451            // 把文件指针移动到末尾
452            debug!("oma will seek file: {} to end", self.entry.filename);
453            if let Err(e) = dest.seek(SeekFrom::End(0)).await {
454                callback(Event::ProgressDone(self.download_list_index)).await;
455                return Err(SingleDownloadError::Seek { source: e });
456            }
457        }
458        // 下载!
459        debug!("Start download!");
460
461        let bytes_stream = source
462            .bytes_stream()
463            .map_err(|e| io::Error::new(ErrorKind::Other, e))
464            .into_async_read();
465
466        let reader: &mut (dyn AsyncRead + Unpin + Send) = match self.file_type {
467            CompressFile::Xz => &mut XzDecoder::new(BufReader::new(bytes_stream)),
468            CompressFile::Gzip => &mut GzipDecoder::new(BufReader::new(bytes_stream)),
469            CompressFile::Bz2 => &mut BzDecoder::new(BufReader::new(bytes_stream)),
470            CompressFile::Nothing => &mut BufReader::new(bytes_stream),
471            CompressFile::Zstd => &mut ZstdDecoder::new(BufReader::new(bytes_stream)),
472        };
473
474        let mut reader = reader.compat();
475
476        let mut buf = vec![0u8; 8 * 1024];
477
478        loop {
479            let size = match timeout(self.timeout, reader.read(&mut buf[..])).await {
480                Ok(Ok(size)) => size,
481                Ok(Err(e)) => {
482                    callback(Event::ProgressDone(self.download_list_index)).await;
483                    return Err(SingleDownloadError::BrokenPipe { source: e });
484                }
485                Err(_) => {
486                    callback(Event::ProgressDone(self.download_list_index)).await;
487                    return Err(SingleDownloadError::DownloadTimeout);
488                }
489            };
490
491            if size == 0 {
492                break;
493            }
494
495            if let Err(e) = dest.write_all(&buf[..size]).await {
496                callback(Event::ProgressDone(self.download_list_index)).await;
497                return Err(SingleDownloadError::Write { source: e });
498            }
499
500            callback(Event::ProgressInc {
501                index: self.download_list_index,
502                size: size as u64,
503            })
504            .await;
505
506            self_progress += size as u64;
507
508            callback(Event::GlobalProgressAdd(size as u64)).await;
509
510            if let Some(ref mut v) = validator {
511                v.update(&buf[..size]);
512            }
513        }
514
515        // 下载完成,告诉运行时不再写这个文件了
516        debug!("Download complete! shutting down dest file stream ...");
517        if let Err(e) = dest.shutdown().await {
518            callback(Event::ProgressDone(self.download_list_index)).await;
519            return Err(SingleDownloadError::Flush { source: e });
520        }
521
522        // 最后看看 checksum 验证是否通过
523        if let Some(v) = validator {
524            if !v.finish() {
525                debug!("checksum fail: {}", self.entry.filename);
526                debug!("{self_progress}");
527
528                callback(Event::GlobalProgressSub(self_progress)).await;
529                callback(Event::ProgressDone(self.download_list_index)).await;
530                return Err(SingleDownloadError::ChecksumMismatch);
531            }
532
533            debug!("checksum success: {}", self.entry.filename);
534        }
535
536        callback(Event::ProgressDone(self.download_list_index)).await;
537
538        Ok(true)
539    }
540
541    async fn set_permission(&self, f: &File) -> Result<(), SingleDownloadError> {
542        if let Some(mode) = self.set_permission {
543            debug!("Setting {} permission to {:#o}", self.entry.filename, mode);
544            f.set_permissions(Permissions::from_mode(mode))
545                .await
546                .context(SetPermissionSnafu)?;
547        }
548
549        Ok(())
550    }
551
552    async fn set_permission_with_path(&self, path: &Path) -> Result<(), SingleDownloadError> {
553        if let Some(mode) = self.set_permission {
554            debug!("Setting {} permission to {:#o}", self.entry.filename, mode);
555
556            fs::set_permissions(path, Permissions::from_mode(mode))
557                .await
558                .context(SetPermissionSnafu)?;
559        }
560
561        Ok(())
562    }
563
564    fn build_request_with_basic_auth(
565        &self,
566        url: &str,
567        method: Method,
568        auth: &Option<(String, String)>,
569    ) -> RequestBuilder {
570        let mut req = self.client.request(method, url);
571
572        if let Some((user, password)) = auth {
573            debug!("auth user: {}", user);
574            req = req.basic_auth(user, Some(password));
575        }
576
577        req
578    }
579
580    fn progress_msg(&self) -> String {
581        let (count, len) = &self.progress;
582        let msg = self.msg.as_deref().unwrap_or(&self.entry.filename);
583        let msg = format!("({count}/{len}) {msg}");
584
585        msg
586    }
587
588    /// Download local source file
589    async fn download_local<F, Fut>(
590        &self,
591        source: &DownloadSource,
592        as_symlink: bool,
593        callback: &F,
594    ) -> Result<bool, SingleDownloadError>
595    where
596        F: Fn(Event) -> Fut,
597        Fut: Future<Output = ()>,
598    {
599        debug!("{:?}", self.entry);
600        let msg = self.progress_msg();
601
602        let url = &source.url;
603
604        // 传入的参数不对,应该 panic
605        let url_path = url_no_escape(url.strip_prefix("file:").unwrap());
606
607        let url_path = Path::new(&url_path);
608
609        let total_size = tokio::fs::metadata(url_path)
610            .await
611            .context(OpenSnafu)?
612            .len();
613
614        callback(Event::NewProgressBar {
615            index: self.download_list_index,
616            msg,
617            size: total_size,
618        })
619        .await;
620
621        if as_symlink {
622            if let Some(hash) = &self.entry.hash {
623                self.checksum_local(callback, url_path, total_size, hash)
624                    .await?;
625            }
626
627            let symlink = self.entry.dir.join(&*self.entry.filename);
628            if symlink.exists() {
629                tokio::fs::remove_file(&symlink)
630                    .await
631                    .context(RemoveSnafu)?;
632            }
633
634            tokio::fs::symlink(url_path, symlink)
635                .await
636                .context(CreateSymlinkSnafu)?;
637
638            callback(Event::GlobalProgressAdd(total_size as u64)).await;
639            callback(Event::ProgressDone(self.download_list_index)).await;
640
641            return Ok(true);
642        }
643
644        debug!("File path is: {}", url_path.display());
645
646        let from = File::open(&url_path).await.context(CreateSnafu)?;
647        let from = tokio::io::BufReader::new(from).compat();
648
649        debug!("Success open file: {}", url_path.display());
650
651        let mut to = File::create(self.entry.dir.join(&*self.entry.filename))
652            .await
653            .context(CreateSnafu)?;
654
655        self.set_permission(&to).await?;
656
657        let reader: &mut (dyn AsyncRead + Unpin + Send) = match self.file_type {
658            CompressFile::Xz => &mut XzDecoder::new(BufReader::new(from)),
659            CompressFile::Gzip => &mut GzipDecoder::new(BufReader::new(from)),
660            CompressFile::Bz2 => &mut BzDecoder::new(BufReader::new(from)),
661            CompressFile::Nothing => &mut BufReader::new(from),
662            CompressFile::Zstd => &mut ZstdDecoder::new(BufReader::new(from)),
663        };
664
665        let mut reader = reader.compat();
666
667        debug!(
668            "Success create file: {}",
669            self.entry.dir.join(&*self.entry.filename).display()
670        );
671
672        let mut buf = vec![0u8; 8 * 1024];
673
674        loop {
675            let size = reader.read(&mut buf[..]).await.context(BrokenPipeSnafu)?;
676
677            if size == 0 {
678                break;
679            }
680
681            to.write_all(&buf[..size]).await.context(WriteSnafu)?;
682
683            callback(Event::ProgressInc {
684                index: self.download_list_index,
685                size: size as u64,
686            })
687            .await;
688
689            callback(Event::GlobalProgressAdd(size as u64)).await;
690        }
691
692        callback(Event::ProgressDone(self.download_list_index)).await;
693
694        Ok(true)
695    }
696
697    async fn checksum_local<F, Fut>(
698        &self,
699        callback: &F,
700        url_path: &Path,
701        total_size: u64,
702        hash: &crate::checksum::Checksum,
703    ) -> Result<(), SingleDownloadError>
704    where
705        F: Fn(Event) -> Fut,
706        Fut: Future<Output = ()>,
707    {
708        let mut f = fs::File::open(url_path).await.context(OpenSnafu)?;
709        let (size, finish) =
710            checksum(callback, total_size, &mut f, &mut hash.get_validator()).await;
711
712        if !finish {
713            callback(Event::GlobalProgressSub(size)).await;
714            callback(Event::ProgressDone(self.download_list_index)).await;
715            return Err(SingleDownloadError::ChecksumMismatch);
716        }
717
718        Ok(())
719    }
720}
721
722async fn checksum<F, Fut>(
723    callback: &F,
724    file_size: u64,
725    f: &mut File,
726    v: &mut ChecksumValidator,
727) -> (u64, bool)
728where
729    F: Fn(Event) -> Fut,
730    Fut: Future<Output = ()>,
731{
732    let mut buf = vec![0; 8192];
733    let mut read = 0;
734
735    loop {
736        if read == file_size {
737            break;
738        }
739
740        let Ok(read_count) = f.read(&mut buf[..]).await else {
741            debug!("Read file get fk, so re-download it");
742            break;
743        };
744
745        v.update(&buf[..read_count]);
746
747        callback(Event::GlobalProgressAdd(read_count as u64)).await;
748
749        read += read_count as u64;
750    }
751
752    (read, v.finish())
753}