use crate::error::{DownloadError, DownloadResult};
use futures::{stream, StreamExt, TryStreamExt};
use sha1::Digest;
use sha1::Sha1;
use std::cmp::min;
use std::path::PathBuf;
use time::OffsetDateTime;
use tokio::fs::create_dir_all;
use tokio::fs::File;
use tokio::io::AsyncWriteExt;
use tokio::sync::mpsc::channel;
use tokio::sync::mpsc::Receiver;
use tokio::sync::mpsc::Sender;
use tokio::task;
#[derive(Clone, Debug)]
pub struct Download {
pub url: String,
pub file: PathBuf,
pub sha1: Option<Vec<u8>>,
}
impl Download {
#[instrument(
name = "download_file",
level = "trace",
skip_all,
fields(
url = self.url,
file = %self.file.to_string_lossy(),
current_file,
total_files,
)
)]
pub async fn download(
&self,
client: reqwest::Client,
mut progress_sender: Option<Sender<DownloadProgress>>,
current_file: usize,
total_files: usize,
) -> DownloadResult<()> {
if let Some(parent) = self.file.parent() {
trace!("Creating parent folder");
create_dir_all(parent).await?;
}
let response = client.get(&self.url).send().await?.error_for_status()?;
trace!("Sending request to get content-length");
let total_bytes = response
.content_length()
.ok_or(DownloadError::NoContentLength)?;
let mut progress = DownloadProgress {
url: self.url.clone(),
file: self.file.clone(),
current_file,
total_files,
downloaded_bytes: 0,
total_bytes,
};
trace!("Send initial progress");
progress.send(&mut progress_sender).await;
let mut file = File::create(&self.file).await?;
let mut stream = response.bytes_stream();
trace!("Writing content to disk.");
let mut last_chunk_time = OffsetDateTime::now_utc().unix_timestamp_nanos();
while let Some(item) = stream.next().await {
let chunk = item?;
file.write_all(&chunk).await?;
progress.downloaded_bytes = min(
progress.downloaded_bytes + (chunk.len() as u64),
progress.total_bytes,
);
let now = OffsetDateTime::now_utc().unix_timestamp_nanos();
if now - last_chunk_time > 500000000 {
last_chunk_time = now;
trace!("Send progress");
progress.send(&mut progress_sender).await;
}
}
file.sync_all().await?;
trace!("Send progress");
progress.downloaded_bytes = progress.total_bytes;
progress.send(&mut progress_sender).await;
Ok(())
}
#[instrument(
name = "verify_file",
level = "trace",
skip_all,
fields(
url = self.url,
file = %self.file.to_string_lossy(),
)
)]
pub async fn verify(&self) -> DownloadResult<bool> {
let this = self.clone();
task::spawn_blocking(move || this.blocking_verify())
.await
.unwrap()
}
fn blocking_verify(self) -> DownloadResult<bool> {
if let Some(sha) = self.sha1 {
if !self.file.is_file() {
return Ok(false);
}
let mut file = std::fs::File::open(self.file)?;
let mut hasher = Sha1::new();
std::io::copy(&mut file, &mut hasher)?;
let hash = hasher.finalize().to_vec();
Ok(sha == hash)
} else {
Ok(self.file.is_file())
}
}
}
#[derive(Clone, Debug)]
pub struct DownloadProgress {
pub url: String,
pub file: PathBuf,
pub current_file: usize,
pub total_files: usize,
pub downloaded_bytes: u64,
pub total_bytes: u64,
}
impl DownloadProgress {
pub(crate) async fn send(&self, sender: &mut Option<Sender<Self>>) {
if let Some(s) = &sender {
if s.send(self.clone()).await.is_err() {
trace!("Sending failed because receiver is no longer around. Dropping sender...");
*sender = None;
}
}
}
}
#[instrument(
name = "download",
level = "trace",
skip_all,
fields(parallel_downloads, verify)
)]
pub async fn download(
downloads: Vec<Download>,
progress_sender: Option<Sender<DownloadProgress>>,
parallel_downloads: u16,
retries: u16,
verify: bool,
) -> DownloadResult<()> {
let client = reqwest::Client::new();
let total = downloads.len();
let downloads = downloads.into_iter().enumerate();
stream::iter(downloads)
.map(move |(n, d)| {
let client = client.clone();
let mut sender = progress_sender.clone();
async move {
for x in 0..=retries {
if x > 0 {
trace!("Retrying to download file for the {}th time", x);
}
if !d.file.exists() || x > 0 {
trace!("File does not exist or retrying");
d.download(client.clone(), sender.clone(), n, total).await?;
} else {
trace!("File does exist, sending progress update");
let file = File::open(&d.file).await?;
let size = file.metadata().await?.len();
DownloadProgress {
url: d.url.clone(),
file: d.file.clone(),
current_file: n,
total_files: total,
downloaded_bytes: size,
total_bytes: size,
}
.send(&mut sender)
.await;
}
if verify && !d.verify().await? {
if x == retries {
debug!("Verification of file failed");
return Err(DownloadError::ChecksumMismatch);
} else {
debug!("Verification of file failed, retrying...");
}
} else {
return Ok(());
}
}
Ok(())
}
})
.buffer_unordered(parallel_downloads as usize)
.try_collect::<()>()
.await?;
Ok(())
}
pub fn download_progress_channel(
buffer: usize,
) -> (Sender<DownloadProgress>, Receiver<DownloadProgress>) {
channel(buffer)
}