use std::{
fs::File,
io::{BufRead, BufReader, Write as _},
path::Path,
thread,
time::{Duration, SystemTime},
};
use anyhow::{anyhow, Context, Result};
use indicatif::{ProgressBar, ProgressDrawTarget, ProgressStyle};
use reqwest::{blocking::Client, StatusCode, Url};
use sha2::{Digest, Sha256};
const DEFAULT_BASE_URL: &str = "https://rialo-artifacts.s3.us-east-2.amazonaws.com/toolchains/";
const BASE_URL_ENV: &str = "RIALO_TOOLCHAINS_DIST_BASE";
const DOWNLOAD_TIMEOUT_SECS: u64 = 300;
const MAX_ATTEMPTS: u32 = 3;
const INITIAL_BACKOFF_SECS: u64 = 2;
#[derive(Debug, thiserror::Error)]
#[error("HTTP {}: {}", .0.as_u16(), .0.canonical_reason().unwrap_or("Unknown"))]
struct RetryableHttpStatus(StatusCode);
#[derive(Debug, thiserror::Error)]
#[error("error reading response body: {0}")]
struct RetryableBodyError(#[from] std::io::Error);
fn is_retryable(err: &anyhow::Error) -> bool {
for cause in err.chain() {
if cause.downcast_ref::<RetryableHttpStatus>().is_some() {
return true;
}
if cause.downcast_ref::<RetryableBodyError>().is_some() {
return true;
}
if let Some(re) = cause.downcast_ref::<reqwest::Error>() {
if re.is_connect() || re.is_timeout() || re.is_body() {
return true;
}
}
}
false
}
fn backoff_with_jitter(initial_backoff: Duration, attempt: u32) -> Duration {
let base_ms = initial_backoff.as_millis() as u64 * (1 << (attempt - 1));
let jitter_ms = SystemTime::now()
.duration_since(SystemTime::UNIX_EPOCH)
.unwrap_or_default()
.subsec_nanos() as u64
% (base_ms / 2).max(1);
Duration::from_millis(base_ms + jitter_ms)
}
fn with_retry<T>(description: &str, operation: impl Fn() -> Result<T>) -> Result<T> {
with_retry_with_backoff(
description,
Duration::from_secs(INITIAL_BACKOFF_SECS),
operation,
)
}
fn with_retry_with_backoff<T>(
description: &str,
initial_backoff: Duration,
operation: impl Fn() -> Result<T>,
) -> Result<T> {
let mut last_err = None;
for attempt in 1..=MAX_ATTEMPTS {
if attempt > 1 {
let delay = backoff_with_jitter(initial_backoff, attempt - 1);
log::warn!(
"{description}: attempt {}/{MAX_ATTEMPTS} failed, retrying in {:.1}s\u{2026}",
attempt - 1,
delay.as_secs_f64(),
);
thread::sleep(delay);
}
match operation() {
Ok(val) => return Ok(val),
Err(e) => {
if !is_retryable(&e) {
return Err(e);
}
log::warn!("{description}: {e:#}");
last_err = Some(e);
}
}
}
Err(last_err.unwrap_or_else(|| anyhow!("all attempts failed")))
.with_context(|| format!("{description}: failed after {MAX_ATTEMPTS} attempts"))
}
fn check_response_status(resp: &reqwest::blocking::Response, url: &Url) -> Result<()> {
let status = resp.status();
if status.is_success() {
return Ok(());
}
if status.is_server_error() || status == StatusCode::TOO_MANY_REQUESTS {
return Err(RetryableHttpStatus(status).into());
}
Err(anyhow!(
"request to {url} failed with status: {} {}",
status,
status.canonical_reason().unwrap_or("Unknown"),
))
}
#[derive(Debug, Clone)]
pub struct HttpToolchainClient {
http: Client,
base_url: Url,
}
impl HttpToolchainClient {
pub fn new() -> Result<Self> {
let base = std::env::var(BASE_URL_ENV).unwrap_or_else(|_| DEFAULT_BASE_URL.to_owned());
let base_url = Url::parse(&base).context("invalid base toolchain URL")?;
let http = Client::builder()
.timeout(Duration::from_secs(DOWNLOAD_TIMEOUT_SECS))
.build()?;
Ok(Self { http, base_url })
}
pub fn archive_url(
&self,
toolchain_name: &str,
version: &str,
archive_name: &str,
) -> Result<Url> {
let path = format!("{}/{}/{}.tar.gz", toolchain_name, version, archive_name);
self.base_url
.join(&path)
.context("failed to construct toolchain archive URL")
}
pub fn checksum_url(
&self,
toolchain_name: &str,
version: &str,
archive_name: &str,
) -> Result<Url> {
let path = format!(
"{}/{}/{}.tar.gz.sha256",
toolchain_name, version, archive_name
);
self.base_url
.join(&path)
.context("failed to construct toolchain checksum URL")
}
fn fetch_checksum_once(&self, url: &Url) -> Result<String> {
let resp = self
.http
.get(url.clone())
.send()
.with_context(|| format!("failed to fetch checksum from {url}"))?;
check_response_status(&resp, url)?;
let checksum = resp
.text()
.context("failed to read checksum response")?
.split_whitespace()
.next()
.ok_or_else(|| anyhow!("empty checksum file"))?
.to_string();
Ok(checksum)
}
fn download_archive_once(&self, url: &Url, archive_name: &str, dest: &Path) -> Result<String> {
let resp = self
.http
.get(url.clone())
.send()
.with_context(|| format!("failed to download toolchain from {url}"))?;
check_response_status(&resp, url)?;
let total = resp.content_length();
let pb = ProgressBar::with_draw_target(total, ProgressDrawTarget::stderr());
pb.set_style(
ProgressStyle::with_template(
"{spinner:.green} {msg} {bytes}/{total_bytes} ({bytes_per_sec}, {eta})",
)
.unwrap()
.progress_chars("=>-"),
);
pb.set_message(format!("Downloading {}", archive_name));
let mut file = File::create(dest)
.with_context(|| format!("failed to create file at {}", dest.display()))?;
let mut hasher = Sha256::new();
let mut reader = BufReader::with_capacity(1024 * 1024, resp);
loop {
let chunk = reader.fill_buf().map_err(RetryableBodyError)?;
if chunk.is_empty() {
break;
}
file.write_all(chunk)?;
hasher.write_all(chunk)?;
let len = chunk.len();
reader.consume(len);
pb.inc(len as u64);
}
pb.finish_with_message(format!("Downloaded {}", archive_name));
Ok(hex::encode(hasher.finalize()))
}
pub fn fetch_checksum(
&self,
toolchain_name: &str,
version: &str,
archive_name: &str,
) -> Result<String> {
let url = self.checksum_url(toolchain_name, version, archive_name)?;
with_retry(&format!("fetch checksum from {url}"), || {
self.fetch_checksum_once(&url)
})
}
pub fn download_toolchain(
&self,
toolchain_name: &str,
version: &str,
archive_name: &str,
dest: &Path,
) -> Result<()> {
let expected_checksum = self
.fetch_checksum(toolchain_name, version, archive_name)
.context("failed to fetch toolchain checksum")?;
let url = self.archive_url(toolchain_name, version, archive_name)?;
log::info!("Downloading {} from {}", archive_name, url.as_str());
let calculated = with_retry(&format!("download {archive_name}"), || {
self.download_archive_once(&url, archive_name, dest)
})?;
if !calculated.eq_ignore_ascii_case(&expected_checksum) {
std::fs::remove_file(dest).ok();
return Err(anyhow!(
"SHA256 mismatch for {} (expected {}, got {})",
archive_name,
expected_checksum,
calculated
));
}
log::info!("SHA256 checksum verified");
Ok(())
}
pub fn check_availability(
&self,
toolchain_name: &str,
version: &str,
archive_name: &str,
) -> Result<bool> {
let url = self.archive_url(toolchain_name, version, archive_name)?;
let res = self
.http
.head(url.clone())
.send()
.with_context(|| format!("failed to check toolchain availability at {url}"))?;
Ok(res.status().is_success())
}
}
#[cfg(test)]
mod tests {
use std::cell::Cell;
use super::*;
#[test]
fn retries_on_transient_error_then_succeeds() {
let attempts = Cell::new(0u32);
let result = with_retry_with_backoff("test", Duration::ZERO, || {
attempts.set(attempts.get() + 1);
if attempts.get() < 3 {
Err(RetryableHttpStatus(StatusCode::INTERNAL_SERVER_ERROR).into())
} else {
Ok("ok")
}
});
assert_eq!(result.unwrap(), "ok");
assert_eq!(attempts.get(), 3);
}
#[test]
fn does_not_retry_non_retryable_error() {
let attempts = Cell::new(0u32);
let result: Result<()> = with_retry_with_backoff("test", Duration::ZERO, || {
attempts.set(attempts.get() + 1);
Err(anyhow!("permanent failure"))
});
assert!(result.is_err());
assert_eq!(attempts.get(), 1);
}
#[test]
fn exhausts_all_attempts_on_persistent_transient_error() {
let attempts = Cell::new(0u32);
let result: Result<()> = with_retry_with_backoff("test", Duration::ZERO, || {
attempts.set(attempts.get() + 1);
Err(RetryableHttpStatus(StatusCode::BAD_GATEWAY).into())
});
assert_eq!(attempts.get(), MAX_ATTEMPTS);
let msg = format!("{:#}", result.unwrap_err());
assert!(msg.contains("failed after"), "unexpected error: {msg}");
}
}