use async_trait::async_trait;
use axum::http::StatusCode;
use futures_util::StreamExt;
use indicatif::{MultiProgress, ProgressBar, ProgressDrawTarget, ProgressState, ProgressStyle};
use log::info;
use std::collections::VecDeque;
use std::error;
use std::fmt::{Display, Write};
use std::path::PathBuf;
use tokio::io::AsyncWriteExt;
use crate::storage::Storage;
#[async_trait]
pub trait StorageExt: Storage {
async fn download_many<
T1: Display + Send + Sync + 'static,
T2: Display + Send + Sync + 'static,
>(
&self,
remote_file_spec: Vec<(T1, T2)>,
force_download: bool,
on_download_msg: impl Display + Send + Sync + 'static,
on_finished_msg: impl Display + Send + Sync + 'static,
) -> anyhow::Result<VecDeque<PathBuf>> {
let mut has_to_download = force_download;
for (_, local_filename) in remote_file_spec.iter() {
has_to_download = has_to_download || !self.exists(&local_filename.to_string()).await?
}
if has_to_download {
info!("{on_download_msg}");
}
let m = MultiProgress::new();
let mut tasks = vec![];
for (remote_file, local_filename) in remote_file_spec {
let remote_file = remote_file.to_string();
let local_filename = local_filename.to_string();
let bar = m.add(download_bar(&local_filename));
let this = self.clone();
tasks.push(tokio::spawn(async move {
this.fetch_remote_data_file(
&remote_file,
&local_filename,
force_download,
Box::new(move |el, t| {
bar.set_length(t as u64);
bar.set_position(el as u64);
}),
)
.await
}));
}
let mut results = VecDeque::new();
for task in tasks {
results.push_back(task.await??);
}
m.clear()?;
if has_to_download {
info!("{on_finished_msg}");
}
Ok(results)
}
async fn fetch_remote_data_file<Cb: Fn(usize, usize) + Send + Sync + 'static>(
&self,
url: &str,
local_file: &str,
force: bool,
cbk: Cb,
) -> std::io::Result<PathBuf> {
if self.exists(local_file).await? && !force {
return Ok(self.path_buf(local_file));
}
let resp = reqwest::get(url).await.map_err(io_err)?;
let status_code = resp.status();
if status_code != StatusCode::OK {
return Err(io_err(format!(
"Error downloading {url}. Invalid status code {status_code}"
)));
}
let total_bytes = resp.content_length().unwrap_or_default() as usize;
let temp_file = format!("{local_file}.temp");
let mut file = self.create(&temp_file).await?;
let mut stream = resp.bytes_stream();
let mut downloaded_bytes = 0;
while let Some(item) = stream.next().await {
match item {
Ok(chunk) => {
downloaded_bytes += chunk.len();
cbk(downloaded_bytes, total_bytes);
file.write_all(&chunk).await?
}
Err(err) => return Err(io_err(err)),
}
}
self.mv(&temp_file, local_file).await?;
Ok(self.path_buf(local_file))
}
}
pub fn download_bar(file: &str) -> ProgressBar {
const NAME_LEN: usize = 32;
const NAME_SHIFT_INTERVAL: usize = 300;
let pb = ProgressBar::with_draw_target(None, ProgressDrawTarget::stderr());
let file_string = file.to_string();
pb.set_style(
ProgressStyle::with_template(
"{file:>32} {spinner:.green} [{wide_bar:.cyan/blue}] {bytes}/{total_bytes}",
)
.unwrap()
.with_key("file", move |state: &ProgressState, w: &mut dyn Write| {
if file_string.len() > NAME_LEN {
let el = state.elapsed().as_millis() as usize;
let offset = (el / NAME_SHIFT_INTERVAL) % (file_string.len() - NAME_LEN + 1);
let view = &file_string[offset..offset + NAME_LEN];
write!(w, "{view: >w$}", w = NAME_LEN).unwrap();
} else {
write!(w, "{file_string: >w$}", w = NAME_LEN).unwrap();
}
})
.with_key("eta", |state: &ProgressState, w: &mut dyn Write| {
write!(w, "{:.1}s", state.eta().as_secs_f64()).unwrap()
})
.progress_chars("#>-"),
);
pb
}
impl<T: Storage + 'static> StorageExt for T {}
fn io_err<E>(e: E) -> std::io::Error
where
E: Into<Box<dyn error::Error + Send + Sync>>,
{
std::io::Error::new(std::io::ErrorKind::Other, e)
}
#[cfg(test)]
mod tests {
use rand::distributions::Alphanumeric;
use rand::{thread_rng, Rng};
use std::path::Path;
use std::time::SystemTime;
use crate::storage::AppFs;
use crate::storage_ext::StorageExt;
fn rand_string() -> String {
thread_rng()
.sample_iter(&Alphanumeric)
.take(7)
.map(char::from)
.collect()
}
#[tokio::test]
async fn downloads_remote_file() -> std::io::Result<()> {
let remote_file = "https://raw.githubusercontent.com/seanmonstar/reqwest/master/README.md";
let file_name = format!("foo/{}.txt", rand_string());
let app_fs = AppFs::new(Path::new("/tmp/downloads_remote_file_test"));
let time = SystemTime::now();
app_fs
.fetch_remote_data_file(remote_file, &file_name, false, |_, _| {})
.await?;
let download_elapsed = SystemTime::now().duration_since(time).unwrap().as_micros();
let time = SystemTime::now();
app_fs
.fetch_remote_data_file(remote_file, &file_name, false, |_, _| {})
.await?;
let cached_elapsed = SystemTime::now().duration_since(time).unwrap().as_micros();
assert!(download_elapsed / cached_elapsed > 10);
Ok(())
}
}