1use crate::{CompressFile, DownloadSource, Event, checksum::ChecksumValidator, send_request};
2use std::{
3 borrow::Cow,
4 io::{self, SeekFrom},
5 path::Path,
6 time::Duration,
7};
8
9use async_compression::futures::bufread::{
10 BzDecoder, GzipDecoder, Lz4Decoder, LzmaDecoder, XzDecoder, ZstdDecoder,
11};
12use bon::bon;
13use futures::{AsyncRead, TryStreamExt, io::BufReader};
14use reqwest::{
15 Client, Method, RequestBuilder,
16 header::{ACCEPT_RANGES, CONTENT_LENGTH, HeaderValue, RANGE},
17};
18use snafu::{ResultExt, Snafu};
19use tokio::{
20 fs::{self, File},
21 io::{AsyncBufReadExt as _, AsyncReadExt as _, AsyncSeekExt, AsyncWriteExt},
22 time::timeout,
23};
24
25use tokio_util::compat::{FuturesAsyncReadCompatExt, TokioAsyncReadCompatExt};
26use tracing::{debug, trace};
27
28use crate::{DownloadEntry, DownloadSourceType};
29
30const READ_FILE_BUFSIZE: usize = 65536;
31const DOWNLOAD_BUFSIZE: usize = 8192;
32
33#[derive(Debug, Snafu)]
34pub enum BuilderError {
35 #[snafu(display("Download task {file_name} sources is empty"))]
36 EmptySource { file_name: String },
37 #[snafu(display("Not allow set illegal download threads: {count}"))]
38 IllegalDownloadThread { count: usize },
39}
40
41pub(crate) struct SingleDownloader<'a> {
42 client: &'a Client,
43 pub entry: &'a DownloadEntry,
44 total: usize,
45 retry_times: usize,
46 msg: Option<Cow<'static, str>>,
47 download_list_index: usize,
48 file_type: CompressFile,
49 timeout: Duration,
50}
51
52pub enum DownloadResult {
53 Success(SuccessSummary),
54 Failed { file_name: String },
55}
56
57#[derive(Debug)]
58pub struct SuccessSummary {
59 pub file_name: String,
60 pub index: usize,
61 pub wrote: bool,
62 pub url: String,
63}
64
65#[derive(Debug, Snafu)]
66pub enum SingleDownloadError {
67 #[snafu(display("Failed to set permission"))]
68 SetPermission { source: io::Error },
69 #[snafu(display("Failed to open file as rw mode"))]
70 OpenAsWriteMode { source: io::Error },
71 #[snafu(display("Failed to open file"))]
72 Open { source: io::Error },
73 #[snafu(display("Failed to create file"))]
74 Create { source: io::Error },
75 #[snafu(display("Failed to seek file"))]
76 Seek { source: io::Error },
77 #[snafu(display("Failed to write file"))]
78 Write { source: io::Error },
79 #[snafu(display("Failed to flush file"))]
80 Flush { source: io::Error },
81 #[snafu(display("Failed to Remove file"))]
82 Remove { source: io::Error },
83 #[snafu(display("Failed to create symlink"))]
84 CreateSymlink { source: io::Error },
85 #[snafu(display("Request Error"))]
86 ReqwestError { source: reqwest::Error },
87 #[snafu(display("Broken pipe"))]
88 BrokenPipe { source: io::Error },
89 #[snafu(display("Send request timeout"))]
90 SendRequestTimeout,
91 #[snafu(display("Download file timeout"))]
92 DownloadTimeout,
93 #[snafu(display("checksum mismatch"))]
94 ChecksumMismatch,
95}
96
97#[bon]
98impl<'a> SingleDownloader<'a> {
99 #[builder]
100 pub(crate) fn new(
101 client: &'a Client,
102 entry: &'a DownloadEntry,
103 total: usize,
104 retry_times: usize,
105 msg: Option<Cow<'static, str>>,
106 download_list_index: usize,
107 file_type: CompressFile,
108 timeout: Duration,
109 ) -> Result<SingleDownloader<'a>, BuilderError> {
110 if entry.source.is_empty() {
111 return Err(BuilderError::EmptySource {
112 file_name: entry.filename.to_string(),
113 });
114 }
115
116 Ok(Self {
117 client,
118 entry,
119 total,
120 retry_times,
121 msg,
122 download_list_index,
123 file_type,
124 timeout,
125 })
126 }
127
128 pub(crate) async fn try_download(self, callback: &impl AsyncFn(Event)) -> DownloadResult {
129 let mut sources = self.entry.source.clone();
130 assert!(!sources.is_empty());
131
132 sources.sort_unstable_by(|a, b| b.source_type.cmp(&a.source_type));
133
134 let msg = self.msg.as_deref().unwrap_or(&*self.entry.filename);
135
136 for (index, c) in sources.iter().enumerate() {
137 let download_res = match &c.source_type {
138 DownloadSourceType::Http { auth } => {
139 self.try_http_download(c, auth, callback).await
140 }
141 DownloadSourceType::Local(as_symlink) => {
142 self.download_local(c, *as_symlink, callback).await
143 }
144 };
145
146 match download_res {
147 Ok(b) => {
148 callback(Event::DownloadDone {
149 index: self.download_list_index,
150 msg: msg.into(),
151 })
152 .await;
153
154 return DownloadResult::Success(SuccessSummary {
155 file_name: self.entry.filename.to_string(),
156 url: c.url.clone(),
157 index: self.download_list_index,
158 wrote: b,
159 });
160 }
161 Err(e) => {
162 if index == sources.len() - 1 {
163 callback(Event::Failed {
164 file_name: self.entry.filename.clone(),
165 error: e,
166 })
167 .await;
168
169 return DownloadResult::Failed {
170 file_name: self.entry.filename.to_string(),
171 };
172 }
173
174 callback(Event::NextUrl {
175 index: self.download_list_index,
176 file_name: self.entry.filename.to_string(),
177 err: e,
178 })
179 .await;
180 }
181 }
182 }
183
184 unreachable!()
185 }
186
187 async fn try_http_download(
189 &self,
190 source: &DownloadSource,
191 auth: &Option<(String, String)>,
192 callback: &impl AsyncFn(Event),
193 ) -> Result<bool, SingleDownloadError> {
194 let mut times = 1;
195 let mut allow_resume = self.entry.allow_resume;
196 loop {
197 match self
198 .http_download(allow_resume, source, auth, callback)
199 .await
200 {
201 Ok(s) => {
202 return Ok(s);
203 }
204 Err(e) => match e {
205 SingleDownloadError::ChecksumMismatch => {
206 if self.retry_times == times {
207 return Err(e);
208 }
209
210 if times > 1 {
211 callback(Event::ChecksumMismatch {
212 index: self.download_list_index,
213 filename: self.entry.filename.to_string(),
214 times,
215 })
216 .await;
217 }
218
219 times += 1;
220 allow_resume = false;
221 }
222 e => {
223 return Err(e);
224 }
225 },
226 }
227 }
228 }
229
230 async fn http_download(
231 &self,
232 allow_resume: bool,
233 source: &DownloadSource,
234 auth: &Option<(String, String)>,
235 callback: &impl AsyncFn(Event),
236 ) -> Result<bool, SingleDownloadError> {
237 let file = self.entry.dir.join(&*self.entry.filename);
238 let file_exist = file.exists();
239 let mut file_size = file.metadata().ok().map(|x| x.len()).unwrap_or(0);
240
241 trace!("{} Exist file size is: {file_size}", file.display());
242 trace!("{} download url is: {}", file.display(), source.url);
243 let mut dest = None;
244 let mut validator = None;
245 let is_symlink = file.is_symlink();
246
247 debug!("file {} is symlink = {}", file.display(), is_symlink);
248
249 if is_symlink {
250 tokio::fs::remove_file(&file).await.context(RemoveSnafu)?;
251 }
252
253 if file_exist && !is_symlink {
256 trace!(
257 "File {} already exists, verifying checksum ...",
258 self.entry.filename
259 );
260
261 if let Some(hash) = &self.entry.hash {
262 trace!("Hash {} exists for the existing file.", hash);
263
264 let mut f = tokio::fs::OpenOptions::new()
265 .write(true)
266 .read(true)
267 .open(&file)
268 .await
269 .context(OpenAsWriteModeSnafu)?;
270
271 trace!("oma opened file {} read/write.", self.entry.filename);
272
273 let mut v = hash.get_validator();
274
275 trace!("Validator created.");
276
277 let (read, finish) = checksum(callback, &mut f, &mut v).await;
278
279 if finish {
280 trace!("Checksum {} matches, cache hit!", self.entry.filename);
281
282 callback(Event::ProgressDone(self.download_list_index)).await;
283
284 return Ok(false);
285 }
286
287 debug!(
288 "Checksum mismatch, initiating re-download for file {} ...",
289 self.entry.filename
290 );
291
292 if !allow_resume {
293 callback(Event::GlobalProgressSub(read)).await;
294 } else {
295 dest = Some(f);
296 validator = Some(v);
297 }
298 }
299 }
300
301 callback(Event::NewProgressSpinner {
302 index: self.download_list_index,
303 msg: self.download_message(),
304 total: self.total,
305 })
306 .await;
307
308 let req = self.build_request_with_basic_auth(&source.url, Method::HEAD, auth);
309 let resp_head = timeout(self.timeout, send_request(&source.url, req)).await;
310
311 callback(Event::ProgressDone(self.download_list_index)).await;
312
313 let resp_head = match resp_head {
314 Ok(Ok(resp)) => resp,
315 Ok(Err(e)) => {
316 return Err(SingleDownloadError::ReqwestError { source: e });
317 }
318 Err(_) => {
319 return Err(SingleDownloadError::SendRequestTimeout);
320 }
321 };
322
323 let head = resp_head.headers();
324
325 let server_can_resume = match head.get(ACCEPT_RANGES) {
329 Some(x) if x == "none" => false,
330 Some(_) => true,
331 None => false,
332 };
333
334 let total_size = {
336 let total_size = head
337 .get(CONTENT_LENGTH)
338 .map(|x| x.to_owned())
339 .unwrap_or(HeaderValue::from(0));
340
341 total_size
342 .to_str()
343 .ok()
344 .and_then(|x| x.parse::<u64>().ok())
345 .unwrap_or_default()
346 };
347
348 trace!("File total size is: {total_size}");
349
350 let mut req = self.build_request_with_basic_auth(&source.url, Method::GET, auth);
351
352 let mut resume = server_can_resume;
353
354 if !allow_resume {
355 resume = false;
356 }
357
358 if server_can_resume && allow_resume {
359 if total_size <= file_size {
362 trace!(
363 "Resetting size indicator for file to 0, as the file to download is larger that the one that already exists."
364 );
365 callback(Event::GlobalProgressSub(file_size)).await;
366 file_size = 0;
367 resume = false;
368 }
369
370 trace!("oma will set header range as bytes={file_size}-");
372 req = req.header(RANGE, format!("bytes={file_size}-"));
373 }
374
375 debug!("Can resume = {server_can_resume}, will resume = {resume}",);
376
377 let resp = timeout(self.timeout, req.send()).await;
378
379 callback(Event::ProgressDone(self.download_list_index)).await;
380
381 let resp = match resp {
382 Ok(resp) => resp
383 .and_then(|resp| resp.error_for_status())
384 .context(ReqwestSnafu)?,
385 Err(_) => return Err(SingleDownloadError::SendRequestTimeout),
386 };
387
388 callback(Event::NewProgressBar {
389 index: self.download_list_index,
390 msg: self.download_message(),
391 total: self.total,
392 size: total_size,
393 })
394 .await;
395
396 let source = resp;
397
398 let hash = &self.entry.hash;
399
400 let mut self_progress = 0;
401 let (mut dest, mut validator) = if !resume {
402 trace!(
404 "oma will open file {} in creation mode.",
405 self.entry.filename
406 );
407
408 let f = match File::create(&file).await {
409 Ok(f) => f,
410 Err(e) => {
411 callback(Event::ProgressDone(self.download_list_index)).await;
412 return Err(SingleDownloadError::Create { source: e });
413 }
414 };
415
416 if file_size > 0 {
417 callback(Event::GlobalProgressSub(file_size)).await;
418 }
419
420 if let Err(e) = f.set_len(0).await {
421 callback(Event::ProgressDone(self.download_list_index)).await;
422 return Err(SingleDownloadError::Create { source: e });
423 }
424
425 (f, hash.as_ref().map(|hash| hash.get_validator()))
426 } else if let Some((dest, validator)) = dest.zip(validator) {
427 callback(Event::ProgressInc {
428 index: self.download_list_index,
429 size: file_size,
430 })
431 .await;
432
433 trace!(
434 "oma will re-use opened destination file for {}",
435 self.entry.filename
436 );
437 self_progress += file_size;
438
439 (dest, Some(validator))
440 } else {
441 trace!(
442 "oma will open file {} in creation mode.",
443 self.entry.filename
444 );
445
446 let f = match File::create(&file).await {
447 Ok(f) => f,
448 Err(e) => {
449 callback(Event::ProgressDone(self.download_list_index)).await;
450 return Err(SingleDownloadError::Create { source: e });
451 }
452 };
453
454 if let Err(e) = f.set_len(0).await {
455 callback(Event::ProgressDone(self.download_list_index)).await;
456 return Err(SingleDownloadError::Create { source: e });
457 }
458
459 (f, hash.as_ref().map(|hash| hash.get_validator()))
460 };
461
462 if server_can_resume && allow_resume {
463 trace!("oma will seek to end-of-file for {}", self.entry.filename);
465 if let Err(e) = dest.seek(SeekFrom::End(0)).await {
466 callback(Event::ProgressDone(self.download_list_index)).await;
467 return Err(SingleDownloadError::Seek { source: e });
468 }
469 }
470 trace!("Starting download!");
472
473 let bytes_stream = source
474 .bytes_stream()
475 .map_err(io::Error::other)
476 .into_async_read();
477
478 let reader: &mut (dyn AsyncRead + Unpin + Send) = match self.file_type {
479 CompressFile::Xz => &mut XzDecoder::new(BufReader::new(bytes_stream)),
480 CompressFile::Gzip => &mut GzipDecoder::new(BufReader::new(bytes_stream)),
481 CompressFile::Bz2 => &mut BzDecoder::new(BufReader::new(bytes_stream)),
482 CompressFile::Zstd => &mut ZstdDecoder::new(BufReader::new(bytes_stream)),
483 CompressFile::Lzma => &mut LzmaDecoder::new(BufReader::new(bytes_stream)),
484 CompressFile::Lz4 => &mut Lz4Decoder::new(BufReader::new(bytes_stream)),
485 CompressFile::Nothing => &mut BufReader::new(bytes_stream),
486 };
487
488 let mut reader = reader.compat();
489
490 let mut buf = vec![0u8; DOWNLOAD_BUFSIZE];
491
492 loop {
493 let size = match timeout(self.timeout, reader.read(&mut buf[..])).await {
494 Ok(Ok(size)) => size,
495 Ok(Err(e)) => {
496 callback(Event::ProgressDone(self.download_list_index)).await;
497 return Err(SingleDownloadError::BrokenPipe { source: e });
498 }
499 Err(_) => {
500 callback(Event::ProgressDone(self.download_list_index)).await;
501 return Err(SingleDownloadError::DownloadTimeout);
502 }
503 };
504
505 if size == 0 {
506 break;
507 }
508
509 if let Err(e) = dest.write_all(&buf[..size]).await {
510 callback(Event::ProgressDone(self.download_list_index)).await;
511 return Err(SingleDownloadError::Write { source: e });
512 }
513
514 callback(Event::ProgressInc {
515 index: self.download_list_index,
516 size: size as u64,
517 })
518 .await;
519
520 self_progress += size as u64;
521
522 callback(Event::GlobalProgressAdd(size as u64)).await;
523
524 if let Some(ref mut v) = validator {
525 v.update(&buf[..size]);
526 }
527 }
528
529 trace!("Download complete! Shutting down destination file stream ...");
531 if let Err(e) = dest.shutdown().await {
532 callback(Event::ProgressDone(self.download_list_index)).await;
533 return Err(SingleDownloadError::Flush { source: e });
534 }
535
536 if let Some(v) = validator {
538 if !v.finish() {
539 debug!("Checksum mismatch for file {}", self.entry.filename);
540 trace!("{self_progress}");
541
542 callback(Event::GlobalProgressSub(self_progress)).await;
543 callback(Event::ProgressDone(self.download_list_index)).await;
544 return Err(SingleDownloadError::ChecksumMismatch);
545 }
546
547 trace!(
548 "Checksum verification successful for file {}",
549 self.entry.filename
550 );
551 }
552
553 callback(Event::ProgressDone(self.download_list_index)).await;
554
555 Ok(true)
556 }
557
558 fn build_request_with_basic_auth(
559 &self,
560 url: &str,
561 method: Method,
562 auth: &Option<(String, String)>,
563 ) -> RequestBuilder {
564 let mut req = self.client.request(method, url);
565
566 if let Some((user, password)) = auth {
567 trace!("Authenticating as user: {} ...", user);
568 req = req.basic_auth(user, Some(password));
569 }
570
571 req
572 }
573
574 async fn download_local(
576 &self,
577 source: &DownloadSource,
578 as_symlink: bool,
579 callback: &impl AsyncFn(Event),
580 ) -> Result<bool, SingleDownloadError> {
581 debug!("{:?}", self.entry);
582
583 let url = source.url.strip_prefix("file:").unwrap();
584
585 let url_path = Path::new(url);
586
587 let total_size = tokio::fs::metadata(url_path)
588 .await
589 .context(OpenSnafu)?
590 .len();
591
592 let file = self.entry.dir.join(&*self.entry.filename);
593 if file.is_symlink() || (as_symlink && file.is_file()) {
594 tokio::fs::remove_file(&file).await.context(RemoveSnafu)?;
595 }
596
597 if as_symlink {
598 if let Some(hash) = &self.entry.hash {
599 self.checksum_local(callback, url_path, hash).await?;
600 }
601
602 tokio::fs::symlink(url_path, file)
603 .await
604 .context(CreateSymlinkSnafu)?;
605
606 return Ok(true);
607 }
608
609 callback(Event::NewProgressBar {
610 index: self.download_list_index,
611 total: self.total,
612 msg: self.download_message(),
613 size: total_size,
614 })
615 .await;
616
617 trace!("Path for file: {}", url_path.display());
618
619 let from = File::open(&url_path).await.context(CreateSnafu)?;
620 let from = tokio::io::BufReader::new(from).compat();
621
622 trace!("Successfully opened file: {}", url_path.display());
623
624 let mut to = File::create(self.entry.dir.join(&*self.entry.filename))
625 .await
626 .context(CreateSnafu)?;
627
628 let reader: &mut (dyn AsyncRead + Unpin + Send) = match self.file_type {
629 CompressFile::Xz => &mut XzDecoder::new(BufReader::new(from)),
630 CompressFile::Gzip => &mut GzipDecoder::new(BufReader::new(from)),
631 CompressFile::Bz2 => &mut BzDecoder::new(BufReader::new(from)),
632 CompressFile::Zstd => &mut ZstdDecoder::new(BufReader::new(from)),
633 CompressFile::Lzma => &mut LzmaDecoder::new(BufReader::new(from)),
634 CompressFile::Lz4 => &mut Lz4Decoder::new(BufReader::new(from)),
635 CompressFile::Nothing => &mut BufReader::new(from),
636 };
637
638 let mut reader = reader.compat();
639
640 trace!(
641 "Successfully created file: {}",
642 self.entry.dir.join(&*self.entry.filename).display()
643 );
644
645 let mut v = self.entry.hash.as_ref().map(|v| v.get_validator());
646
647 let mut buf = vec![0u8; READ_FILE_BUFSIZE];
648 let mut self_progress = 0;
649
650 loop {
651 let size = reader.read(&mut buf[..]).await.context(BrokenPipeSnafu)?;
652 self_progress += size;
653
654 if size == 0 {
655 break;
656 }
657
658 to.write_all(&buf[..size]).await.context(WriteSnafu)?;
659
660 callback(Event::ProgressInc {
661 index: self.download_list_index,
662 size: size as u64,
663 })
664 .await;
665
666 if let Some(ref mut v) = v {
667 v.update(&buf[..size]);
668 }
669
670 callback(Event::GlobalProgressAdd(size as u64)).await;
671 }
672
673 if v.is_some_and(|v| !v.finish()) {
674 callback(Event::GlobalProgressSub(self_progress as u64)).await;
675 callback(Event::ProgressDone(self.download_list_index)).await;
676 return Err(SingleDownloadError::ChecksumMismatch);
677 }
678
679 callback(Event::ProgressDone(self.download_list_index)).await;
680
681 Ok(true)
682 }
683
684 async fn checksum_local(
685 &self,
686 callback: &impl AsyncFn(Event),
687 url_path: &Path,
688 hash: &crate::checksum::Checksum,
689 ) -> Result<(), SingleDownloadError> {
690 let mut f = fs::File::open(url_path).await.context(OpenSnafu)?;
691 let (size, finish) = checksum(callback, &mut f, &mut hash.get_validator()).await;
692
693 if !finish {
694 callback(Event::GlobalProgressSub(size)).await;
695 callback(Event::ProgressDone(self.download_list_index)).await;
696 return Err(SingleDownloadError::ChecksumMismatch);
697 }
698
699 Ok(())
700 }
701
702 fn download_message(&self) -> String {
703 self.msg
704 .as_deref()
705 .unwrap_or(&self.entry.filename)
706 .to_string()
707 }
708}
709
710async fn checksum(
711 callback: &impl AsyncFn(Event),
712 f: &mut File,
713 v: &mut ChecksumValidator,
714) -> (u64, bool) {
715 let mut reader = tokio::io::BufReader::with_capacity(READ_FILE_BUFSIZE, f);
716
717 let mut read = 0;
718
719 loop {
720 let buffer = match reader.fill_buf().await {
721 Ok([]) => break,
722 Ok(buffer) => buffer,
723 Err(e) => {
724 debug!("Error while reading file: {e}");
725 break;
726 }
727 };
728
729 v.update(buffer);
730
731 callback(Event::GlobalProgressAdd(buffer.len() as u64)).await;
732 read += buffer.len() as u64;
733 let len = buffer.len();
734
735 reader.consume(len);
736 }
737
738 (read, v.finish())
739}