#![forbid(unsafe_code)]
use reqwest::blocking::Client;
use reqwest::header::RANGE;
use reqwest::StatusCode;
use std::fs;
use std::io::Read;
use std::path::{Path, PathBuf};
use std::time::Duration;
use vanta_core::{Area, VtaError, VtaResult};
const MAX_REDIRECTS: usize = 10;
pub type ProgressFn<'a> = dyn Fn(u64) + 'a;
pub struct Downloader {
client: Client,
retries: u32,
allow_http: bool,
}
impl Downloader {
pub fn new() -> VtaResult<Downloader> {
Self::build(false)
}
pub fn insecure() -> VtaResult<Downloader> {
Self::build(true)
}
fn build(allow_http: bool) -> VtaResult<Downloader> {
let redirect = reqwest::redirect::Policy::custom(|attempt| {
if attempt.previous().len() >= MAX_REDIRECTS {
attempt.error("too many redirects")
} else if attempt.url().scheme() == "http"
&& attempt
.previous()
.last()
.map(|u| u.scheme() == "https")
.unwrap_or(false)
{
attempt.stop()
} else {
attempt.follow()
}
});
let client = Client::builder()
.user_agent(concat!("vanta/", env!("CARGO_PKG_VERSION")))
.connect_timeout(Duration::from_secs(30))
.redirect(redirect)
.build()
.map_err(|e| VtaError::new(Area::Net, 4, format!("building HTTP client: {e}")))?;
Ok(Downloader {
client,
retries: 3,
allow_http,
})
}
pub fn with_retries(mut self, retries: u32) -> Self {
self.retries = retries;
self
}
pub fn download(&self, url: &str, dest: &Path) -> VtaResult<()> {
self.download_capped(url, dest, None)
}
pub fn download_capped(&self, url: &str, dest: &Path, max: Option<u64>) -> VtaResult<()> {
self.download_capped_with_progress(url, dest, max, None)
}
pub fn download_capped_with_progress(
&self,
url: &str,
dest: &Path,
max: Option<u64>,
progress: Option<&ProgressFn>,
) -> VtaResult<()> {
self.scheme_ok(url)?;
let mut last: Option<VtaError> = None;
for attempt in 0..=self.retries {
match self.fetch_one(url, dest, max, progress) {
Ok(()) => return Ok(()),
Err(e) => {
last = Some(e);
if attempt < self.retries {
std::thread::sleep(backoff(attempt));
}
}
}
}
Err(last.unwrap_or_else(|| VtaError::new(Area::Net, 1, format!("download failed: {url}"))))
}
pub fn download_any(&self, urls: &[String], dest: &Path, max: Option<u64>) -> VtaResult<()> {
self.download_any_with_progress(urls, dest, max, None)
}
pub fn download_any_with_progress(
&self,
urls: &[String],
dest: &Path,
max: Option<u64>,
progress: Option<&ProgressFn>,
) -> VtaResult<()> {
let mut last: Option<VtaError> = None;
for url in urls {
let _ = fs::remove_file(part_path(dest));
match self.download_capped_with_progress(url, dest, max, progress) {
Ok(()) => return Ok(()),
Err(e) => last = Some(e),
}
}
Err(last.unwrap_or_else(|| {
VtaError::new(Area::Net, 1, "no URLs supplied to download_any".to_string())
}))
}
fn scheme_ok(&self, url: &str) -> VtaResult<()> {
if let Some(rest) = url.strip_prefix("http://") {
if !self.allow_http && !is_loopback_authority(rest) {
return Err(VtaError::new(
Area::Net,
5,
format!(
"refusing plaintext http:// download of {url} \
(https required; set the insecure opt-in to override)"
),
));
}
}
Ok(())
}
fn fetch_one(
&self,
url: &str,
dest: &Path,
max: Option<u64>,
progress: Option<&ProgressFn>,
) -> VtaResult<()> {
let part = part_path(dest);
let have = fs::metadata(&part).map(|m| m.len()).unwrap_or(0);
let mut req = self.client.get(url);
if have > 0 {
req = req.header(RANGE, format!("bytes={have}-"));
}
let mut resp = req
.send()
.map_err(|e| VtaError::new(Area::Net, 1, format!("requesting {url}: {e}")))?;
let status = resp.status();
let resuming = have > 0 && status == StatusCode::PARTIAL_CONTENT;
if !(status.is_success() || resuming) {
return Err(VtaError::new(
Area::Net,
1,
format!("HTTP {status} for {url}"),
));
}
let remaining =
match max {
Some(m) => Some(m.checked_sub(if resuming { have } else { 0 }).ok_or_else(
|| VtaError::new(Area::Net, 6, format!("download of {url} exceeds size cap")),
)?),
None => None,
};
if let Some(parent) = part.parent() {
fs::create_dir_all(parent).map_err(|e| io(parent, e))?;
}
let mut file = if resuming {
fs::OpenOptions::new()
.append(true)
.open(&part)
.map_err(|e| io(&part, e))?
} else {
let _ = fs::remove_file(&part);
fs::File::create(&part).map_err(|e| io(&part, e))?
};
let mut src = ProgressReader::new(&mut resp, progress);
let written = match remaining {
Some(limit) => {
let mut limited = (&mut src).take(limit.saturating_add(1));
let n = std::io::copy(&mut limited, &mut file).map_err(|e| {
VtaError::new(Area::Net, 1, format!("writing {}: {e}", part.display()))
})?;
if n > limit {
let _ = fs::remove_file(&part);
return Err(VtaError::new(
Area::Net,
6,
format!("download of {url} exceeds declared size {limit} bytes"),
));
}
n
}
None => std::io::copy(&mut src, &mut file).map_err(|e| {
VtaError::new(Area::Net, 1, format!("writing {}: {e}", part.display()))
})?,
};
let _ = written;
file.sync_all().ok();
fs::rename(&part, dest).map_err(|e| io(dest, e))?;
Ok(())
}
}
struct ProgressReader<'a, R> {
inner: R,
progress: Option<&'a ProgressFn<'a>>,
}
impl<'a, R> ProgressReader<'a, R> {
fn new(inner: R, progress: Option<&'a ProgressFn<'a>>) -> Self {
ProgressReader { inner, progress }
}
}
impl<R: Read> Read for ProgressReader<'_, R> {
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
let n = self.inner.read(buf)?;
if n > 0 {
if let Some(cb) = self.progress {
cb(n as u64);
}
}
Ok(n)
}
}
fn is_loopback_authority(rest: &str) -> bool {
let authority = rest
.split(['/', '?', '#'])
.next()
.unwrap_or(rest)
.trim_end_matches('.');
let host_port = authority.rsplit('@').next().unwrap_or(authority);
let host = if let Some(stripped) = host_port.strip_prefix('[') {
stripped.split(']').next().unwrap_or(stripped)
} else {
host_port.split(':').next().unwrap_or(host_port)
};
host == "localhost" || host == "::1" || host.starts_with("127.")
}
fn part_path(dest: &Path) -> PathBuf {
let mut s = dest.as_os_str().to_os_string();
s.push(".part");
PathBuf::from(s)
}
fn backoff(attempt: u32) -> Duration {
let secs = (1u64 << attempt.min(4)) as f64 * 0.5;
Duration::from_secs_f64(secs)
}
fn io(path: &Path, e: std::io::Error) -> VtaError {
VtaError::new(Area::Net, 1, format!("{}: {e}", path.display()))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn client_builds() {
assert!(Downloader::new().is_ok());
}
#[test]
fn part_path_appends_suffix() {
assert_eq!(
part_path(Path::new("/tmp/a.bin")),
PathBuf::from("/tmp/a.bin.part")
);
}
#[test]
fn download_any_empty_errors() {
let d = Downloader::new().unwrap();
assert!(d.download_any(&[], Path::new("/tmp/none"), None).is_err());
}
#[test]
fn rejects_plaintext_http_scheme() {
let d = Downloader::new().unwrap();
let err = d
.download("http://example.org/x", Path::new("/tmp/should-not-write"))
.unwrap_err();
assert_eq!(err.area, Area::Net);
assert_eq!(err.number, 5);
assert!(matches!(d.scheme_ok("https://example.org/x"), Ok(())));
}
#[test]
fn loopback_http_is_allowed_scheme() {
assert!(is_loopback_authority("127.0.0.1:8080/x"));
assert!(is_loopback_authority("localhost/x"));
assert!(is_loopback_authority("[::1]:9/x"));
assert!(!is_loopback_authority("example.org/x"));
assert!(!is_loopback_authority("127x.evil.com/x"));
}
#[test]
fn insecure_allows_http() {
let d = Downloader::insecure().unwrap();
assert!(matches!(d.scheme_ok("http://example.org/x"), Ok(())));
}
#[test]
fn size_cap_aborts_oversize_download() {
use std::collections::HashMap;
let mut files = HashMap::new();
files.insert("/big".to_string(), vec![0u8; 10_000]);
let port = vanta_test::serve(files);
let d = Downloader::new().unwrap();
let dest = std::env::temp_dir().join(format!("vanta-net-cap-{}.bin", std::process::id()));
let _ = fs::remove_file(&dest);
let url = format!("http://127.0.0.1:{port}/big");
let err = d.download_capped(&url, &dest, Some(1000)).unwrap_err();
assert_eq!(err.number, 6);
assert!(!dest.exists());
assert!(d.download_capped(&url, &dest, Some(10_000)).is_ok());
let _ = fs::remove_file(&dest);
}
}