use std::path::Path;
use anyhow::{Context, Result};
use s3::{creds::Credentials, Bucket, Region};
pub struct S3StorageBackend {
bucket: Box<Bucket>,
}
impl S3StorageBackend {
pub async fn new(bucket_name: String) -> Result<Option<Self>> {
let access_key = std::env::var("AWS_ACCESS_KEY_ID").ok();
let secret_key = std::env::var("AWS_SECRET_ACCESS_KEY").ok();
if access_key.is_none() || secret_key.is_none() {
return Ok(None);
}
let region_str = std::env::var("AWS_DEFAULT_REGION")
.or_else(|_| std::env::var("AWS_REGION"))
.unwrap_or_else(|_| "us-east-1".to_string());
let endpoint_url = std::env::var("AWS_ENDPOINT_URL_S3")
.or_else(|_| std::env::var("AWS_ENDPOINT_URL"))
.ok();
let region = if let Some(endpoint) = endpoint_url {
log::debug!("Using custom S3 endpoint: {}", endpoint);
Region::Custom {
region: region_str,
endpoint,
}
} else {
region_str.parse().unwrap_or(Region::UsEast1)
};
let credentials = Credentials::new(
access_key.as_deref(),
secret_key.as_deref(),
None, None, None, )
.context("Failed to create AWS credentials")?;
let mut bucket =
Bucket::new(&bucket_name, region, credentials).context("Failed to create S3 bucket")?;
if std::env::var("AWS_ENDPOINT_URL_S3").is_ok() || std::env::var("AWS_ENDPOINT_URL").is_ok()
{
log::debug!("Enabling path-style addressing for custom endpoint");
bucket = bucket.with_path_style();
}
Ok(Some(Self { bucket }))
}
pub async fn upload_toolchain(
&self,
toolchain_name: &str,
version: &str,
platform: &str,
archive_path: &Path,
) -> Result<()> {
let archive_name = self.get_archive_name(toolchain_name, version, platform);
let key = format!("toolchains/{toolchain_name}/{version}/{archive_name}.tar.gz");
log::info!("Uploading to s3://{}/{}...", self.bucket.name(), key);
let bytes = tokio::fs::read(archive_path)
.await
.with_context(|| format!("Failed to read archive file {}", archive_path.display()))?;
self.bucket
.put_object(&key, &bytes)
.await
.context("Failed to upload toolchain to S3")?;
let checksum = compute_sha256(archive_path)?;
self.upload_checksum(&key, &checksum).await?;
log::info!("Uploaded to S3: {}", key);
Ok(())
}
pub async fn download_toolchain(
&self,
toolchain_name: &str,
version: &str,
platform: &str,
dest_path: &Path,
) -> Result<()> {
let archive_name = self.get_archive_name(toolchain_name, version, platform);
let key = format!("toolchains/{toolchain_name}/{version}/{archive_name}.tar.gz");
log::info!("Downloading from s3://{}/{}", self.bucket.name(), key);
let response = self.bucket.get_object(&key).await.with_context(|| {
format!(
"Failed to download from S3: s3://{}/{}",
self.bucket.name(),
key
)
})?;
if response.status_code() != 200 {
return Err(anyhow::anyhow!(
"Failed to download from S3: HTTP {}",
response.status_code()
));
}
tokio::fs::write(dest_path, response.bytes())
.await
.with_context(|| format!("Failed to write to {}", dest_path.display()))?;
let expected_checksum = self.download_checksum(&key).await?;
if let Err(e) = verify_checksum(dest_path, &expected_checksum) {
let _ = std::fs::remove_file(dest_path);
return Err(e);
}
log::info!("Downloaded from S3 (checksum verified)");
Ok(())
}
pub async fn check_availability(
&self,
toolchain_name: &str,
version: &str,
platform: &str,
) -> bool {
let archive_name = self.get_archive_name(toolchain_name, version, platform);
let key = format!("toolchains/{toolchain_name}/{version}/{archive_name}.tar.gz");
self.bucket.head_object(&key).await.is_ok()
}
fn get_archive_name(&self, toolchain_name: &str, version: &str, platform: &str) -> String {
match toolchain_name {
"gnu-riscv" => format!("riscv64-elf-{platform}-{version}"),
"rialo-rust" => format!("rialo-rust-{version}-{platform}"),
_ => format!("{toolchain_name}-{version}-{platform}"),
}
}
async fn upload_checksum(&self, archive_key: &str, checksum: &str) -> Result<()> {
let checksum_key = format!("{archive_key}.sha256");
self.bucket
.put_object(&checksum_key, checksum.as_bytes())
.await
.context("Failed to upload checksum to S3")?;
Ok(())
}
async fn download_checksum(&self, archive_key: &str) -> Result<String> {
let checksum_key = format!("{archive_key}.sha256");
let response = self
.bucket
.get_object(&checksum_key)
.await
.with_context(|| {
format!(
"Failed to download checksum from S3: s3://{}/{}",
self.bucket.name(),
checksum_key
)
})?;
if response.status_code() != 200 {
return Err(anyhow::anyhow!(
"Failed to download checksum: HTTP {}",
response.status_code()
));
}
let checksum = String::from_utf8(response.bytes().to_vec())
.context("Checksum file contains invalid UTF-8")?
.trim()
.to_string();
Ok(checksum)
}
}
fn compute_sha256(file_path: &Path) -> Result<String> {
use std::{fs::File, io};
use sha2::{Digest, Sha256};
let mut file =
File::open(file_path).with_context(|| format!("Failed to open {}", file_path.display()))?;
let mut hasher = Sha256::new();
io::copy(&mut file, &mut hasher)
.with_context(|| format!("Failed to read {}", file_path.display()))?;
Ok(hex::encode(hasher.finalize()))
}
fn verify_checksum(file_path: &Path, expected_checksum: &str) -> Result<()> {
let actual_checksum = compute_sha256(file_path)?;
if actual_checksum != expected_checksum {
return Err(anyhow::anyhow!(
"Checksum mismatch for {}: expected {}, got {}",
file_path.display(),
expected_checksum,
actual_checksum
));
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_archive_naming_gnu_riscv() {
let bucket = Bucket::new(
"test-bucket",
Region::UsEast1,
Credentials::new(Some("test"), Some("test"), None, None, None).unwrap(),
)
.unwrap();
let backend = S3StorageBackend { bucket };
let name = backend.get_archive_name("gnu-riscv", "13.2.0", "x86_64-apple-darwin");
assert_eq!(name, "riscv64-elf-x86_64-apple-darwin-13.2.0");
}
#[test]
fn test_archive_naming_rialo_rust() {
let bucket = Bucket::new(
"test-bucket",
Region::UsEast1,
Credentials::new(Some("test"), Some("test"), None, None, None).unwrap(),
)
.unwrap();
let backend = S3StorageBackend { bucket };
let name = backend.get_archive_name("rialo-rust", "latest", "aarch64-apple-darwin");
assert_eq!(name, "rialo-rust-latest-aarch64-apple-darwin");
}
#[test]
fn test_archive_naming_generic() {
let bucket = Bucket::new(
"test-bucket",
Region::UsEast1,
Credentials::new(Some("test"), Some("test"), None, None, None).unwrap(),
)
.unwrap();
let backend = S3StorageBackend { bucket };
let name =
backend.get_archive_name("custom-toolchain", "1.0.0", "x86_64-unknown-linux-gnu");
assert_eq!(name, "custom-toolchain-1.0.0-x86_64-unknown-linux-gnu");
}
#[test]
fn test_compute_sha256() {
use std::io::Write;
use tempfile::NamedTempFile;
let mut temp_file = NamedTempFile::new().unwrap();
temp_file.write_all(b"Hello, World!").unwrap();
temp_file.flush().unwrap();
let checksum = compute_sha256(temp_file.path()).unwrap();
assert_eq!(
checksum,
"dffd6021bb2bd5b0af676290809ec3a53191dd81c7f70a4b28688a362182986f"
);
}
#[test]
fn test_verify_checksum_success() {
use std::io::Write;
use tempfile::NamedTempFile;
let mut temp_file = NamedTempFile::new().unwrap();
temp_file.write_all(b"Hello, World!").unwrap();
temp_file.flush().unwrap();
let result = verify_checksum(
temp_file.path(),
"dffd6021bb2bd5b0af676290809ec3a53191dd81c7f70a4b28688a362182986f",
);
assert!(result.is_ok());
}
#[test]
fn test_verify_checksum_failure() {
use std::io::Write;
use tempfile::NamedTempFile;
let mut temp_file = NamedTempFile::new().unwrap();
temp_file.write_all(b"Hello, World!").unwrap();
temp_file.flush().unwrap();
let result = verify_checksum(temp_file.path(), "wrong_checksum");
assert!(result.is_err());
assert!(result
.unwrap_err()
.to_string()
.contains("Checksum mismatch"));
}
}