oma_fetch/
download.rs

1use crate::{CompressFile, DownloadSource, Event, checksum::ChecksumValidator, send_request};
2use std::{
3    borrow::Cow,
4    io::{self, SeekFrom},
5    path::Path,
6    time::Duration,
7};
8
9use async_compression::futures::bufread::{
10    BzDecoder, GzipDecoder, Lz4Decoder, LzmaDecoder, XzDecoder, ZstdDecoder,
11};
12use bon::bon;
13use futures::{AsyncRead, TryStreamExt, io::BufReader};
14use reqwest::{
15    Client, Method, RequestBuilder,
16    header::{ACCEPT_RANGES, CONTENT_LENGTH, HeaderValue, RANGE},
17};
18use snafu::{ResultExt, Snafu};
19use tokio::{
20    fs::{self, File},
21    io::{AsyncBufReadExt as _, AsyncReadExt as _, AsyncSeekExt, AsyncWriteExt},
22    time::timeout,
23};
24
25use tokio_util::compat::{FuturesAsyncReadCompatExt, TokioAsyncReadCompatExt};
26use tracing::{debug, trace};
27
28use crate::{DownloadEntry, DownloadSourceType};
29
30const READ_FILE_BUFSIZE: usize = 65536;
31const DOWNLOAD_BUFSIZE: usize = 8192;
32
33#[derive(Debug, Snafu)]
34pub enum BuilderError {
35    #[snafu(display("Download task {file_name} sources is empty"))]
36    EmptySource { file_name: String },
37    #[snafu(display("Not allow set illegal download threads: {count}"))]
38    IllegalDownloadThread { count: usize },
39}
40
41pub(crate) struct SingleDownloader<'a> {
42    client: &'a Client,
43    pub entry: &'a DownloadEntry,
44    total: usize,
45    retry_times: usize,
46    msg: Option<Cow<'static, str>>,
47    download_list_index: usize,
48    file_type: CompressFile,
49    timeout: Duration,
50}
51
52pub enum DownloadResult {
53    Success(SuccessSummary),
54    Failed { file_name: String },
55}
56
57#[derive(Debug)]
58pub struct SuccessSummary {
59    pub file_name: String,
60    pub index: usize,
61    pub wrote: bool,
62    pub url: String,
63}
64
65#[derive(Debug, Snafu)]
66pub enum SingleDownloadError {
67    #[snafu(display("Failed to set permission"))]
68    SetPermission { source: io::Error },
69    #[snafu(display("Failed to open file as rw mode"))]
70    OpenAsWriteMode { source: io::Error },
71    #[snafu(display("Failed to open file"))]
72    Open { source: io::Error },
73    #[snafu(display("Failed to create file"))]
74    Create { source: io::Error },
75    #[snafu(display("Failed to seek file"))]
76    Seek { source: io::Error },
77    #[snafu(display("Failed to write file"))]
78    Write { source: io::Error },
79    #[snafu(display("Failed to flush file"))]
80    Flush { source: io::Error },
81    #[snafu(display("Failed to Remove file"))]
82    Remove { source: io::Error },
83    #[snafu(display("Failed to create symlink"))]
84    CreateSymlink { source: io::Error },
85    #[snafu(display("Request Error"))]
86    ReqwestError { source: reqwest::Error },
87    #[snafu(display("Broken pipe"))]
88    BrokenPipe { source: io::Error },
89    #[snafu(display("Send request timeout"))]
90    SendRequestTimeout,
91    #[snafu(display("Download file timeout"))]
92    DownloadTimeout,
93    #[snafu(display("checksum mismatch"))]
94    ChecksumMismatch,
95}
96
97#[bon]
98impl<'a> SingleDownloader<'a> {
99    #[builder]
100    pub(crate) fn new(
101        client: &'a Client,
102        entry: &'a DownloadEntry,
103        total: usize,
104        retry_times: usize,
105        msg: Option<Cow<'static, str>>,
106        download_list_index: usize,
107        file_type: CompressFile,
108        timeout: Duration,
109    ) -> Result<SingleDownloader<'a>, BuilderError> {
110        if entry.source.is_empty() {
111            return Err(BuilderError::EmptySource {
112                file_name: entry.filename.to_string(),
113            });
114        }
115
116        Ok(Self {
117            client,
118            entry,
119            total,
120            retry_times,
121            msg,
122            download_list_index,
123            file_type,
124            timeout,
125        })
126    }
127
128    pub(crate) async fn try_download(self, callback: &impl AsyncFn(Event)) -> DownloadResult {
129        let mut sources = self.entry.source.clone();
130        assert!(!sources.is_empty());
131
132        sources.sort_unstable_by(|a, b| b.source_type.cmp(&a.source_type));
133
134        let msg = self.msg.as_deref().unwrap_or(&*self.entry.filename);
135
136        for (index, c) in sources.iter().enumerate() {
137            let download_res = match &c.source_type {
138                DownloadSourceType::Http { auth } => {
139                    self.try_http_download(c, auth, callback).await
140                }
141                DownloadSourceType::Local(as_symlink) => {
142                    self.download_local(c, *as_symlink, callback).await
143                }
144            };
145
146            match download_res {
147                Ok(b) => {
148                    callback(Event::DownloadDone {
149                        index: self.download_list_index,
150                        msg: msg.into(),
151                    })
152                    .await;
153
154                    return DownloadResult::Success(SuccessSummary {
155                        file_name: self.entry.filename.to_string(),
156                        url: c.url.clone(),
157                        index: self.download_list_index,
158                        wrote: b,
159                    });
160                }
161                Err(e) => {
162                    if index == sources.len() - 1 {
163                        callback(Event::Failed {
164                            file_name: self.entry.filename.clone(),
165                            error: e,
166                        })
167                        .await;
168
169                        return DownloadResult::Failed {
170                            file_name: self.entry.filename.to_string(),
171                        };
172                    }
173
174                    callback(Event::NextUrl {
175                        index: self.download_list_index,
176                        file_name: self.entry.filename.to_string(),
177                        err: e,
178                    })
179                    .await;
180                }
181            }
182        }
183
184        unreachable!()
185    }
186
187    /// Download file with retry (http)
188    async fn try_http_download(
189        &self,
190        source: &DownloadSource,
191        auth: &Option<(String, String)>,
192        callback: &impl AsyncFn(Event),
193    ) -> Result<bool, SingleDownloadError> {
194        let mut times = 1;
195        let mut allow_resume = self.entry.allow_resume;
196        loop {
197            match self
198                .http_download(allow_resume, source, auth, callback)
199                .await
200            {
201                Ok(s) => {
202                    return Ok(s);
203                }
204                Err(e) => match e {
205                    SingleDownloadError::ChecksumMismatch => {
206                        if self.retry_times == times {
207                            return Err(e);
208                        }
209
210                        if times > 1 {
211                            callback(Event::ChecksumMismatch {
212                                index: self.download_list_index,
213                                filename: self.entry.filename.to_string(),
214                                times,
215                            })
216                            .await;
217                        }
218
219                        times += 1;
220                        allow_resume = false;
221                    }
222                    e => {
223                        return Err(e);
224                    }
225                },
226            }
227        }
228    }
229
230    async fn http_download(
231        &self,
232        allow_resume: bool,
233        source: &DownloadSource,
234        auth: &Option<(String, String)>,
235        callback: &impl AsyncFn(Event),
236    ) -> Result<bool, SingleDownloadError> {
237        let file = self.entry.dir.join(&*self.entry.filename);
238        let file_exist = file.exists();
239        let mut file_size = file.metadata().ok().map(|x| x.len()).unwrap_or(0);
240
241        trace!("{} Exist file size is: {file_size}", file.display());
242        trace!("{} download url is: {}", file.display(), source.url);
243        let mut dest = None;
244        let mut validator = None;
245        let is_symlink = file.is_symlink();
246
247        debug!("file {} is symlink = {}", file.display(), is_symlink);
248
249        if is_symlink {
250            tokio::fs::remove_file(&file).await.context(RemoveSnafu)?;
251        }
252
253        // 如果要下载的文件已经存在,则验证 Checksum 是否正确,若正确则添加总进度条的进度,并返回
254        // 如果不存在,则继续往下走
255        if file_exist && !is_symlink {
256            trace!(
257                "File {} already exists, verifying checksum ...",
258                self.entry.filename
259            );
260
261            if let Some(hash) = &self.entry.hash {
262                trace!("Hash {} exists for the existing file.", hash);
263
264                let mut f = tokio::fs::OpenOptions::new()
265                    .write(true)
266                    .read(true)
267                    .open(&file)
268                    .await
269                    .context(OpenAsWriteModeSnafu)?;
270
271                trace!("oma opened file {} read/write.", self.entry.filename);
272
273                let mut v = hash.get_validator();
274
275                trace!("Validator created.");
276
277                let (read, finish) = checksum(callback, &mut f, &mut v).await;
278
279                if finish {
280                    trace!("Checksum {} matches, cache hit!", self.entry.filename);
281
282                    callback(Event::ProgressDone(self.download_list_index)).await;
283
284                    return Ok(false);
285                }
286
287                debug!(
288                    "Checksum mismatch, initiating re-download for file {} ...",
289                    self.entry.filename
290                );
291
292                if !allow_resume {
293                    callback(Event::GlobalProgressSub(read)).await;
294                } else {
295                    dest = Some(f);
296                    validator = Some(v);
297                }
298            }
299        }
300
301        callback(Event::NewProgressSpinner {
302            index: self.download_list_index,
303            msg: self.download_message(),
304            total: self.total,
305        })
306        .await;
307
308        let req = self.build_request_with_basic_auth(&source.url, Method::HEAD, auth);
309        let resp_head = timeout(self.timeout, send_request(&source.url, req)).await;
310
311        callback(Event::ProgressDone(self.download_list_index)).await;
312
313        let resp_head = match resp_head {
314            Ok(Ok(resp)) => resp,
315            Ok(Err(e)) => {
316                return Err(SingleDownloadError::ReqwestError { source: e });
317            }
318            Err(_) => {
319                return Err(SingleDownloadError::SendRequestTimeout);
320            }
321        };
322
323        let head = resp_head.headers();
324
325        // 看看头是否有 ACCEPT_RANGES 这个变量
326        // 如果有,而且值不为 none,则可以断点续传
327        // 反之,则不能断点续传
328        let server_can_resume = match head.get(ACCEPT_RANGES) {
329            Some(x) if x == "none" => false,
330            Some(_) => true,
331            None => false,
332        };
333
334        // 从服务器获取文件的总大小
335        let total_size = {
336            let total_size = head
337                .get(CONTENT_LENGTH)
338                .map(|x| x.to_owned())
339                .unwrap_or(HeaderValue::from(0));
340
341            total_size
342                .to_str()
343                .ok()
344                .and_then(|x| x.parse::<u64>().ok())
345                .unwrap_or_default()
346        };
347
348        trace!("File total size is: {total_size}");
349
350        let mut req = self.build_request_with_basic_auth(&source.url, Method::GET, auth);
351
352        let mut resume = server_can_resume;
353
354        if !allow_resume {
355            resume = false;
356        }
357
358        if server_can_resume && allow_resume {
359            // 如果已存在的文件大小大于或等于要下载的文件,则重置文件大小,重新下载
360            // 因为已经走过一次 chekcusm 了,函数走到这里,则说明肯定文件完整性不对
361            if total_size <= file_size {
362                trace!(
363                    "Resetting size indicator for file to 0, as the file to download is larger that the one that already exists."
364                );
365                callback(Event::GlobalProgressSub(file_size)).await;
366                file_size = 0;
367                resume = false;
368            }
369
370            // 发送 RANGE 的头,传入的是已经下载的文件的大小
371            trace!("oma will set header range as bytes={file_size}-");
372            req = req.header(RANGE, format!("bytes={file_size}-"));
373        }
374
375        debug!("Can resume = {server_can_resume}, will resume = {resume}",);
376
377        let resp = timeout(self.timeout, req.send()).await;
378
379        callback(Event::ProgressDone(self.download_list_index)).await;
380
381        let resp = match resp {
382            Ok(resp) => resp
383                .and_then(|resp| resp.error_for_status())
384                .context(ReqwestSnafu)?,
385            Err(_) => return Err(SingleDownloadError::SendRequestTimeout),
386        };
387
388        callback(Event::NewProgressBar {
389            index: self.download_list_index,
390            msg: self.download_message(),
391            total: self.total,
392            size: total_size,
393        })
394        .await;
395
396        let source = resp;
397
398        let hash = &self.entry.hash;
399
400        let mut self_progress = 0;
401        let (mut dest, mut validator) = if !resume {
402            // 如果不能 resume,则使用创建模式
403            trace!(
404                "oma will open file {} in creation mode.",
405                self.entry.filename
406            );
407
408            let f = match File::create(&file).await {
409                Ok(f) => f,
410                Err(e) => {
411                    callback(Event::ProgressDone(self.download_list_index)).await;
412                    return Err(SingleDownloadError::Create { source: e });
413                }
414            };
415
416            if file_size > 0 {
417                callback(Event::GlobalProgressSub(file_size)).await;
418            }
419
420            if let Err(e) = f.set_len(0).await {
421                callback(Event::ProgressDone(self.download_list_index)).await;
422                return Err(SingleDownloadError::Create { source: e });
423            }
424
425            (f, hash.as_ref().map(|hash| hash.get_validator()))
426        } else if let Some((dest, validator)) = dest.zip(validator) {
427            callback(Event::ProgressInc {
428                index: self.download_list_index,
429                size: file_size,
430            })
431            .await;
432
433            trace!(
434                "oma will re-use opened destination file for {}",
435                self.entry.filename
436            );
437            self_progress += file_size;
438
439            (dest, Some(validator))
440        } else {
441            trace!(
442                "oma will open file {} in creation mode.",
443                self.entry.filename
444            );
445
446            let f = match File::create(&file).await {
447                Ok(f) => f,
448                Err(e) => {
449                    callback(Event::ProgressDone(self.download_list_index)).await;
450                    return Err(SingleDownloadError::Create { source: e });
451                }
452            };
453
454            if let Err(e) = f.set_len(0).await {
455                callback(Event::ProgressDone(self.download_list_index)).await;
456                return Err(SingleDownloadError::Create { source: e });
457            }
458
459            (f, hash.as_ref().map(|hash| hash.get_validator()))
460        };
461
462        if server_can_resume && allow_resume {
463            // 把文件指针移动到末尾
464            trace!("oma will seek to end-of-file for {}", self.entry.filename);
465            if let Err(e) = dest.seek(SeekFrom::End(0)).await {
466                callback(Event::ProgressDone(self.download_list_index)).await;
467                return Err(SingleDownloadError::Seek { source: e });
468            }
469        }
470        // 下载!
471        trace!("Starting download!");
472
473        let bytes_stream = source
474            .bytes_stream()
475            .map_err(io::Error::other)
476            .into_async_read();
477
478        let reader: &mut (dyn AsyncRead + Unpin + Send) = match self.file_type {
479            CompressFile::Xz => &mut XzDecoder::new(BufReader::new(bytes_stream)),
480            CompressFile::Gzip => &mut GzipDecoder::new(BufReader::new(bytes_stream)),
481            CompressFile::Bz2 => &mut BzDecoder::new(BufReader::new(bytes_stream)),
482            CompressFile::Zstd => &mut ZstdDecoder::new(BufReader::new(bytes_stream)),
483            CompressFile::Lzma => &mut LzmaDecoder::new(BufReader::new(bytes_stream)),
484            CompressFile::Lz4 => &mut Lz4Decoder::new(BufReader::new(bytes_stream)),
485            CompressFile::Nothing => &mut BufReader::new(bytes_stream),
486        };
487
488        let mut reader = reader.compat();
489
490        let mut buf = vec![0u8; DOWNLOAD_BUFSIZE];
491
492        loop {
493            let size = match timeout(self.timeout, reader.read(&mut buf[..])).await {
494                Ok(Ok(size)) => size,
495                Ok(Err(e)) => {
496                    callback(Event::ProgressDone(self.download_list_index)).await;
497                    return Err(SingleDownloadError::BrokenPipe { source: e });
498                }
499                Err(_) => {
500                    callback(Event::ProgressDone(self.download_list_index)).await;
501                    return Err(SingleDownloadError::DownloadTimeout);
502                }
503            };
504
505            if size == 0 {
506                break;
507            }
508
509            if let Err(e) = dest.write_all(&buf[..size]).await {
510                callback(Event::ProgressDone(self.download_list_index)).await;
511                return Err(SingleDownloadError::Write { source: e });
512            }
513
514            callback(Event::ProgressInc {
515                index: self.download_list_index,
516                size: size as u64,
517            })
518            .await;
519
520            self_progress += size as u64;
521
522            callback(Event::GlobalProgressAdd(size as u64)).await;
523
524            if let Some(ref mut v) = validator {
525                v.update(&buf[..size]);
526            }
527        }
528
529        // 下载完成,告诉运行时不再写这个文件了
530        trace!("Download complete! Shutting down destination file stream ...");
531        if let Err(e) = dest.shutdown().await {
532            callback(Event::ProgressDone(self.download_list_index)).await;
533            return Err(SingleDownloadError::Flush { source: e });
534        }
535
536        // 最后看看 checksum 验证是否通过
537        if let Some(v) = validator {
538            if !v.finish() {
539                debug!("Checksum mismatch for file {}", self.entry.filename);
540                trace!("{self_progress}");
541
542                callback(Event::GlobalProgressSub(self_progress)).await;
543                callback(Event::ProgressDone(self.download_list_index)).await;
544                return Err(SingleDownloadError::ChecksumMismatch);
545            }
546
547            trace!(
548                "Checksum verification successful for file {}",
549                self.entry.filename
550            );
551        }
552
553        callback(Event::ProgressDone(self.download_list_index)).await;
554
555        Ok(true)
556    }
557
558    fn build_request_with_basic_auth(
559        &self,
560        url: &str,
561        method: Method,
562        auth: &Option<(String, String)>,
563    ) -> RequestBuilder {
564        let mut req = self.client.request(method, url);
565
566        if let Some((user, password)) = auth {
567            trace!("Authenticating as user: {} ...", user);
568            req = req.basic_auth(user, Some(password));
569        }
570
571        req
572    }
573
574    /// Download local source file
575    async fn download_local(
576        &self,
577        source: &DownloadSource,
578        as_symlink: bool,
579        callback: &impl AsyncFn(Event),
580    ) -> Result<bool, SingleDownloadError> {
581        debug!("{:?}", self.entry);
582
583        let url = source.url.strip_prefix("file:").unwrap();
584
585        let url_path = Path::new(url);
586
587        let total_size = tokio::fs::metadata(url_path)
588            .await
589            .context(OpenSnafu)?
590            .len();
591
592        let file = self.entry.dir.join(&*self.entry.filename);
593        if file.is_symlink() || (as_symlink && file.is_file()) {
594            tokio::fs::remove_file(&file).await.context(RemoveSnafu)?;
595        }
596
597        if as_symlink {
598            if let Some(hash) = &self.entry.hash {
599                self.checksum_local(callback, url_path, hash).await?;
600            }
601
602            tokio::fs::symlink(url_path, file)
603                .await
604                .context(CreateSymlinkSnafu)?;
605
606            return Ok(true);
607        }
608
609        callback(Event::NewProgressBar {
610            index: self.download_list_index,
611            total: self.total,
612            msg: self.download_message(),
613            size: total_size,
614        })
615        .await;
616
617        trace!("Path for file: {}", url_path.display());
618
619        let from = File::open(&url_path).await.context(CreateSnafu)?;
620        let from = tokio::io::BufReader::new(from).compat();
621
622        trace!("Successfully opened file: {}", url_path.display());
623
624        let mut to = File::create(self.entry.dir.join(&*self.entry.filename))
625            .await
626            .context(CreateSnafu)?;
627
628        let reader: &mut (dyn AsyncRead + Unpin + Send) = match self.file_type {
629            CompressFile::Xz => &mut XzDecoder::new(BufReader::new(from)),
630            CompressFile::Gzip => &mut GzipDecoder::new(BufReader::new(from)),
631            CompressFile::Bz2 => &mut BzDecoder::new(BufReader::new(from)),
632            CompressFile::Zstd => &mut ZstdDecoder::new(BufReader::new(from)),
633            CompressFile::Lzma => &mut LzmaDecoder::new(BufReader::new(from)),
634            CompressFile::Lz4 => &mut Lz4Decoder::new(BufReader::new(from)),
635            CompressFile::Nothing => &mut BufReader::new(from),
636        };
637
638        let mut reader = reader.compat();
639
640        trace!(
641            "Successfully created file: {}",
642            self.entry.dir.join(&*self.entry.filename).display()
643        );
644
645        let mut v = self.entry.hash.as_ref().map(|v| v.get_validator());
646
647        let mut buf = vec![0u8; READ_FILE_BUFSIZE];
648        let mut self_progress = 0;
649
650        loop {
651            let size = reader.read(&mut buf[..]).await.context(BrokenPipeSnafu)?;
652            self_progress += size;
653
654            if size == 0 {
655                break;
656            }
657
658            to.write_all(&buf[..size]).await.context(WriteSnafu)?;
659
660            callback(Event::ProgressInc {
661                index: self.download_list_index,
662                size: size as u64,
663            })
664            .await;
665
666            if let Some(ref mut v) = v {
667                v.update(&buf[..size]);
668            }
669
670            callback(Event::GlobalProgressAdd(size as u64)).await;
671        }
672
673        if v.is_some_and(|v| !v.finish()) {
674            callback(Event::GlobalProgressSub(self_progress as u64)).await;
675            callback(Event::ProgressDone(self.download_list_index)).await;
676            return Err(SingleDownloadError::ChecksumMismatch);
677        }
678
679        callback(Event::ProgressDone(self.download_list_index)).await;
680
681        Ok(true)
682    }
683
684    async fn checksum_local(
685        &self,
686        callback: &impl AsyncFn(Event),
687        url_path: &Path,
688        hash: &crate::checksum::Checksum,
689    ) -> Result<(), SingleDownloadError> {
690        let mut f = fs::File::open(url_path).await.context(OpenSnafu)?;
691        let (size, finish) = checksum(callback, &mut f, &mut hash.get_validator()).await;
692
693        if !finish {
694            callback(Event::GlobalProgressSub(size)).await;
695            callback(Event::ProgressDone(self.download_list_index)).await;
696            return Err(SingleDownloadError::ChecksumMismatch);
697        }
698
699        Ok(())
700    }
701
702    fn download_message(&self) -> String {
703        self.msg
704            .as_deref()
705            .unwrap_or(&self.entry.filename)
706            .to_string()
707    }
708}
709
710async fn checksum(
711    callback: &impl AsyncFn(Event),
712    f: &mut File,
713    v: &mut ChecksumValidator,
714) -> (u64, bool) {
715    let mut reader = tokio::io::BufReader::with_capacity(READ_FILE_BUFSIZE, f);
716
717    let mut read = 0;
718
719    loop {
720        let buffer = match reader.fill_buf().await {
721            Ok([]) => break,
722            Ok(buffer) => buffer,
723            Err(e) => {
724                debug!("Error while reading file: {e}");
725                break;
726            }
727        };
728
729        v.update(buffer);
730
731        callback(Event::GlobalProgressAdd(buffer.len() as u64)).await;
732        read += buffer.len() as u64;
733        let len = buffer.len();
734
735        reader.consume(len);
736    }
737
738    (read, v.finish())
739}