1use crate::{CompressFile, DownloadSource, Event, checksum::ChecksumValidator};
2use std::{
3 fs::Permissions,
4 future::Future,
5 io::{self, ErrorKind, SeekFrom},
6 os::unix::fs::PermissionsExt,
7 path::Path,
8 time::Duration,
9};
10
11use async_compression::futures::bufread::{BzDecoder, GzipDecoder, XzDecoder, ZstdDecoder};
12use bon::Builder;
13use futures::{AsyncRead, TryStreamExt, io::BufReader};
14use oma_utils::url_no_escape::url_no_escape;
15use reqwest::{
16 Client, Method, RequestBuilder,
17 header::{ACCEPT_RANGES, CONTENT_LENGTH, HeaderValue, RANGE},
18};
19use snafu::{ResultExt, Snafu};
20use tokio::{
21 fs::{self, File},
22 io::{AsyncReadExt as _, AsyncSeekExt, AsyncWriteExt},
23 time::timeout,
24};
25
26use tokio_util::compat::{FuturesAsyncReadCompatExt, TokioAsyncReadCompatExt};
27use tracing::debug;
28
29use crate::{DownloadEntry, DownloadSourceType};
30
31#[derive(Snafu, Debug)]
32#[snafu(display("source list is empty"))]
33pub struct EmptySource {
34 file_name: String,
35}
36
37#[derive(Builder)]
38pub(crate) struct SingleDownloader<'a> {
39 client: &'a Client,
40 #[builder(with = |entry: &'a DownloadEntry| -> Result<_, EmptySource> {
41 if entry.source.is_empty() {
42 return Err(EmptySource { file_name: entry.filename.to_string() });
43 } else {
44 return Ok(entry);
45 }
46 })]
47 pub entry: &'a DownloadEntry,
48 progress: (usize, usize),
49 retry_times: usize,
50 msg: Option<String>,
51 download_list_index: usize,
52 file_type: CompressFile,
53 set_permission: Option<u32>,
54 timeout: Duration,
55}
56
57pub enum DownloadResult {
58 Success(SuccessSummary),
59 Failed { file_name: String },
60}
61
62#[derive(Debug)]
63pub struct SuccessSummary {
64 pub file_name: String,
65 pub index: usize,
66 pub wrote: bool,
67}
68
69#[derive(Debug, Snafu)]
70pub enum SingleDownloadError {
71 #[snafu(display("Failed to set permission"))]
72 SetPermission { source: io::Error },
73 #[snafu(display("Failed to open file as rw mode"))]
74 OpenAsWriteMode { source: io::Error },
75 #[snafu(display("Failed to open file"))]
76 Open { source: io::Error },
77 #[snafu(display("Failed to create file"))]
78 Create { source: io::Error },
79 #[snafu(display("Failed to seek file"))]
80 Seek { source: io::Error },
81 #[snafu(display("Failed to write file"))]
82 Write { source: io::Error },
83 #[snafu(display("Failed to flush file"))]
84 Flush { source: io::Error },
85 #[snafu(display("Failed to Remove file"))]
86 Remove { source: io::Error },
87 #[snafu(display("Failed to create symlink"))]
88 CreateSymlink { source: io::Error },
89 #[snafu(display("Request Error"))]
90 ReqwestError { source: reqwest::Error },
91 #[snafu(display("Broken pipe"))]
92 BrokenPipe { source: io::Error },
93 #[snafu(display("Send request timeout"))]
94 SendRequestTimeout,
95 #[snafu(display("Download file timeout"))]
96 DownloadTimeout,
97 #[snafu(display("checksum mismatch"))]
98 ChecksumMismatch,
99}
100
101impl SingleDownloader<'_> {
102 pub(crate) async fn try_download<F, Fut>(self, callback: &F) -> DownloadResult
103 where
104 F: Fn(Event) -> Fut,
105 Fut: Future<Output = ()>,
106 {
107 let mut sources = self.entry.source.clone();
108 assert!(!sources.is_empty());
109
110 sources.sort_unstable_by(|a, b| b.source_type.cmp(&a.source_type));
111
112 let msg = self.msg.as_deref().unwrap_or(&*self.entry.filename);
113
114 for (index, c) in sources.iter().enumerate() {
115 let download_res = match &c.source_type {
116 DownloadSourceType::Http { auth } => {
117 self.try_http_download(c, auth, callback).await
118 }
119 DownloadSourceType::Local(as_symlink) => {
120 self.download_local(c, *as_symlink, callback).await
121 }
122 };
123
124 match download_res {
125 Ok(b) => {
126 callback(Event::DownloadDone {
127 index: self.download_list_index,
128 msg: msg.into(),
129 })
130 .await;
131
132 return DownloadResult::Success(SuccessSummary {
133 file_name: self.entry.filename.to_string(),
134 index: self.download_list_index,
135 wrote: b,
136 });
137 }
138 Err(e) => {
139 if index == sources.len() - 1 {
140 callback(Event::Failed {
141 file_name: self.entry.filename.clone(),
142 error: e,
143 })
144 .await;
145
146 return DownloadResult::Failed {
147 file_name: self.entry.filename.to_string(),
148 };
149 }
150
151 callback(Event::NextUrl {
152 index: self.download_list_index,
153 file_name: self.entry.filename.to_string(),
154 err: e,
155 })
156 .await;
157 }
158 }
159 }
160
161 unreachable!()
162 }
163
164 async fn try_http_download<F, Fut>(
166 &self,
167 source: &DownloadSource,
168 auth: &Option<(String, String)>,
169 callback: &F,
170 ) -> Result<bool, SingleDownloadError>
171 where
172 F: Fn(Event) -> Fut,
173 Fut: Future<Output = ()>,
174 {
175 let mut times = 1;
176 let mut allow_resume = self.entry.allow_resume;
177 loop {
178 match self
179 .http_download(allow_resume, source, auth, callback)
180 .await
181 {
182 Ok(s) => {
183 return Ok(s);
184 }
185 Err(e) => match e {
186 SingleDownloadError::ChecksumMismatch => {
187 if self.retry_times == times {
188 return Err(e);
189 }
190
191 if times > 1 {
192 callback(Event::ChecksumMismatch {
193 index: self.download_list_index,
194 filename: self.entry.filename.to_string(),
195 times,
196 })
197 .await;
198 }
199
200 times += 1;
201 allow_resume = false;
202 }
203 e => {
204 return Err(e);
205 }
206 },
207 }
208 }
209 }
210
211 async fn http_download<F, Fut>(
212 &self,
213 allow_resume: bool,
214 source: &DownloadSource,
215 auth: &Option<(String, String)>,
216 callback: &F,
217 ) -> Result<bool, SingleDownloadError>
218 where
219 F: Fn(Event) -> Fut,
220 Fut: Future<Output = ()>,
221 {
222 let file = self.entry.dir.join(&*self.entry.filename);
223 let file_exist = file.exists();
224 let mut file_size = file.metadata().ok().map(|x| x.len()).unwrap_or(0);
225
226 debug!("{} Exist file size is: {file_size}", file.display());
227 debug!("{} download url is: {}", file.display(), source.url);
228 let mut dest = None;
229 let mut validator = None;
230
231 if file.is_symlink() {
232 tokio::fs::remove_file(&file).await.context(RemoveSnafu)?;
233 }
234
235 if file_exist && !file.is_symlink() {
238 debug!("File: {} exists", self.entry.filename);
239
240 self.set_permission_with_path(&file).await?;
241
242 if let Some(hash) = &self.entry.hash {
243 debug!("Hash exist! It is: {}", hash);
244
245 let mut f = tokio::fs::OpenOptions::new()
246 .write(true)
247 .read(true)
248 .open(&file)
249 .await
250 .context(OpenAsWriteModeSnafu)?;
251
252 debug!(
253 "oma opened file: {} with write and read mode",
254 self.entry.filename
255 );
256
257 let mut v = hash.get_validator();
258
259 debug!("Validator created.");
260
261 let (read, finish) = checksum(callback, file_size, &mut f, &mut v).await;
262
263 if finish {
264 debug!(
265 "{} checksum success, no need to download anything.",
266 self.entry.filename
267 );
268
269 callback(Event::ProgressDone(self.download_list_index)).await;
270
271 return Ok(false);
272 }
273
274 debug!(
275 "checksum fail, will download this file: {}",
276 self.entry.filename
277 );
278
279 if !allow_resume {
280 callback(Event::GlobalProgressSub(read)).await;
281 } else {
282 dest = Some(f);
283 validator = Some(v);
284 }
285 }
286 }
287
288 let msg = self.progress_msg();
289 callback(Event::NewProgressSpinner {
290 index: self.download_list_index,
291 msg: msg.clone(),
292 })
293 .await;
294
295 let req = self.build_request_with_basic_auth(&source.url, Method::HEAD, auth);
296
297 let resp_head = timeout(self.timeout, req.send()).await;
298
299 callback(Event::ProgressDone(self.download_list_index)).await;
300
301 let resp_head = match resp_head {
302 Ok(Ok(resp)) => resp,
303 Ok(Err(e)) => {
304 return Err(SingleDownloadError::ReqwestError { source: e });
305 }
306 Err(_) => {
307 return Err(SingleDownloadError::SendRequestTimeout);
308 }
309 };
310
311 let head = resp_head.headers();
312
313 let mut can_resume = match head.get(ACCEPT_RANGES) {
317 Some(x) if x == "none" => false,
318 Some(_) => true,
319 None => false,
320 };
321
322 debug!("Can resume? {can_resume}");
323
324 let total_size = {
326 let total_size = head
327 .get(CONTENT_LENGTH)
328 .map(|x| x.to_owned())
329 .unwrap_or(HeaderValue::from(0));
330
331 total_size
332 .to_str()
333 .ok()
334 .and_then(|x| x.parse::<u64>().ok())
335 .unwrap_or_default()
336 };
337
338 debug!("File total size is: {total_size}");
339
340 let mut req = self.build_request_with_basic_auth(&source.url, Method::GET, auth);
341
342 if can_resume && allow_resume {
343 if total_size <= file_size {
346 debug!("Exist file size is reset to 0, because total size <= exist file size");
347 callback(Event::GlobalProgressSub(file_size)).await;
348 file_size = 0;
349 can_resume = false;
350 }
351
352 debug!("oma will set header range as bytes={file_size}-");
354 req = req.header(RANGE, format!("bytes={}-", file_size));
355 }
356
357 debug!("Can resume? {can_resume}");
358
359 let resp = timeout(self.timeout, req.send()).await;
360
361 callback(Event::ProgressDone(self.download_list_index)).await;
362
363 let resp = match resp {
364 Ok(resp) => resp
365 .and_then(|resp| resp.error_for_status())
366 .context(ReqwestSnafu)?,
367 Err(_) => return Err(SingleDownloadError::SendRequestTimeout),
368 };
369
370 callback(Event::NewProgressBar {
371 index: self.download_list_index,
372 msg,
373 size: total_size,
374 })
375 .await;
376
377 let source = resp;
378
379 let hash = &self.entry.hash;
382 let mut validator = if let Some(v) = validator {
383 callback(Event::ProgressInc {
384 index: self.download_list_index,
385 size: file_size,
386 })
387 .await;
388
389 Some(v)
390 } else {
391 hash.as_ref().map(|hash| hash.get_validator())
392 };
393
394 let mut self_progress = 0;
395 let mut dest = if !can_resume || !allow_resume {
396 debug!(
398 "oma will open file: {} as create mode.",
399 self.entry.filename
400 );
401
402 let f = match File::create(&file).await {
403 Ok(f) => f,
404 Err(e) => {
405 callback(Event::ProgressDone(self.download_list_index)).await;
406 return Err(SingleDownloadError::Create { source: e });
407 }
408 };
409
410 if let Err(e) = f.set_len(0).await {
411 callback(Event::ProgressDone(self.download_list_index)).await;
412 return Err(SingleDownloadError::Create { source: e });
413 }
414
415 self.set_permission(&f).await?;
416
417 f
418 } else if let Some(dest) = dest {
419 debug!(
420 "oma will re use opened dest file for {}",
421 self.entry.filename
422 );
423 self_progress += file_size;
424
425 dest
426 } else {
427 debug!(
428 "oma will open file: {} as create mode.",
429 self.entry.filename
430 );
431
432 let f = match File::create(&file).await {
433 Ok(f) => f,
434 Err(e) => {
435 callback(Event::ProgressDone(self.download_list_index)).await;
436 return Err(SingleDownloadError::Create { source: e });
437 }
438 };
439
440 if let Err(e) = f.set_len(0).await {
441 callback(Event::ProgressDone(self.download_list_index)).await;
442 return Err(SingleDownloadError::Create { source: e });
443 }
444
445 self.set_permission(&f).await?;
446
447 f
448 };
449
450 if can_resume && allow_resume {
451 debug!("oma will seek file: {} to end", self.entry.filename);
453 if let Err(e) = dest.seek(SeekFrom::End(0)).await {
454 callback(Event::ProgressDone(self.download_list_index)).await;
455 return Err(SingleDownloadError::Seek { source: e });
456 }
457 }
458 debug!("Start download!");
460
461 let bytes_stream = source
462 .bytes_stream()
463 .map_err(|e| io::Error::new(ErrorKind::Other, e))
464 .into_async_read();
465
466 let reader: &mut (dyn AsyncRead + Unpin + Send) = match self.file_type {
467 CompressFile::Xz => &mut XzDecoder::new(BufReader::new(bytes_stream)),
468 CompressFile::Gzip => &mut GzipDecoder::new(BufReader::new(bytes_stream)),
469 CompressFile::Bz2 => &mut BzDecoder::new(BufReader::new(bytes_stream)),
470 CompressFile::Nothing => &mut BufReader::new(bytes_stream),
471 CompressFile::Zstd => &mut ZstdDecoder::new(BufReader::new(bytes_stream)),
472 };
473
474 let mut reader = reader.compat();
475
476 let mut buf = vec![0u8; 8 * 1024];
477
478 loop {
479 let size = match timeout(self.timeout, reader.read(&mut buf[..])).await {
480 Ok(Ok(size)) => size,
481 Ok(Err(e)) => {
482 callback(Event::ProgressDone(self.download_list_index)).await;
483 return Err(SingleDownloadError::BrokenPipe { source: e });
484 }
485 Err(_) => {
486 callback(Event::ProgressDone(self.download_list_index)).await;
487 return Err(SingleDownloadError::DownloadTimeout);
488 }
489 };
490
491 if size == 0 {
492 break;
493 }
494
495 if let Err(e) = dest.write_all(&buf[..size]).await {
496 callback(Event::ProgressDone(self.download_list_index)).await;
497 return Err(SingleDownloadError::Write { source: e });
498 }
499
500 callback(Event::ProgressInc {
501 index: self.download_list_index,
502 size: size as u64,
503 })
504 .await;
505
506 self_progress += size as u64;
507
508 callback(Event::GlobalProgressAdd(size as u64)).await;
509
510 if let Some(ref mut v) = validator {
511 v.update(&buf[..size]);
512 }
513 }
514
515 debug!("Download complete! shutting down dest file stream ...");
517 if let Err(e) = dest.shutdown().await {
518 callback(Event::ProgressDone(self.download_list_index)).await;
519 return Err(SingleDownloadError::Flush { source: e });
520 }
521
522 if let Some(v) = validator {
524 if !v.finish() {
525 debug!("checksum fail: {}", self.entry.filename);
526 debug!("{self_progress}");
527
528 callback(Event::GlobalProgressSub(self_progress)).await;
529 callback(Event::ProgressDone(self.download_list_index)).await;
530 return Err(SingleDownloadError::ChecksumMismatch);
531 }
532
533 debug!("checksum success: {}", self.entry.filename);
534 }
535
536 callback(Event::ProgressDone(self.download_list_index)).await;
537
538 Ok(true)
539 }
540
541 async fn set_permission(&self, f: &File) -> Result<(), SingleDownloadError> {
542 if let Some(mode) = self.set_permission {
543 debug!("Setting {} permission to {:#o}", self.entry.filename, mode);
544 f.set_permissions(Permissions::from_mode(mode))
545 .await
546 .context(SetPermissionSnafu)?;
547 }
548
549 Ok(())
550 }
551
552 async fn set_permission_with_path(&self, path: &Path) -> Result<(), SingleDownloadError> {
553 if let Some(mode) = self.set_permission {
554 debug!("Setting {} permission to {:#o}", self.entry.filename, mode);
555
556 fs::set_permissions(path, Permissions::from_mode(mode))
557 .await
558 .context(SetPermissionSnafu)?;
559 }
560
561 Ok(())
562 }
563
564 fn build_request_with_basic_auth(
565 &self,
566 url: &str,
567 method: Method,
568 auth: &Option<(String, String)>,
569 ) -> RequestBuilder {
570 let mut req = self.client.request(method, url);
571
572 if let Some((user, password)) = auth {
573 debug!("auth user: {}", user);
574 req = req.basic_auth(user, Some(password));
575 }
576
577 req
578 }
579
580 fn progress_msg(&self) -> String {
581 let (count, len) = &self.progress;
582 let msg = self.msg.as_deref().unwrap_or(&self.entry.filename);
583 let msg = format!("({count}/{len}) {msg}");
584
585 msg
586 }
587
588 async fn download_local<F, Fut>(
590 &self,
591 source: &DownloadSource,
592 as_symlink: bool,
593 callback: &F,
594 ) -> Result<bool, SingleDownloadError>
595 where
596 F: Fn(Event) -> Fut,
597 Fut: Future<Output = ()>,
598 {
599 debug!("{:?}", self.entry);
600 let msg = self.progress_msg();
601
602 let url = &source.url;
603
604 let url_path = url_no_escape(url.strip_prefix("file:").unwrap());
606
607 let url_path = Path::new(&url_path);
608
609 let total_size = tokio::fs::metadata(url_path)
610 .await
611 .context(OpenSnafu)?
612 .len();
613
614 callback(Event::NewProgressBar {
615 index: self.download_list_index,
616 msg,
617 size: total_size,
618 })
619 .await;
620
621 if as_symlink {
622 if let Some(hash) = &self.entry.hash {
623 self.checksum_local(callback, url_path, total_size, hash)
624 .await?;
625 }
626
627 let symlink = self.entry.dir.join(&*self.entry.filename);
628 if symlink.exists() {
629 tokio::fs::remove_file(&symlink)
630 .await
631 .context(RemoveSnafu)?;
632 }
633
634 tokio::fs::symlink(url_path, symlink)
635 .await
636 .context(CreateSymlinkSnafu)?;
637
638 callback(Event::GlobalProgressAdd(total_size as u64)).await;
639 callback(Event::ProgressDone(self.download_list_index)).await;
640
641 return Ok(true);
642 }
643
644 debug!("File path is: {}", url_path.display());
645
646 let from = File::open(&url_path).await.context(CreateSnafu)?;
647 let from = tokio::io::BufReader::new(from).compat();
648
649 debug!("Success open file: {}", url_path.display());
650
651 let mut to = File::create(self.entry.dir.join(&*self.entry.filename))
652 .await
653 .context(CreateSnafu)?;
654
655 self.set_permission(&to).await?;
656
657 let reader: &mut (dyn AsyncRead + Unpin + Send) = match self.file_type {
658 CompressFile::Xz => &mut XzDecoder::new(BufReader::new(from)),
659 CompressFile::Gzip => &mut GzipDecoder::new(BufReader::new(from)),
660 CompressFile::Bz2 => &mut BzDecoder::new(BufReader::new(from)),
661 CompressFile::Nothing => &mut BufReader::new(from),
662 CompressFile::Zstd => &mut ZstdDecoder::new(BufReader::new(from)),
663 };
664
665 let mut reader = reader.compat();
666
667 debug!(
668 "Success create file: {}",
669 self.entry.dir.join(&*self.entry.filename).display()
670 );
671
672 let mut buf = vec![0u8; 8 * 1024];
673
674 loop {
675 let size = reader.read(&mut buf[..]).await.context(BrokenPipeSnafu)?;
676
677 if size == 0 {
678 break;
679 }
680
681 to.write_all(&buf[..size]).await.context(WriteSnafu)?;
682
683 callback(Event::ProgressInc {
684 index: self.download_list_index,
685 size: size as u64,
686 })
687 .await;
688
689 callback(Event::GlobalProgressAdd(size as u64)).await;
690 }
691
692 callback(Event::ProgressDone(self.download_list_index)).await;
693
694 Ok(true)
695 }
696
697 async fn checksum_local<F, Fut>(
698 &self,
699 callback: &F,
700 url_path: &Path,
701 total_size: u64,
702 hash: &crate::checksum::Checksum,
703 ) -> Result<(), SingleDownloadError>
704 where
705 F: Fn(Event) -> Fut,
706 Fut: Future<Output = ()>,
707 {
708 let mut f = fs::File::open(url_path).await.context(OpenSnafu)?;
709 let (size, finish) =
710 checksum(callback, total_size, &mut f, &mut hash.get_validator()).await;
711
712 if !finish {
713 callback(Event::GlobalProgressSub(size)).await;
714 callback(Event::ProgressDone(self.download_list_index)).await;
715 return Err(SingleDownloadError::ChecksumMismatch);
716 }
717
718 Ok(())
719 }
720}
721
722async fn checksum<F, Fut>(
723 callback: &F,
724 file_size: u64,
725 f: &mut File,
726 v: &mut ChecksumValidator,
727) -> (u64, bool)
728where
729 F: Fn(Event) -> Fut,
730 Fut: Future<Output = ()>,
731{
732 let mut buf = vec![0; 8192];
733 let mut read = 0;
734
735 loop {
736 if read == file_size {
737 break;
738 }
739
740 let Ok(read_count) = f.read(&mut buf[..]).await else {
741 debug!("Read file get fk, so re-download it");
742 break;
743 };
744
745 v.update(&buf[..read_count]);
746
747 callback(Event::GlobalProgressAdd(read_count as u64)).await;
748
749 read += read_count as u64;
750 }
751
752 (read, v.finish())
753}