use crate::download::{Download, Status, Summary};
use futures::stream::{self, StreamExt};
use indicatif::{MultiProgress, ProgressBar, ProgressDrawTarget, ProgressStyle};
use reqwest::{
header::{HeaderMap, HeaderValue, IntoHeaderName, RANGE},
StatusCode,
};
use reqwest_middleware::{ClientBuilder, ClientWithMiddleware};
use reqwest_retry::{policies::ExponentialBackoff, RetryTransientMiddleware};
use reqwest_tracing::TracingMiddleware;
use std::{fs, path::PathBuf, sync::Arc};
use tokio::{fs::OpenOptions, io::AsyncWriteExt};
use tracing::debug;
pub struct TimeTrace;
#[derive(Debug, Clone)]
pub struct Downloader {
directory: PathBuf,
retries: u32,
concurrent_downloads: usize,
style_options: StyleOptions,
resumable: bool,
headers: Option<HeaderMap>,
}
impl Downloader {
const DEFAULT_RETRIES: u32 = 3;
const DEFAULT_CONCURRENT_DOWNLOADS: usize = 32;
pub async fn download(&self, downloads: &[Download]) -> Vec<Summary> {
self.download_inner(downloads, None).await
}
pub async fn download_with_proxy(
&self,
downloads: &[Download],
proxy: reqwest::Proxy,
) -> Vec<Summary> {
self.download_inner(downloads, Some(proxy)).await
}
pub async fn download_inner(
&self,
downloads: &[Download],
proxy: Option<reqwest::Proxy>,
) -> Vec<Summary> {
let retry_policy = ExponentialBackoff::builder().build_with_max_retries(self.retries);
let mut inner_client_builder = reqwest::Client::builder();
if let Some(proxy) = proxy {
inner_client_builder = inner_client_builder.proxy(proxy);
}
if let Some(headers) = &self.headers {
inner_client_builder = inner_client_builder.default_headers(headers.clone());
}
let inner_client = inner_client_builder.build().unwrap();
let client = ClientBuilder::new(inner_client)
.with(TracingMiddleware::default())
.with(RetryTransientMiddleware::new_with_policy(retry_policy))
.build();
let multi = match self.style_options.clone().is_enabled() {
true => Arc::new(MultiProgress::new()),
false => Arc::new(MultiProgress::with_draw_target(ProgressDrawTarget::hidden())),
};
let main = Arc::new(
multi.add(
self.style_options
.main
.clone()
.to_progress_bar(downloads.len() as u64),
),
);
main.tick();
let summaries = stream::iter(downloads)
.map(|d| self.fetch(&client, d, multi.clone(), main.clone()))
.buffer_unordered(self.concurrent_downloads)
.collect::<Vec<_>>()
.await;
if self.style_options.main.clear {
main.finish_and_clear();
} else {
main.finish();
}
summaries
}
async fn fetch(
&self,
client: &ClientWithMiddleware,
download: &Download,
multi: Arc<MultiProgress>,
main: Arc<ProgressBar>,
) -> Summary {
let mut size_on_disk: u64 = 0;
let mut can_resume = false;
let output = self.directory.join(&download.filename);
let mut summary = Summary::new(
download.clone(),
StatusCode::BAD_REQUEST,
size_on_disk,
can_resume,
);
let file_exist = output.exists();
if self.resumable {
can_resume = match download.is_resumable(client).await {
Ok(r) => r,
Err(e) => {
return summary.fail(e);
}
};
if file_exist {
debug!("A file with the same name already exists at the destination.");
size_on_disk = match output.metadata() {
Ok(m) => m.len(),
Err(e) => {
return summary.fail(e);
}
};
}
summary.set_resumable(can_resume);
}
let content_length = match download.content_length(client).await {
Ok(l) => l,
Err(e) => {
if can_resume && file_exist {
return summary.fail(e);
}
debug!("Error retrieving content length {e}");
None
}
};
debug!("Fetching {}", &download.url);
let mut req = client.get(download.url.clone());
if self.resumable && can_resume {
req = req.header(RANGE, format!("bytes={size_on_disk}-"));
}
if let Some(ref h) = self.headers {
req = req.headers(h.to_owned());
}
let res = match req.send().await {
Ok(res) => res,
Err(e) => {
return summary.fail(e);
}
};
if let Some(content_length) = content_length {
if content_length == size_on_disk {
return summary.with_status(Status::Skipped(
"the file was already fully downloaded".into(),
));
}
}
match res.error_for_status_ref() {
Ok(_res) => (),
Err(e) => return summary.fail(e),
};
let size = content_length.unwrap_or_default() + size_on_disk;
let status = res.status();
summary = Summary::new(download.clone(), status, size, can_resume);
if size_on_disk > 0 && size == size_on_disk {
return summary.with_status(Status::Skipped(
"the file was already fully downloaded".into(),
));
}
let pb = multi.add(
self.style_options
.child
.clone()
.to_progress_bar(size)
.with_position(size_on_disk),
);
let output_dir = output.parent().unwrap_or(&output);
debug!("Creating destination directory {:?}", output_dir);
match fs::create_dir_all(output_dir) {
Ok(_res) => (),
Err(e) => {
return summary.fail(e);
}
};
debug!("Creating destination file {:?}", &output);
let mut file = match OpenOptions::new()
.create(true)
.write(true)
.append(can_resume)
.open(output)
.await
{
Ok(file) => file,
Err(e) => {
return summary.fail(e);
}
};
let mut final_size = size_on_disk;
debug!("Retrieving chunks...");
let mut stream = res.bytes_stream();
while let Some(item) = stream.next().await {
let mut chunk = match item {
Ok(chunk) => chunk,
Err(e) => {
return summary.fail(e);
}
};
let chunk_size = chunk.len() as u64;
final_size += chunk_size;
pb.inc(chunk_size);
match file.write_all_buf(&mut chunk).await {
Ok(_res) => (),
Err(e) => {
return summary.fail(e);
}
};
}
if self.style_options.child.clear {
pb.finish_and_clear();
} else {
pb.finish();
}
main.inc(1);
let summary = Summary::new(download.clone(), status, final_size, can_resume);
summary.with_status(Status::Success)
}
}
pub struct DownloaderBuilder(Downloader);
impl DownloaderBuilder {
pub fn new() -> Self {
DownloaderBuilder::default()
}
pub fn hidden() -> Self {
let d = DownloaderBuilder::default();
d.style_options(StyleOptions::new(
ProgressBarOpts::hidden(),
ProgressBarOpts::hidden(),
))
}
pub fn directory(mut self, directory: PathBuf) -> Self {
self.0.directory = directory;
self
}
pub fn retries(mut self, retries: u32) -> Self {
self.0.retries = retries;
self
}
pub fn concurrent_downloads(mut self, concurrent_downloads: usize) -> Self {
self.0.concurrent_downloads = concurrent_downloads;
self
}
pub fn style_options(mut self, style_options: StyleOptions) -> Self {
self.0.style_options = style_options;
self
}
fn new_header(&self) -> HeaderMap {
match self.0.headers {
Some(ref h) => h.to_owned(),
_ => HeaderMap::new(),
}
}
pub fn headers(mut self, headers: HeaderMap) -> Self {
let mut new = self.new_header();
new.extend(headers);
self.0.headers = Some(new);
self
}
pub fn header<K: IntoHeaderName>(mut self, name: K, value: HeaderValue) -> Self {
let mut new = self.new_header();
new.insert(name, value);
self.0.headers = Some(new);
self
}
pub fn build(self) -> Downloader {
Downloader {
directory: self.0.directory,
retries: self.0.retries,
concurrent_downloads: self.0.concurrent_downloads,
style_options: self.0.style_options,
resumable: self.0.resumable,
headers: self.0.headers,
}
}
}
impl Default for DownloaderBuilder {
fn default() -> Self {
Self(Downloader {
directory: std::env::current_dir().unwrap_or_default(),
retries: Downloader::DEFAULT_RETRIES,
concurrent_downloads: Downloader::DEFAULT_CONCURRENT_DOWNLOADS,
style_options: StyleOptions::default(),
resumable: true,
headers: None,
})
}
}
#[derive(Debug, Clone)]
pub struct StyleOptions {
main: ProgressBarOpts,
child: ProgressBarOpts,
}
impl Default for StyleOptions {
fn default() -> Self {
Self {
main: ProgressBarOpts {
template: Some(ProgressBarOpts::TEMPLATE_BAR_WITH_POSITION.into()),
progress_chars: Some(ProgressBarOpts::CHARS_FINE.into()),
enabled: true,
clear: false,
},
child: ProgressBarOpts::with_pip_style(),
}
}
}
impl StyleOptions {
pub fn new(main: ProgressBarOpts, child: ProgressBarOpts) -> Self {
Self { main, child }
}
pub fn set_main(&mut self, main: ProgressBarOpts) {
self.main = main;
}
pub fn set_child(&mut self, child: ProgressBarOpts) {
self.child = child;
}
pub fn is_enabled(self) -> bool {
self.main.enabled || self.child.enabled
}
}
#[derive(Debug, Clone)]
pub struct ProgressBarOpts {
template: Option<String>,
progress_chars: Option<String>,
enabled: bool,
clear: bool,
}
impl Default for ProgressBarOpts {
fn default() -> Self {
Self {
template: None,
progress_chars: None,
enabled: true,
clear: true,
}
}
}
impl ProgressBarOpts {
pub const TEMPLATE_BAR_WITH_POSITION: &'static str =
"{bar:40.blue} {pos:>}/{len} ({percent}%) eta {eta_precise:.blue}";
pub const TEMPLATE_PIP: &'static str =
"{bar:40.green/black} {bytes:>11.green}/{total_bytes:<11.green} {bytes_per_sec:>13.red} eta {eta:.blue}";
pub const CHARS_BLOCKY: &'static str = "█▛▌▖ ";
pub const CHARS_FADE_IN: &'static str = "█▓▒░ ";
pub const CHARS_FINE: &'static str = "█▉▊▋▌▍▎▏ ";
pub const CHARS_LINE: &'static str = "━╾╴─";
pub const CHARS_ROUGH: &'static str = "█ ";
pub const CHARS_VERTICAL: &'static str = "█▇▆▅▄▃▂▁ ";
pub fn new(
template: Option<String>,
progress_chars: Option<String>,
enabled: bool,
clear: bool,
) -> Self {
Self {
template,
progress_chars,
enabled,
clear,
}
}
pub fn to_progress_style(self) -> ProgressStyle {
let mut style = ProgressStyle::default_bar();
if let Some(template) = self.template {
style = style.template(&template).unwrap();
}
if let Some(progress_chars) = self.progress_chars {
style = style.progress_chars(&progress_chars);
}
style
}
pub fn to_progress_bar(self, len: u64) -> ProgressBar {
if !self.enabled {
return ProgressBar::hidden();
}
let style = self.to_progress_style();
ProgressBar::new(len).with_style(style)
}
pub fn with_pip_style() -> Self {
Self {
template: Some(ProgressBarOpts::TEMPLATE_PIP.into()),
progress_chars: Some(ProgressBarOpts::CHARS_LINE.into()),
enabled: true,
clear: true,
}
}
pub fn set_clear(&mut self, clear: bool) {
self.clear = clear;
}
pub fn hidden() -> Self {
Self {
enabled: false,
..ProgressBarOpts::default()
}
}
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn test_builder_defaults() {
let d = DownloaderBuilder::new().build();
assert_eq!(d.retries, Downloader::DEFAULT_RETRIES);
assert_eq!(
d.concurrent_downloads,
Downloader::DEFAULT_CONCURRENT_DOWNLOADS
);
}
}