use futures::{Stream, StreamExt, channel::mpsc};
#[cfg(feature = "json")]
use serde::de::DeserializeOwned;
use sha2::{Digest as _, Sha256};
use std::{
io,
path::{Path, PathBuf},
time::Duration,
};
use tokio::io::{AsyncReadExt, AsyncSeekExt, AsyncWriteExt};
pub use reqwest::IntoUrl;
#[derive(Debug, Clone)]
pub struct Downloader {
client: reqwest::Client,
cache_dir: PathBuf,
}
impl Downloader {
pub fn new<P: Into<PathBuf>>(cache_dir: P) -> io::Result<Self> {
let cache_dir = cache_dir.into();
if !cache_dir.exists() {
let _ = std::fs::create_dir_all(&cache_dir);
}
if cache_dir.exists() && !cache_dir.is_dir() {
return Err(io::Error::new(
io::ErrorKind::NotADirectory,
"cache_dir should be a directory",
));
}
let client = reqwest::Client::builder()
.connect_timeout(Duration::from_secs(10))
.build()
.expect("Unsupported OS");
Ok(Self { client, cache_dir })
}
pub async fn check_cache_from_sha(&self, sha256: [u8; 32]) -> Option<PathBuf> {
let file_path = self.path_from_sha(sha256);
if file_path.exists() {
if let Ok(hash) = sha256_from_path(&file_path).await {
if hash == sha256 {
return Some(file_path);
}
}
let _ = tokio::fs::remove_file(&file_path).await;
}
None
}
pub fn check_cache_from_url<U: reqwest::IntoUrl>(&self, url: U) -> Option<PathBuf> {
let file_path = self.path_from_url(url.as_str());
if file_path.exists() {
Some(file_path)
} else {
None
}
}
#[cfg(feature = "json")]
pub async fn download_json_no_cache<T, U>(&self, url: U) -> io::Result<T>
where
T: DeserializeOwned,
U: reqwest::IntoUrl,
{
self.client
.get(url)
.send()
.await
.map_err(io::Error::other)?
.json()
.await
.map_err(io::Error::other)
}
pub async fn download<U: reqwest::IntoUrl>(
&self,
url: U,
mut chan: Option<mpsc::Sender<f32>>,
) -> io::Result<PathBuf> {
let url = url.into_url().map_err(io::Error::other)?;
if let Some(p) = self.check_cache_from_url(url.clone()) {
return Ok(p);
}
let file_path = self.path_from_url(url.as_str());
chan_send(chan.as_mut(), 0.0);
let mut cur_pos = 0;
let mut tmp_file = AsyncTempFile::new()?;
{
let mut tmp_file = tokio::io::BufWriter::new(tmp_file.as_mut());
let response = self
.client
.get(url)
.send()
.await
.map_err(io::Error::other)?;
let response_size = response.content_length();
let mut response_stream = response.bytes_stream();
let response_size = match response_size {
Some(x) => x as usize,
None => response_stream.size_hint().0,
};
while let Some(x) = response_stream.next().await {
let mut data = x.map_err(io::Error::other)?;
cur_pos += data.len();
tmp_file.write_all_buf(&mut data).await?;
chan_send(chan.as_mut(), (cur_pos as f32) / (response_size as f32));
}
}
tmp_file.persist(&file_path).await?;
Ok(file_path)
}
pub async fn download_with_sha<U: reqwest::IntoUrl>(
&self,
url: U,
sha256: [u8; 32],
mut chan: Option<mpsc::Sender<f32>>,
) -> io::Result<PathBuf> {
let url = url.into_url().map_err(io::Error::other)?;
tracing::info!(
"Download {:?} with sha256: {:?}",
url,
const_hex::encode(sha256)
);
if let Some(p) = self.check_cache_from_sha(sha256).await {
return Ok(p);
}
let file_path = self.path_from_sha(sha256);
chan_send(chan.as_mut(), 0.0);
let mut tmp_file = AsyncTempFile::new()?;
{
let mut tmp_file = tokio::io::BufWriter::new(tmp_file.as_mut());
let response = self
.client
.get(url)
.send()
.await
.map_err(io::Error::other)?;
let mut cur_pos = 0;
let response_size = response.content_length();
let mut response_stream = response.bytes_stream();
let response_size = match response_size {
Some(x) => x as usize,
None => response_stream.size_hint().0,
};
let mut hasher = Sha256::new();
while let Some(x) = response_stream.next().await {
let mut data = x.map_err(io::Error::other)?;
cur_pos += data.len();
hasher.update(&data);
tmp_file.write_all_buf(&mut data).await?;
chan_send(chan.as_mut(), (cur_pos as f32) / (response_size as f32));
}
let hash: [u8; 32] = hasher
.finalize()
.as_slice()
.try_into()
.expect("SHA-256 is 32 bytes");
if hash != sha256 {
tracing::warn!("{hash:?} != {sha256:?}");
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"Invalid SHA256",
));
}
}
tmp_file.persist(&file_path).await?;
Ok(file_path)
}
fn path_from_url(&self, url: &str) -> PathBuf {
let file_name: [u8; 32] = Sha256::new()
.chain_update(url)
.finalize()
.as_slice()
.try_into()
.expect("SHA-256 is 32 bytes");
self.path_from_sha(file_name)
}
fn path_from_sha(&self, sha256: [u8; 32]) -> PathBuf {
let file_name = const_hex::encode(sha256);
self.cache_dir.join(file_name)
}
}
struct AsyncTempFile(tokio::fs::File);
impl AsyncTempFile {
fn new() -> io::Result<Self> {
let f = tempfile::tempfile()?;
Ok(Self(tokio::fs::File::from_std(f)))
}
async fn persist(&mut self, path: &Path) -> io::Result<()> {
let mut f = tokio::fs::File::create_new(path).await?;
self.0.seek(io::SeekFrom::Start(0)).await?;
tokio::io::copy(&mut self.0, &mut f).await?;
Ok(())
}
}
impl AsMut<tokio::fs::File> for AsyncTempFile {
fn as_mut(&mut self) -> &mut tokio::fs::File {
&mut self.0
}
}
async fn sha256_from_path(p: &Path) -> io::Result<[u8; 32]> {
let file = tokio::fs::File::open(p).await?;
let mut reader = tokio::io::BufReader::new(file);
let mut hasher = Sha256::new();
let mut buffer = [0; 512];
loop {
let count = reader.read(&mut buffer).await?;
if count == 0 {
break;
}
hasher.update(&buffer[..count]);
}
let hash = hasher
.finalize()
.as_slice()
.try_into()
.expect("SHA-256 is 32 bytes");
Ok(hash)
}
fn chan_send(chan: Option<&mut mpsc::Sender<f32>>, msg: f32) {
if let Some(c) = chan {
let _ = c.try_send(msg);
}
}