use std::collections::VecDeque;
use std::fmt::Debug;
use std::io::{self, ErrorKind, Read};
use std::path::PathBuf;
use std::sync::Arc;
use std::time::{Duration, Instant};
use ecow::EcoString;
use native_tls::{Certificate, TlsConnector};
use once_cell::sync::OnceCell;
use ureq::Response;
pub trait Progress {
fn print_start(&mut self);
fn print_progress(&mut self, state: &DownloadState);
fn print_finish(&mut self, state: &DownloadState);
}
pub struct ProgressSink;
impl Progress for ProgressSink {
fn print_start(&mut self) {}
fn print_progress(&mut self, _: &DownloadState) {}
fn print_finish(&mut self, _: &DownloadState) {}
}
#[derive(Debug)]
pub struct DownloadState {
pub content_len: Option<usize>,
pub total_downloaded: usize,
pub bytes_per_second: VecDeque<usize>,
pub start_time: Instant,
}
pub struct Downloader {
user_agent: EcoString,
cert_path: Option<PathBuf>,
cert: OnceCell<Certificate>,
}
impl Downloader {
pub fn new(user_agent: impl Into<EcoString>) -> Self {
Self {
user_agent: user_agent.into(),
cert_path: None,
cert: OnceCell::new(),
}
}
pub fn with_path(user_agent: impl Into<EcoString>, cert_path: PathBuf) -> Self {
Self {
user_agent: user_agent.into(),
cert_path: Some(cert_path),
cert: OnceCell::new(),
}
}
pub fn with_cert(user_agent: impl Into<EcoString>, cert: Certificate) -> Self {
Self {
user_agent: user_agent.into(),
cert_path: None,
cert: OnceCell::with_value(cert),
}
}
pub fn cert(&self) -> Option<io::Result<&Certificate>> {
self.cert_path.as_ref().map(|path| {
self.cert.get_or_try_init(|| {
let pem = std::fs::read(path)?;
Certificate::from_pem(&pem).map_err(io::Error::other)
})
})
}
#[allow(clippy::result_large_err)]
pub fn download(&self, url: &str) -> Result<ureq::Response, ureq::Error> {
let mut builder = ureq::AgentBuilder::new();
let mut tls = TlsConnector::builder();
builder = builder.user_agent(&self.user_agent);
if let Some(proxy) = env_proxy::for_url_str(url)
.to_url()
.and_then(|url| ureq::Proxy::new(url).ok())
{
builder = builder.proxy(proxy);
}
if let Some(cert) = self.cert() {
tls.add_root_certificate(cert?.clone());
}
let connector =
tls.build().map_err(|err| io::Error::new(io::ErrorKind::Other, err))?;
builder = builder.tls_connector(Arc::new(connector));
builder.build().get(url).call()
}
#[allow(clippy::result_large_err)]
pub fn download_with_progress(
&self,
url: &str,
progress: &mut dyn Progress,
) -> Result<Vec<u8>, ureq::Error> {
progress.print_start();
let response = self.download(url)?;
Ok(RemoteReader::from_response(response, progress).download()?)
}
}
impl Debug for Downloader {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Downloader")
.field("user_agent", &self.user_agent)
.field("cert_path", &self.cert_path)
.field(
"cert",
&self
.cert
.get()
.map(|_| typst_utils::debug(|f| write!(f, "Certificate(..)"))),
)
.finish()
}
}
const SAMPLES: usize = 5;
struct RemoteReader<'p> {
reader: Box<dyn Read + Send + Sync + 'static>,
state: DownloadState,
last_progress: Option<Instant>,
progress: &'p mut dyn Progress,
}
impl<'p> RemoteReader<'p> {
fn from_response(response: Response, progress: &'p mut dyn Progress) -> Self {
let content_len: Option<usize> = response
.header("Content-Length")
.and_then(|header| header.parse().ok());
Self {
reader: response.into_reader(),
last_progress: None,
state: DownloadState {
content_len,
total_downloaded: 0,
bytes_per_second: VecDeque::with_capacity(SAMPLES),
start_time: Instant::now(),
},
progress,
}
}
fn download(mut self) -> io::Result<Vec<u8>> {
let mut buffer = vec![0; 8192];
let mut data = match self.state.content_len {
Some(content_len) => Vec::with_capacity(content_len),
None => Vec::with_capacity(8192),
};
let mut downloaded_this_sec = 0;
loop {
let read = match self.reader.read(&mut buffer) {
Ok(0) => break,
Ok(n) => n,
Err(ref e) if e.kind() == ErrorKind::Interrupted => continue,
Err(e) => return Err(e),
};
data.extend(&buffer[..read]);
let last_printed = match self.last_progress {
Some(prev) => prev,
None => {
let current_time = Instant::now();
self.last_progress = Some(current_time);
current_time
}
};
let elapsed = Instant::now().saturating_duration_since(last_printed);
downloaded_this_sec += read;
self.state.total_downloaded += read;
if elapsed >= Duration::from_secs(1) {
if self.state.bytes_per_second.len() == SAMPLES {
self.state.bytes_per_second.pop_back();
}
self.state.bytes_per_second.push_front(downloaded_this_sec);
downloaded_this_sec = 0;
self.progress.print_progress(&self.state);
self.last_progress = Some(Instant::now());
}
}
self.progress.print_finish(&self.state);
Ok(data)
}
}