rialo-build-lib 0.10.1

Shared library for Rialo program building logic
Documentation
// Copyright (c) Subzero Labs, Inc.
// SPDX-License-Identifier: Apache-2.0

//! HTTP-based toolchain download (no AWS credentials required).
//!
//! Downloads toolchains from public HTTP URLs following the same pattern as
//! binary artifacts in rialoman.

use std::{
    fs::File,
    io::{BufRead, BufReader, Write as _},
    path::Path,
    thread,
    time::{Duration, SystemTime},
};

use anyhow::{anyhow, Context, Result};
use indicatif::{ProgressBar, ProgressDrawTarget, ProgressStyle};
use reqwest::{blocking::Client, StatusCode, Url};
use sha2::{Digest, Sha256};

/// Default base URL for toolchain downloads
const DEFAULT_BASE_URL: &str = "https://rialo-artifacts.s3.us-east-2.amazonaws.com/toolchains/";

/// Environment variable to override base URL
const BASE_URL_ENV: &str = "RIALO_TOOLCHAINS_DIST_BASE";

/// Download timeout in seconds (5 minutes for large toolchains)
const DOWNLOAD_TIMEOUT_SECS: u64 = 300;

/// Maximum number of attempts for retryable HTTP requests.
const MAX_ATTEMPTS: u32 = 3;

/// Initial backoff delay in seconds (doubles each retry: 2s, 4s).
const INITIAL_BACKOFF_SECS: u64 = 2;

/// Marker error for retryable HTTP status codes (5xx, 429).
#[derive(Debug, thiserror::Error)]
#[error("HTTP {}: {}", .0.as_u16(), .0.canonical_reason().unwrap_or("Unknown"))]
struct RetryableHttpStatus(StatusCode);

/// Marker error for retryable body-read failures (e.g. connection reset
/// mid-download).
#[derive(Debug, thiserror::Error)]
#[error("error reading response body: {0}")]
struct RetryableBodyError(#[from] std::io::Error);

/// Walk the `anyhow` error chain and return `true` if any cause is transient.
fn is_retryable(err: &anyhow::Error) -> bool {
    for cause in err.chain() {
        if cause.downcast_ref::<RetryableHttpStatus>().is_some() {
            return true;
        }
        if cause.downcast_ref::<RetryableBodyError>().is_some() {
            return true;
        }
        if let Some(re) = cause.downcast_ref::<reqwest::Error>() {
            if re.is_connect() || re.is_timeout() || re.is_body() {
                return true;
            }
        }
    }
    false
}

/// Compute an exponential backoff duration with lightweight jitter.
///
/// `attempt` is 1-indexed (first retry = 1). Jitter is derived from
/// `SystemTime` sub-second nanos to avoid pulling in an RNG crate.
fn backoff_with_jitter(initial_backoff: Duration, attempt: u32) -> Duration {
    let base_ms = initial_backoff.as_millis() as u64 * (1 << (attempt - 1));
    let jitter_ms = SystemTime::now()
        .duration_since(SystemTime::UNIX_EPOCH)
        .unwrap_or_default()
        .subsec_nanos() as u64
        % (base_ms / 2).max(1);
    Duration::from_millis(base_ms + jitter_ms)
}

/// Execute `operation` up to [`MAX_ATTEMPTS`] times with the default backoff,
/// retrying only on transient errors.
fn with_retry<T>(description: &str, operation: impl Fn() -> Result<T>) -> Result<T> {
    with_retry_with_backoff(
        description,
        Duration::from_secs(INITIAL_BACKOFF_SECS),
        operation,
    )
}

/// Like [`with_retry`] but with an explicit initial backoff duration.
fn with_retry_with_backoff<T>(
    description: &str,
    initial_backoff: Duration,
    operation: impl Fn() -> Result<T>,
) -> Result<T> {
    let mut last_err = None;
    for attempt in 1..=MAX_ATTEMPTS {
        if attempt > 1 {
            let delay = backoff_with_jitter(initial_backoff, attempt - 1);
            log::warn!(
                "{description}: attempt {}/{MAX_ATTEMPTS} failed, retrying in {:.1}s\u{2026}",
                attempt - 1,
                delay.as_secs_f64(),
            );
            thread::sleep(delay);
        }

        match operation() {
            Ok(val) => return Ok(val),
            Err(e) => {
                if !is_retryable(&e) {
                    return Err(e);
                }
                log::warn!("{description}: {e:#}");
                last_err = Some(e);
            }
        }
    }

    Err(last_err.unwrap_or_else(|| anyhow!("all attempts failed")))
        .with_context(|| format!("{description}: failed after {MAX_ATTEMPTS} attempts"))
}

/// Check the HTTP status of `resp` and return an appropriate error for
/// non-success codes. Retryable statuses (5xx, 429) are wrapped in
/// [`RetryableHttpStatus`]; other failures are plain `anyhow` errors.
fn check_response_status(resp: &reqwest::blocking::Response, url: &Url) -> Result<()> {
    let status = resp.status();
    if status.is_success() {
        return Ok(());
    }
    if status.is_server_error() || status == StatusCode::TOO_MANY_REQUESTS {
        return Err(RetryableHttpStatus(status).into());
    }
    Err(anyhow!(
        "request to {url} failed with status: {} {}",
        status,
        status.canonical_reason().unwrap_or("Unknown"),
    ))
}

/// HTTP client for downloading toolchains
#[derive(Debug, Clone)]
pub struct HttpToolchainClient {
    http: Client,
    base_url: Url,
}

impl HttpToolchainClient {
    /// Create a new HTTP toolchain client
    pub fn new() -> Result<Self> {
        let base = std::env::var(BASE_URL_ENV).unwrap_or_else(|_| DEFAULT_BASE_URL.to_owned());
        let base_url = Url::parse(&base).context("invalid base toolchain URL")?;
        let http = Client::builder()
            .timeout(Duration::from_secs(DOWNLOAD_TIMEOUT_SECS))
            .build()?;

        Ok(Self { http, base_url })
    }

    /// Construct the download URL for a toolchain archive
    ///
    /// Format: {base_url}/{toolchain_name}/{version}/{archive_name}.tar.gz
    pub fn archive_url(
        &self,
        toolchain_name: &str,
        version: &str,
        archive_name: &str,
    ) -> Result<Url> {
        let path = format!("{}/{}/{}.tar.gz", toolchain_name, version, archive_name);
        self.base_url
            .join(&path)
            .context("failed to construct toolchain archive URL")
    }

    /// Construct the checksum URL for a toolchain archive
    ///
    /// Format: {base_url}/{toolchain_name}/{version}/{archive_name}.tar.gz.sha256
    pub fn checksum_url(
        &self,
        toolchain_name: &str,
        version: &str,
        archive_name: &str,
    ) -> Result<Url> {
        let path = format!(
            "{}/{}/{}.tar.gz.sha256",
            toolchain_name, version, archive_name
        );
        self.base_url
            .join(&path)
            .context("failed to construct toolchain checksum URL")
    }

    /// Single attempt to fetch a checksum from `url`.
    fn fetch_checksum_once(&self, url: &Url) -> Result<String> {
        let resp = self
            .http
            .get(url.clone())
            .send()
            .with_context(|| format!("failed to fetch checksum from {url}"))?;

        check_response_status(&resp, url)?;

        let checksum = resp
            .text()
            .context("failed to read checksum response")?
            .split_whitespace()
            .next()
            .ok_or_else(|| anyhow!("empty checksum file"))?
            .to_string();

        Ok(checksum)
    }

    /// Single attempt to download the archive at `url` into `dest`, returning
    /// the hex-encoded SHA-256 of the bytes written.
    fn download_archive_once(&self, url: &Url, archive_name: &str, dest: &Path) -> Result<String> {
        let resp = self
            .http
            .get(url.clone())
            .send()
            .with_context(|| format!("failed to download toolchain from {url}"))?;

        check_response_status(&resp, url)?;

        let total = resp.content_length();
        let pb = ProgressBar::with_draw_target(total, ProgressDrawTarget::stderr());
        pb.set_style(
            ProgressStyle::with_template(
                "{spinner:.green} {msg} {bytes}/{total_bytes} ({bytes_per_sec}, {eta})",
            )
            .unwrap()
            .progress_chars("=>-"),
        );
        pb.set_message(format!("Downloading {}", archive_name));

        let mut file = File::create(dest)
            .with_context(|| format!("failed to create file at {}", dest.display()))?;
        let mut hasher = Sha256::new();
        let mut reader = BufReader::with_capacity(1024 * 1024, resp);

        loop {
            let chunk = reader.fill_buf().map_err(RetryableBodyError)?;
            if chunk.is_empty() {
                break;
            }

            file.write_all(chunk)?;
            hasher.write_all(chunk)?;

            let len = chunk.len();
            reader.consume(len);

            pb.inc(len as u64);
        }

        pb.finish_with_message(format!("Downloaded {}", archive_name));

        Ok(hex::encode(hasher.finalize()))
    }

    /// Download the checksum file and return the expected SHA256.
    ///
    /// Retries automatically on transient network errors.
    pub fn fetch_checksum(
        &self,
        toolchain_name: &str,
        version: &str,
        archive_name: &str,
    ) -> Result<String> {
        let url = self.checksum_url(toolchain_name, version, archive_name)?;

        with_retry(&format!("fetch checksum from {url}"), || {
            self.fetch_checksum_once(&url)
        })
    }

    /// Download a toolchain archive to the specified destination.
    ///
    /// Downloads the archive, computes SHA256 during download, and verifies
    /// against the checksum file. Retries automatically on transient network
    /// errors; checksum mismatches are *not* retried (they indicate
    /// corruption, not a transient failure).
    pub fn download_toolchain(
        &self,
        toolchain_name: &str,
        version: &str,
        archive_name: &str,
        dest: &Path,
    ) -> Result<()> {
        // First fetch the expected checksum (has its own retry).
        let expected_checksum = self
            .fetch_checksum(toolchain_name, version, archive_name)
            .context("failed to fetch toolchain checksum")?;

        // Download the archive with retry.
        let url = self.archive_url(toolchain_name, version, archive_name)?;
        log::info!("Downloading {} from {}", archive_name, url.as_str());

        let calculated = with_retry(&format!("download {archive_name}"), || {
            self.download_archive_once(&url, archive_name, dest)
        })?;

        // Verify checksum (not retried — mismatch means corruption).
        if !calculated.eq_ignore_ascii_case(&expected_checksum) {
            std::fs::remove_file(dest).ok();
            return Err(anyhow!(
                "SHA256 mismatch for {} (expected {}, got {})",
                archive_name,
                expected_checksum,
                calculated
            ));
        }

        log::info!("SHA256 checksum verified");
        Ok(())
    }

    /// Check if a toolchain exists at the remote location
    pub fn check_availability(
        &self,
        toolchain_name: &str,
        version: &str,
        archive_name: &str,
    ) -> Result<bool> {
        let url = self.archive_url(toolchain_name, version, archive_name)?;

        let res = self
            .http
            .head(url.clone())
            .send()
            .with_context(|| format!("failed to check toolchain availability at {url}"))?;

        Ok(res.status().is_success())
    }
}

#[cfg(test)]
mod tests {
    use std::cell::Cell;

    use super::*;

    #[test]
    fn retries_on_transient_error_then_succeeds() {
        let attempts = Cell::new(0u32);
        let result = with_retry_with_backoff("test", Duration::ZERO, || {
            attempts.set(attempts.get() + 1);
            if attempts.get() < 3 {
                Err(RetryableHttpStatus(StatusCode::INTERNAL_SERVER_ERROR).into())
            } else {
                Ok("ok")
            }
        });
        assert_eq!(result.unwrap(), "ok");
        assert_eq!(attempts.get(), 3);
    }

    #[test]
    fn does_not_retry_non_retryable_error() {
        let attempts = Cell::new(0u32);
        let result: Result<()> = with_retry_with_backoff("test", Duration::ZERO, || {
            attempts.set(attempts.get() + 1);
            Err(anyhow!("permanent failure"))
        });
        assert!(result.is_err());
        assert_eq!(attempts.get(), 1);
    }

    #[test]
    fn exhausts_all_attempts_on_persistent_transient_error() {
        let attempts = Cell::new(0u32);
        let result: Result<()> = with_retry_with_backoff("test", Duration::ZERO, || {
            attempts.set(attempts.get() + 1);
            Err(RetryableHttpStatus(StatusCode::BAD_GATEWAY).into())
        });
        assert_eq!(attempts.get(), MAX_ATTEMPTS);
        let msg = format!("{:#}", result.unwrap_err());
        assert!(msg.contains("failed after"), "unexpected error: {msg}");
    }
}