1use crate::{CompressFile, DownloadSource, Event, checksum::ChecksumValidator};
2use std::{
3 fs::Permissions,
4 io::{self, ErrorKind, SeekFrom},
5 os::unix::fs::PermissionsExt,
6 path::Path,
7 time::Duration,
8};
9
10use async_compression::futures::bufread::{BzDecoder, GzipDecoder, XzDecoder, ZstdDecoder};
11use bon::Builder;
12use futures::{AsyncRead, TryStreamExt, io::BufReader};
13use reqwest::{
14 Client, Method, RequestBuilder,
15 header::{ACCEPT_RANGES, CONTENT_LENGTH, HeaderValue, RANGE},
16};
17use snafu::{ResultExt, Snafu};
18use tokio::{
19 fs::{self, File},
20 io::{AsyncBufReadExt as _, AsyncReadExt as _, AsyncSeekExt, AsyncWriteExt},
21 time::timeout,
22};
23
24use tokio_util::compat::{FuturesAsyncReadCompatExt, TokioAsyncReadCompatExt};
25use tracing::debug;
26
27use crate::{DownloadEntry, DownloadSourceType};
28
29const READ_FILE_BUFSIZE: usize = 65536;
30const DOWNLOAD_BUFSIZE: usize = 8192;
31
32#[derive(Snafu, Debug)]
33#[snafu(display("source list is empty"))]
34pub struct EmptySource {
35 file_name: String,
36}
37
38#[derive(Builder)]
39pub(crate) struct SingleDownloader<'a> {
40 client: &'a Client,
41 #[builder(with = |entry: &'a DownloadEntry| -> Result<_, EmptySource> {
42 if entry.source.is_empty() {
43 return Err(EmptySource { file_name: entry.filename.to_string() });
44 } else {
45 return Ok(entry);
46 }
47 })]
48 pub entry: &'a DownloadEntry,
49 progress: (usize, usize),
50 retry_times: usize,
51 msg: Option<String>,
52 download_list_index: usize,
53 file_type: CompressFile,
54 set_permission: Option<u32>,
55 timeout: Duration,
56}
57
58pub enum DownloadResult {
59 Success(SuccessSummary),
60 Failed { file_name: String },
61}
62
63#[derive(Debug)]
64pub struct SuccessSummary {
65 pub file_name: String,
66 pub index: usize,
67 pub wrote: bool,
68}
69
70#[derive(Debug, Snafu)]
71pub enum SingleDownloadError {
72 #[snafu(display("Failed to set permission"))]
73 SetPermission { source: io::Error },
74 #[snafu(display("Failed to open file as rw mode"))]
75 OpenAsWriteMode { source: io::Error },
76 #[snafu(display("Failed to open file"))]
77 Open { source: io::Error },
78 #[snafu(display("Failed to create file"))]
79 Create { source: io::Error },
80 #[snafu(display("Failed to seek file"))]
81 Seek { source: io::Error },
82 #[snafu(display("Failed to write file"))]
83 Write { source: io::Error },
84 #[snafu(display("Failed to flush file"))]
85 Flush { source: io::Error },
86 #[snafu(display("Failed to Remove file"))]
87 Remove { source: io::Error },
88 #[snafu(display("Failed to create symlink"))]
89 CreateSymlink { source: io::Error },
90 #[snafu(display("Request Error"))]
91 ReqwestError { source: reqwest::Error },
92 #[snafu(display("Broken pipe"))]
93 BrokenPipe { source: io::Error },
94 #[snafu(display("Send request timeout"))]
95 SendRequestTimeout,
96 #[snafu(display("Download file timeout"))]
97 DownloadTimeout,
98 #[snafu(display("checksum mismatch"))]
99 ChecksumMismatch,
100}
101
102impl SingleDownloader<'_> {
103 pub(crate) async fn try_download(self, callback: &impl AsyncFn(Event)) -> DownloadResult {
104 let mut sources = self.entry.source.clone();
105 assert!(!sources.is_empty());
106
107 sources.sort_unstable_by(|a, b| b.source_type.cmp(&a.source_type));
108
109 let msg = self.msg.as_deref().unwrap_or(&*self.entry.filename);
110
111 for (index, c) in sources.iter().enumerate() {
112 let download_res = match &c.source_type {
113 DownloadSourceType::Http { auth } => {
114 self.try_http_download(c, auth, callback).await
115 }
116 DownloadSourceType::Local(as_symlink) => {
117 self.download_local(c, *as_symlink, callback).await
118 }
119 };
120
121 match download_res {
122 Ok(b) => {
123 callback(Event::DownloadDone {
124 index: self.download_list_index,
125 msg: msg.into(),
126 })
127 .await;
128
129 return DownloadResult::Success(SuccessSummary {
130 file_name: self.entry.filename.to_string(),
131 index: self.download_list_index,
132 wrote: b,
133 });
134 }
135 Err(e) => {
136 if index == sources.len() - 1 {
137 callback(Event::Failed {
138 file_name: self.entry.filename.clone(),
139 error: e,
140 })
141 .await;
142
143 return DownloadResult::Failed {
144 file_name: self.entry.filename.to_string(),
145 };
146 }
147
148 callback(Event::NextUrl {
149 index: self.download_list_index,
150 file_name: self.entry.filename.to_string(),
151 err: e,
152 })
153 .await;
154 }
155 }
156 }
157
158 unreachable!()
159 }
160
161 async fn try_http_download(
163 &self,
164 source: &DownloadSource,
165 auth: &Option<(String, String)>,
166 callback: &impl AsyncFn(Event),
167 ) -> Result<bool, SingleDownloadError> {
168 let mut times = 1;
169 let mut allow_resume = self.entry.allow_resume;
170 loop {
171 match self
172 .http_download(allow_resume, source, auth, callback)
173 .await
174 {
175 Ok(s) => {
176 return Ok(s);
177 }
178 Err(e) => match e {
179 SingleDownloadError::ChecksumMismatch => {
180 if self.retry_times == times {
181 return Err(e);
182 }
183
184 if times > 1 {
185 callback(Event::ChecksumMismatch {
186 index: self.download_list_index,
187 filename: self.entry.filename.to_string(),
188 times,
189 })
190 .await;
191 }
192
193 times += 1;
194 allow_resume = false;
195 }
196 e => {
197 return Err(e);
198 }
199 },
200 }
201 }
202 }
203
204 async fn http_download(
205 &self,
206 allow_resume: bool,
207 source: &DownloadSource,
208 auth: &Option<(String, String)>,
209 callback: &impl AsyncFn(Event),
210 ) -> Result<bool, SingleDownloadError> {
211 let file = self.entry.dir.join(&*self.entry.filename);
212 let file_exist = file.exists();
213 let mut file_size = file.metadata().ok().map(|x| x.len()).unwrap_or(0);
214
215 debug!("{} Exist file size is: {file_size}", file.display());
216 debug!("{} download url is: {}", file.display(), source.url);
217 let mut dest = None;
218 let mut validator = None;
219
220 if file.is_symlink() {
221 tokio::fs::remove_file(&file).await.context(RemoveSnafu)?;
222 }
223
224 if file_exist && !file.is_symlink() {
227 debug!("File: {} exists", self.entry.filename);
228
229 self.set_permission_with_path(&file).await?;
230
231 if let Some(hash) = &self.entry.hash {
232 debug!("Hash exist! It is: {}", hash);
233
234 let mut f = tokio::fs::OpenOptions::new()
235 .write(true)
236 .read(true)
237 .open(&file)
238 .await
239 .context(OpenAsWriteModeSnafu)?;
240
241 debug!(
242 "oma opened file: {} with write and read mode",
243 self.entry.filename
244 );
245
246 let mut v = hash.get_validator();
247
248 debug!("Validator created.");
249
250 let (read, finish) = checksum(callback, &mut f, &mut v).await;
251
252 if finish {
253 debug!(
254 "{} checksum success, no need to download anything.",
255 self.entry.filename
256 );
257
258 callback(Event::ProgressDone(self.download_list_index)).await;
259
260 return Ok(false);
261 }
262
263 debug!(
264 "checksum fail, will download this file: {}",
265 self.entry.filename
266 );
267
268 if !allow_resume {
269 callback(Event::GlobalProgressSub(read)).await;
270 } else {
271 dest = Some(f);
272 validator = Some(v);
273 }
274 }
275 }
276
277 let msg = self.progress_msg();
278 callback(Event::NewProgressSpinner {
279 index: self.download_list_index,
280 msg: msg.clone(),
281 })
282 .await;
283
284 let req = self.build_request_with_basic_auth(&source.url, Method::HEAD, auth);
285
286 let resp_head = timeout(self.timeout, req.send()).await;
287
288 callback(Event::ProgressDone(self.download_list_index)).await;
289
290 let resp_head = match resp_head {
291 Ok(Ok(resp)) => resp,
292 Ok(Err(e)) => {
293 return Err(SingleDownloadError::ReqwestError { source: e });
294 }
295 Err(_) => {
296 return Err(SingleDownloadError::SendRequestTimeout);
297 }
298 };
299
300 let head = resp_head.headers();
301
302 let mut can_resume = match head.get(ACCEPT_RANGES) {
306 Some(x) if x == "none" => false,
307 Some(_) => true,
308 None => false,
309 };
310
311 debug!("Can resume? {can_resume}");
312
313 let total_size = {
315 let total_size = head
316 .get(CONTENT_LENGTH)
317 .map(|x| x.to_owned())
318 .unwrap_or(HeaderValue::from(0));
319
320 total_size
321 .to_str()
322 .ok()
323 .and_then(|x| x.parse::<u64>().ok())
324 .unwrap_or_default()
325 };
326
327 debug!("File total size is: {total_size}");
328
329 let mut req = self.build_request_with_basic_auth(&source.url, Method::GET, auth);
330
331 if can_resume && allow_resume {
332 if total_size <= file_size {
335 debug!("Exist file size is reset to 0, because total size <= exist file size");
336 callback(Event::GlobalProgressSub(file_size)).await;
337 file_size = 0;
338 can_resume = false;
339 }
340
341 debug!("oma will set header range as bytes={file_size}-");
343 req = req.header(RANGE, format!("bytes={}-", file_size));
344 }
345
346 debug!("Can resume? {can_resume}");
347
348 let resp = timeout(self.timeout, req.send()).await;
349
350 callback(Event::ProgressDone(self.download_list_index)).await;
351
352 let resp = match resp {
353 Ok(resp) => resp
354 .and_then(|resp| resp.error_for_status())
355 .context(ReqwestSnafu)?,
356 Err(_) => return Err(SingleDownloadError::SendRequestTimeout),
357 };
358
359 callback(Event::NewProgressBar {
360 index: self.download_list_index,
361 msg,
362 size: total_size,
363 })
364 .await;
365
366 let source = resp;
367
368 let hash = &self.entry.hash;
369
370 let mut self_progress = 0;
371 let (mut dest, mut validator) = if !can_resume || !allow_resume {
372 debug!(
374 "oma will open file: {} as create mode.",
375 self.entry.filename
376 );
377
378 let f = match File::create(&file).await {
379 Ok(f) => f,
380 Err(e) => {
381 callback(Event::ProgressDone(self.download_list_index)).await;
382 return Err(SingleDownloadError::Create { source: e });
383 }
384 };
385
386 if file_size > 0 {
387 callback(Event::GlobalProgressSub(file_size)).await;
388 }
389
390 if let Err(e) = f.set_len(0).await {
391 callback(Event::ProgressDone(self.download_list_index)).await;
392 return Err(SingleDownloadError::Create { source: e });
393 }
394
395 self.set_permission(&f).await?;
396
397 (f, hash.as_ref().map(|hash| hash.get_validator()))
398 } else if let Some((dest, validator)) = dest.zip(validator) {
399 callback(Event::ProgressInc {
400 index: self.download_list_index,
401 size: file_size,
402 })
403 .await;
404
405 debug!(
406 "oma will re use opened dest file for {}",
407 self.entry.filename
408 );
409 self_progress += file_size;
410
411 (dest, Some(validator))
412 } else {
413 debug!(
414 "oma will open file: {} as create mode.",
415 self.entry.filename
416 );
417
418 let f = match File::create(&file).await {
419 Ok(f) => f,
420 Err(e) => {
421 callback(Event::ProgressDone(self.download_list_index)).await;
422 return Err(SingleDownloadError::Create { source: e });
423 }
424 };
425
426 if let Err(e) = f.set_len(0).await {
427 callback(Event::ProgressDone(self.download_list_index)).await;
428 return Err(SingleDownloadError::Create { source: e });
429 }
430
431 self.set_permission(&f).await?;
432
433 (f, hash.as_ref().map(|hash| hash.get_validator()))
434 };
435
436 if can_resume && allow_resume {
437 debug!("oma will seek file: {} to end", self.entry.filename);
439 if let Err(e) = dest.seek(SeekFrom::End(0)).await {
440 callback(Event::ProgressDone(self.download_list_index)).await;
441 return Err(SingleDownloadError::Seek { source: e });
442 }
443 }
444 debug!("Start download!");
446
447 let bytes_stream = source
448 .bytes_stream()
449 .map_err(|e| io::Error::new(ErrorKind::Other, e))
450 .into_async_read();
451
452 let reader: &mut (dyn AsyncRead + Unpin + Send) = match self.file_type {
453 CompressFile::Xz => &mut XzDecoder::new(BufReader::new(bytes_stream)),
454 CompressFile::Gzip => &mut GzipDecoder::new(BufReader::new(bytes_stream)),
455 CompressFile::Bz2 => &mut BzDecoder::new(BufReader::new(bytes_stream)),
456 CompressFile::Nothing => &mut BufReader::new(bytes_stream),
457 CompressFile::Zstd => &mut ZstdDecoder::new(BufReader::new(bytes_stream)),
458 };
459
460 let mut reader = reader.compat();
461
462 let mut buf = vec![0u8; DOWNLOAD_BUFSIZE];
463
464 loop {
465 let size = match timeout(self.timeout, reader.read(&mut buf[..])).await {
466 Ok(Ok(size)) => size,
467 Ok(Err(e)) => {
468 callback(Event::ProgressDone(self.download_list_index)).await;
469 return Err(SingleDownloadError::BrokenPipe { source: e });
470 }
471 Err(_) => {
472 callback(Event::ProgressDone(self.download_list_index)).await;
473 return Err(SingleDownloadError::DownloadTimeout);
474 }
475 };
476
477 if size == 0 {
478 break;
479 }
480
481 if let Err(e) = dest.write_all(&buf[..size]).await {
482 callback(Event::ProgressDone(self.download_list_index)).await;
483 return Err(SingleDownloadError::Write { source: e });
484 }
485
486 callback(Event::ProgressInc {
487 index: self.download_list_index,
488 size: size as u64,
489 })
490 .await;
491
492 self_progress += size as u64;
493
494 callback(Event::GlobalProgressAdd(size as u64)).await;
495
496 if let Some(ref mut v) = validator {
497 v.update(&buf[..size]);
498 }
499 }
500
501 debug!("Download complete! shutting down dest file stream ...");
503 if let Err(e) = dest.shutdown().await {
504 callback(Event::ProgressDone(self.download_list_index)).await;
505 return Err(SingleDownloadError::Flush { source: e });
506 }
507
508 if let Some(v) = validator {
510 if !v.finish() {
511 debug!("checksum fail: {}", self.entry.filename);
512 debug!("{self_progress}");
513
514 callback(Event::GlobalProgressSub(self_progress)).await;
515 callback(Event::ProgressDone(self.download_list_index)).await;
516 return Err(SingleDownloadError::ChecksumMismatch);
517 }
518
519 debug!("checksum success: {}", self.entry.filename);
520 }
521
522 callback(Event::ProgressDone(self.download_list_index)).await;
523
524 Ok(true)
525 }
526
527 async fn set_permission(&self, f: &File) -> Result<(), SingleDownloadError> {
528 if let Some(mode) = self.set_permission {
529 debug!("Setting {} permission to {:#o}", self.entry.filename, mode);
530 f.set_permissions(Permissions::from_mode(mode))
531 .await
532 .context(SetPermissionSnafu)?;
533 }
534
535 Ok(())
536 }
537
538 async fn set_permission_with_path(&self, path: &Path) -> Result<(), SingleDownloadError> {
539 if let Some(mode) = self.set_permission {
540 debug!("Setting {} permission to {:#o}", self.entry.filename, mode);
541
542 fs::set_permissions(path, Permissions::from_mode(mode))
543 .await
544 .context(SetPermissionSnafu)?;
545 }
546
547 Ok(())
548 }
549
550 fn build_request_with_basic_auth(
551 &self,
552 url: &str,
553 method: Method,
554 auth: &Option<(String, String)>,
555 ) -> RequestBuilder {
556 let mut req = self.client.request(method, url);
557
558 if let Some((user, password)) = auth {
559 debug!("auth user: {}", user);
560 req = req.basic_auth(user, Some(password));
561 }
562
563 req
564 }
565
566 fn progress_msg(&self) -> String {
567 let (count, len) = &self.progress;
568 let msg = self.msg.as_deref().unwrap_or(&self.entry.filename);
569 let msg = format!("({count}/{len}) {msg}");
570
571 msg
572 }
573
574 async fn download_local(
576 &self,
577 source: &DownloadSource,
578 as_symlink: bool,
579 callback: &impl AsyncFn(Event),
580 ) -> Result<bool, SingleDownloadError> {
581 debug!("{:?}", self.entry);
582 let msg = self.progress_msg();
583
584 let url = source.url.strip_prefix("file:").unwrap();
585
586 let url_path = Path::new(url);
587
588 let total_size = tokio::fs::metadata(url_path)
589 .await
590 .context(OpenSnafu)?
591 .len();
592
593 let file = self.entry.dir.join(&*self.entry.filename);
594 if file.is_symlink() || (as_symlink && file.is_file()) {
595 tokio::fs::remove_file(&file).await.context(RemoveSnafu)?;
596 }
597
598 if as_symlink {
599 if let Some(hash) = &self.entry.hash {
600 self.checksum_local(callback, url_path, hash).await?;
601 }
602
603 tokio::fs::symlink(url_path, file)
604 .await
605 .context(CreateSymlinkSnafu)?;
606
607 return Ok(true);
608 }
609
610 callback(Event::NewProgressBar {
611 index: self.download_list_index,
612 msg,
613 size: total_size,
614 })
615 .await;
616
617 debug!("File path is: {}", url_path.display());
618
619 let from = File::open(&url_path).await.context(CreateSnafu)?;
620 let from = tokio::io::BufReader::new(from).compat();
621
622 debug!("Success open file: {}", url_path.display());
623
624 let mut to = File::create(self.entry.dir.join(&*self.entry.filename))
625 .await
626 .context(CreateSnafu)?;
627
628 self.set_permission(&to).await?;
629
630 let reader: &mut (dyn AsyncRead + Unpin + Send) = match self.file_type {
631 CompressFile::Xz => &mut XzDecoder::new(BufReader::new(from)),
632 CompressFile::Gzip => &mut GzipDecoder::new(BufReader::new(from)),
633 CompressFile::Bz2 => &mut BzDecoder::new(BufReader::new(from)),
634 CompressFile::Nothing => &mut BufReader::new(from),
635 CompressFile::Zstd => &mut ZstdDecoder::new(BufReader::new(from)),
636 };
637
638 let mut reader = reader.compat();
639
640 debug!(
641 "Success create file: {}",
642 self.entry.dir.join(&*self.entry.filename).display()
643 );
644
645 let mut v = self.entry.hash.as_ref().map(|v| v.get_validator());
646
647 let mut buf = vec![0u8; READ_FILE_BUFSIZE];
648 let mut self_progress = 0;
649
650 loop {
651 let size = reader.read(&mut buf[..]).await.context(BrokenPipeSnafu)?;
652 self_progress += size;
653
654 if size == 0 {
655 break;
656 }
657
658 to.write_all(&buf[..size]).await.context(WriteSnafu)?;
659
660 callback(Event::ProgressInc {
661 index: self.download_list_index,
662 size: size as u64,
663 })
664 .await;
665
666 if let Some(ref mut v) = v {
667 v.update(&buf[..size]);
668 }
669
670 callback(Event::GlobalProgressAdd(size as u64)).await;
671 }
672
673 if v.is_some_and(|v| !v.finish()) {
674 callback(Event::GlobalProgressSub(self_progress as u64)).await;
675 callback(Event::ProgressDone(self.download_list_index)).await;
676 return Err(SingleDownloadError::ChecksumMismatch);
677 }
678
679 callback(Event::ProgressDone(self.download_list_index)).await;
680
681 Ok(true)
682 }
683
684 async fn checksum_local(
685 &self,
686 callback: &impl AsyncFn(Event),
687 url_path: &Path,
688 hash: &crate::checksum::Checksum,
689 ) -> Result<(), SingleDownloadError> {
690 let mut f = fs::File::open(url_path).await.context(OpenSnafu)?;
691 let (size, finish) = checksum(callback, &mut f, &mut hash.get_validator()).await;
692
693 if !finish {
694 callback(Event::GlobalProgressSub(size)).await;
695 callback(Event::ProgressDone(self.download_list_index)).await;
696 return Err(SingleDownloadError::ChecksumMismatch);
697 }
698
699 Ok(())
700 }
701}
702
703async fn checksum(
704 callback: &impl AsyncFn(Event),
705 f: &mut File,
706 v: &mut ChecksumValidator,
707) -> (u64, bool) {
708 let mut reader = tokio::io::BufReader::with_capacity(READ_FILE_BUFSIZE, f);
709
710 let mut read = 0;
711
712 loop {
713 let Ok(buffer) = reader.fill_buf().await else {
714 debug!("Read file get fk, so re-download it");
715 break;
716 };
717
718 if buffer.is_empty() {
719 break;
720 }
721
722 v.update(buffer);
723
724 callback(Event::GlobalProgressAdd(buffer.len() as u64)).await;
725 read += buffer.len() as u64;
726 let len = buffer.len();
727
728 reader.consume(len);
729 }
730
731 (read, v.finish())
732}