use anyhow::{Context, Result, bail};
use sha2::{Digest, Sha256};
use std::path::Path;
use tokio::fs;
use tracing::{debug, info, warn};
pub struct ChecksumVerifier;
impl ChecksumVerifier {
pub async fn compute_sha256(file_path: &Path) -> Result<String> {
debug!("Computing SHA256 checksum for: {:?}", file_path);
let contents = fs::read(file_path)
.await
.with_context(|| format!("Failed to read file: {file_path:?}"))?;
let mut hasher = Sha256::new();
hasher.update(&contents);
let result = hasher.finalize();
Ok(format!("sha256:{result:x}"))
}
pub async fn verify_checksum(file_path: &Path, expected_checksum: &str) -> Result<()> {
info!("Verifying checksum for: {:?}", file_path);
let actual_checksum = Self::compute_sha256(file_path).await?;
if actual_checksum.to_lowercase() != expected_checksum.to_lowercase() {
bail!(
"Checksum verification failed!\n Expected: {expected_checksum}\n Actual: {actual_checksum}"
);
}
info!("Checksum verification successful");
Ok(())
}
pub async fn fetch_expected_checksum(
checksums_url: &str,
binary_name: &str,
) -> Result<Option<String>> {
debug!("Fetching checksums from: {}", checksums_url);
let client = reqwest::Client::new();
let response =
client.get(checksums_url).send().await.context("Failed to fetch checksums file")?;
if !response.status().is_success() {
warn!("Failed to fetch checksums file: HTTP {}", response.status());
return Ok(None);
}
let content = response.text().await.context("Failed to read checksums file content")?;
for line in content.lines() {
let parts: Vec<&str> = line.split_whitespace().collect();
if parts.len() == 2 {
let (checksum, filename) = (parts[0], parts[1]);
if filename == binary_name
|| filename.starts_with(&format!("{}-", binary_name))
|| filename.ends_with(&format!("/{}", binary_name))
{
debug!("Found checksum for {}: {}", binary_name, checksum);
return Ok(Some(checksum.to_string()));
}
}
}
warn!("No checksum found for binary: {}", binary_name);
Ok(None)
}
pub async fn verify_from_release(
file_path: &Path,
checksums_url: &str,
binary_name: &str,
) -> Result<bool> {
if let Some(expected) = Self::fetch_expected_checksum(checksums_url, binary_name).await? {
Self::verify_checksum(file_path, &expected).await?;
Ok(true)
} else {
warn!("No checksum available for verification, skipping");
Ok(false)
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::Write;
use tempfile::NamedTempFile;
#[tokio::test]
async fn test_compute_sha256() {
let mut temp_file = NamedTempFile::new().unwrap();
temp_file.write_all(b"Hello, World!").unwrap();
let checksum = ChecksumVerifier::compute_sha256(temp_file.path()).await.unwrap();
assert_eq!(
checksum,
"sha256:dffd6021bb2bd5b0af676290809ec3a53191dd81c7f70a4b28688a362182986f"
);
}
#[tokio::test]
async fn test_verify_checksum_success() {
let mut temp_file = NamedTempFile::new().unwrap();
temp_file.write_all(b"Test content").unwrap();
let actual = ChecksumVerifier::compute_sha256(temp_file.path()).await.unwrap();
ChecksumVerifier::verify_checksum(temp_file.path(), &actual).await.unwrap();
}
#[tokio::test]
async fn test_verify_checksum_failure() {
let mut temp_file = NamedTempFile::new().unwrap();
temp_file.write_all(b"Test content").unwrap();
let wrong_checksum = "0000000000000000000000000000000000000000000000000000000000000000";
let result = ChecksumVerifier::verify_checksum(temp_file.path(), wrong_checksum).await;
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("Checksum verification failed"));
}
#[tokio::test]
async fn test_verify_checksum_case_insensitive() {
let mut temp_file = NamedTempFile::new().unwrap();
temp_file.write_all(b"Test").unwrap();
let lowercase = "sha256:532eaabd9574880dbf76b9b8cc00832c20a6ec113d682299550d7a6e0f345e25";
let uppercase = "sha256:532EAABD9574880DBF76B9B8CC00832C20A6EC113D682299550D7A6E0F345E25";
ChecksumVerifier::verify_checksum(temp_file.path(), lowercase).await.unwrap();
ChecksumVerifier::verify_checksum(temp_file.path(), uppercase).await.unwrap();
}
}