use std::{
fs::File,
io::{BufRead, BufReader, Write as _},
path::Path,
time::Duration,
};
use anyhow::{anyhow, Context, Result};
use indicatif::{ProgressBar, ProgressDrawTarget, ProgressStyle};
use reqwest::{blocking::Client, Url};
use sha2::{Digest, Sha256};
const DEFAULT_BASE_URL: &str = "https://rialo-artifacts.s3.us-east-2.amazonaws.com/toolchains/";
const BASE_URL_ENV: &str = "RIALO_TOOLCHAINS_DIST_BASE";
const DOWNLOAD_TIMEOUT_SECS: u64 = 300;
#[derive(Debug, Clone)]
pub struct HttpToolchainClient {
http: Client,
base_url: Url,
}
impl HttpToolchainClient {
pub fn new() -> Result<Self> {
let base = std::env::var(BASE_URL_ENV).unwrap_or_else(|_| DEFAULT_BASE_URL.to_owned());
let base_url = Url::parse(&base).context("invalid base toolchain URL")?;
let http = Client::builder()
.timeout(Duration::from_secs(DOWNLOAD_TIMEOUT_SECS))
.build()?;
Ok(Self { http, base_url })
}
pub fn archive_url(
&self,
toolchain_name: &str,
version: &str,
archive_name: &str,
) -> Result<Url> {
let path = format!("{}/{}/{}.tar.gz", toolchain_name, version, archive_name);
self.base_url
.join(&path)
.context("failed to construct toolchain archive URL")
}
pub fn checksum_url(
&self,
toolchain_name: &str,
version: &str,
archive_name: &str,
) -> Result<Url> {
let path = format!(
"{}/{}/{}.tar.gz.sha256",
toolchain_name, version, archive_name
);
self.base_url
.join(&path)
.context("failed to construct toolchain checksum URL")
}
pub fn fetch_checksum(
&self,
toolchain_name: &str,
version: &str,
archive_name: &str,
) -> Result<String> {
let url = self.checksum_url(toolchain_name, version, archive_name)?;
let res = self
.http
.get(url.clone())
.send()
.with_context(|| format!("failed to fetch checksum from {url}"))?;
if !res.status().is_success() {
return Err(anyhow!(
"checksum fetch failed with status: {}",
res.status()
));
}
let checksum = res
.text()
.context("failed to read checksum response")?
.split_whitespace()
.next()
.ok_or_else(|| anyhow!("empty checksum file"))?
.to_string();
Ok(checksum)
}
pub fn download_toolchain(
&self,
toolchain_name: &str,
version: &str,
archive_name: &str,
dest: &Path,
) -> Result<()> {
let expected_checksum = self
.fetch_checksum(toolchain_name, version, archive_name)
.context("failed to fetch toolchain checksum")?;
let url = self.archive_url(toolchain_name, version, archive_name)?;
log::info!("Downloading {} from {}", archive_name, url.as_str());
let resp = self
.http
.get(url.clone())
.send()
.with_context(|| format!("failed to download toolchain from {url}"))?;
if !resp.status().is_success() {
return Err(anyhow!(
"Download failed with status: {} {}",
resp.status(),
resp.status().canonical_reason().unwrap_or("Unknown")
));
}
let total = resp.content_length();
let pb = ProgressBar::with_draw_target(total, ProgressDrawTarget::stderr());
pb.set_style(
ProgressStyle::with_template(
"{spinner:.green} {msg} {bytes}/{total_bytes} ({bytes_per_sec}, {eta})",
)
.unwrap()
.progress_chars("=>-"),
);
pb.set_message(format!("Downloading {}", archive_name));
let mut file = File::create(dest)
.with_context(|| format!("failed to create file at {}", dest.display()))?;
let mut hasher = Sha256::new();
let mut reader = BufReader::with_capacity(1024 * 1024, resp);
loop {
let chunk = reader.fill_buf()?;
if chunk.is_empty() {
break;
}
file.write_all(chunk)?;
hasher.write_all(chunk)?;
let len = chunk.len();
reader.consume(len);
pb.inc(len as u64);
}
pb.finish_with_message(format!("Downloaded {}", archive_name));
let calculated = hex::encode(hasher.finalize());
if !calculated.eq_ignore_ascii_case(&expected_checksum) {
std::fs::remove_file(dest).ok(); return Err(anyhow!(
"SHA256 mismatch for {} (expected {}, got {})",
archive_name,
expected_checksum,
calculated
));
}
log::info!("SHA256 checksum verified");
Ok(())
}
pub fn check_availability(
&self,
toolchain_name: &str,
version: &str,
archive_name: &str,
) -> Result<bool> {
let url = self.archive_url(toolchain_name, version, archive_name)?;
let res = self
.http
.head(url.clone())
.send()
.with_context(|| format!("failed to check toolchain availability at {url}"))?;
Ok(res.status().is_success())
}
}