use std::{
fs::{self, File},
io::{self, Write},
path::{Path, PathBuf},
};
use anyhow::{Context, Result};
use flate2::read::GzDecoder;
use serde::Deserialize;
use sha2::{Digest, Sha256};
use tar::Archive;
mod gnu_riscv;
mod http_backend;
mod rialo_rust;
pub mod s3_backend;
pub mod source_builder;
pub use gnu_riscv::{GnuRiscvToolchain, DEFAULT_GNU_RISCV_VERSION};
pub use http_backend::HttpToolchainClient;
pub use rialo_rust::{
ResolvedToolchainVersion, RialoRustToolchain, ToolchainSource, RUST_COMMIT_HASH,
RUST_NIGHTLY_VERSION,
};
pub use s3_backend::S3StorageBackend;
pub use source_builder::{
BuildSystemConfig, RustSourceBuilder, SourceBuildConfig, SourceBuildable,
};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ToolchainType {
GnuRiscv,
RialoRust,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum DownloadSource {
PreferS3WithFallback,
S3,
GitHub,
}
#[derive(Debug, Clone)]
pub struct ToolchainConfig {
pub name: String,
pub version: String,
pub download_url: String,
pub install_path: PathBuf,
pub checksum: Option<String>,
}
pub trait Toolchain {
fn is_installed(&self) -> Result<bool>;
fn install(&self) -> Result<()>;
fn validate(&self) -> Result<()>;
fn get_bin_path(&self) -> Result<PathBuf>;
fn get_config(&self) -> &ToolchainConfig;
}
pub fn get_toolchain_root() -> Result<PathBuf> {
use std::env;
if let Ok(custom_path) = env::var("RIALO_BUILD_TOOLCHAIN_HOME") {
log::debug!("Using custom toolchain location: {}", custom_path);
return Ok(PathBuf::from(custom_path));
}
let home = dirs::home_dir().context("Failed to get home directory")?;
let path = home.join(".local/share/rialo/toolchains");
log::debug!("Using toolchain location: {}", path.display());
Ok(path)
}
#[derive(Debug, Deserialize)]
struct MinimalManifest {
manifest_version: Option<u32>,
rust_toolchain: Option<String>,
}
fn find_rialoman_release_manifest_path() -> Option<PathBuf> {
let exe = std::env::current_exe().ok()?;
let manifest = exe.parent()?.parent()?.join("manifest.json");
manifest.exists().then_some(manifest)
}
pub fn detect_rialoman_release_toolchain_version() -> Result<Option<String>> {
let Some(path) = find_rialoman_release_manifest_path() else {
log::debug!("No release manifest found (standalone mode)");
return Ok(None);
};
let content = std::fs::read_to_string(&path)
.with_context(|| format!("Failed to read {}", path.display()))?;
let manifest: MinimalManifest = serde_json::from_str(&content)
.with_context(|| format!("Failed to parse {}", path.display()))?;
if let Some(tc) = manifest.rust_toolchain {
log::debug!("Found rust_toolchain: {tc}");
return Ok(Some(tc));
}
let version = manifest.manifest_version.unwrap_or(1);
if version > 1 {
log::warn!("Manifest v{version} has no rust_toolchain field - using default toolchain",);
}
Ok(None)
}
pub fn download_file(url: &str, dest: &Path) -> Result<()> {
log::info!("Downloading from {}", url);
let response =
reqwest::blocking::get(url).with_context(|| format!("Failed to download from {url}"))?;
if !response.status().is_success() {
return Err(anyhow::anyhow!(
"Download failed with status: {}",
response.status()
));
}
let mut dest_file = File::create(dest)
.with_context(|| format!("Failed to create file at {}", dest.display()))?;
let content = response.bytes().context("Failed to read response bytes")?;
dest_file
.write_all(&content)
.with_context(|| format!("Failed to write to {}", dest.display()))?;
log::debug!("Downloaded {} bytes", content.len());
Ok(())
}
pub fn verify_checksum(file_path: &Path, expected_checksum: &str) -> Result<()> {
log::debug!("Verifying checksum for {}", file_path.display());
let mut file = File::open(file_path)
.with_context(|| format!("Failed to open file {}", file_path.display()))?;
let mut hasher = Sha256::new();
io::copy(&mut file, &mut hasher)
.with_context(|| format!("Failed to read file {}", file_path.display()))?;
let hash = hasher.finalize();
let hash_str = hex::encode(hash);
if hash_str != expected_checksum {
return Err(anyhow::anyhow!(
"Checksum mismatch: expected {}, got {}",
expected_checksum,
hash_str
));
}
log::debug!("Checksum verified");
Ok(())
}
pub fn extract_tar_gz(archive_path: &Path, dest_dir: &Path) -> Result<()> {
log::info!("Extracting archive to {}", dest_dir.display());
let tar_gz = File::open(archive_path)
.with_context(|| format!("Failed to open archive {}", archive_path.display()))?;
let tar = GzDecoder::new(tar_gz);
let mut archive = Archive::new(tar);
archive
.unpack(dest_dir)
.with_context(|| format!("Failed to extract archive to {}", dest_dir.display()))?;
log::info!("Extraction complete");
Ok(())
}
pub fn get_platform() -> Result<String> {
let os = std::env::consts::OS;
let arch = std::env::consts::ARCH;
match (os, arch) {
("macos", "x86_64") => Ok("x86_64-apple-darwin".to_string()),
("macos", "aarch64") => Ok("aarch64-apple-darwin".to_string()),
("linux", "x86_64") => Ok("x86_64-unknown-linux-gnu".to_string()),
("linux", "aarch64") => Ok("aarch64-unknown-linux-gnu".to_string()),
_ => Err(anyhow::anyhow!("Unsupported platform: {os}-{arch}")),
}
}
pub fn command_exists(command: &str) -> bool {
which::which(command).is_ok()
}
pub fn download_with_fallback_strategy<S, G>(s3_download: S, github_download: G) -> Result<()>
where
S: FnOnce() -> Result<()>,
G: FnOnce() -> Result<()>,
{
let source = determine_download_source_from_env();
match source {
DownloadSource::S3 => {
log::info!("Attempting download from S3 (RIALO_TOOLCHAIN_SOURCE=s3)");
s3_download().context(
"S3 download failed. Set RIALO_TOOLCHAIN_SOURCE=github to use GitHub instead.",
)?;
log::info!("Successfully downloaded from S3");
}
DownloadSource::GitHub => {
log::info!("Attempting download from GitHub (RIALO_TOOLCHAIN_SOURCE=github)");
github_download().context("GitHub download failed")?;
log::info!("Successfully downloaded from GitHub");
}
DownloadSource::PreferS3WithFallback => {
log::info!("Attempting download from S3 (will fallback to GitHub if unavailable)");
if let Err(e) = s3_download() {
log::warn!("S3 download failed: {}", e);
log::info!("Falling back to GitHub releases");
github_download().context("Both S3 and GitHub downloads failed")?;
log::info!("Successfully downloaded from GitHub");
} else {
log::info!("Successfully downloaded from S3");
}
}
}
Ok(())
}
pub fn get_s3_bucket() -> String {
std::env::var("RIALO_TOOLCHAIN_S3_BUCKET").unwrap_or_else(|_| "rialo-artifacts".to_string())
}
pub fn determine_download_source_from_env() -> DownloadSource {
let env_value = std::env::var("RIALO_TOOLCHAIN_SOURCE").unwrap_or_else(|_| "auto".to_string());
match env_value.as_str() {
"s3" => DownloadSource::S3,
"github" => DownloadSource::GitHub,
"auto" => DownloadSource::PreferS3WithFallback,
_ => {
eprintln!(
"⚠️ Warning: Unrecognized RIALO_TOOLCHAIN_SOURCE value '{}'. \
Valid values are: 's3', 'github', 'auto'. \
Defaulting to 'auto' (prefer S3 with GitHub fallback).",
env_value
);
DownloadSource::PreferS3WithFallback
}
}
}
pub fn list_installed_toolchains() -> Result<Vec<(String, String)>> {
let toolchain_root = get_toolchain_root()?;
if !toolchain_root.exists() {
return Ok(Vec::new());
}
let mut toolchains = Vec::new();
for entry in fs::read_dir(&toolchain_root)
.with_context(|| format!("Failed to read directory {}", toolchain_root.display()))?
{
let entry = entry?;
let path = entry.path();
if path.is_dir() {
if let Some(name) = path.file_name() {
let name_str = name.to_string_lossy().to_string();
if let Some((toolchain_name, version)) = name_str.rsplit_once('-') {
toolchains.push((toolchain_name.to_string(), version.to_string()));
}
}
}
}
Ok(toolchains)
}