use anyhow::{Context, Result, bail};
use sha2::{Digest, Sha256};
use std::path::Path;
use tokio::fs;
use tracing::{debug, info, warn};
pub struct ChecksumVerifier;
impl ChecksumVerifier {
pub async fn compute_sha256(file_path: &Path) -> Result<String> {
debug!("Computing SHA256 checksum for: {:?}", file_path);
let contents = fs::read(file_path)
.await
.with_context(|| format!("Failed to read file: {:?}", file_path))?;
let mut hasher = Sha256::new();
hasher.update(&contents);
let result = hasher.finalize();
Ok(format!("{:x}", result))
}
pub async fn verify_checksum(file_path: &Path, expected_checksum: &str) -> Result<()> {
info!("Verifying checksum for: {:?}", file_path);
let actual_checksum = Self::compute_sha256(file_path).await?;
if actual_checksum.to_lowercase() != expected_checksum.to_lowercase() {
bail!(
"Checksum verification failed!\n Expected: {}\n Actual: {}",
expected_checksum,
actual_checksum
);
}
info!("Checksum verification successful");
Ok(())
}
pub async fn fetch_expected_checksum(
checksums_url: &str,
binary_name: &str,
) -> Result<Option<String>> {
debug!("Fetching checksums from: {}", checksums_url);
let client = reqwest::Client::new();
let response = client
.get(checksums_url)
.send()
.await
.context("Failed to fetch checksums file")?;
if !response.status().is_success() {
warn!("Failed to fetch checksums file: HTTP {}", response.status());
return Ok(None);
}
let content = response
.text()
.await
.context("Failed to read checksums file content")?;
for line in content.lines() {
let parts: Vec<&str> = line.split_whitespace().collect();
if parts.len() == 2 {
let (checksum, filename) = (parts[0], parts[1]);
if filename == binary_name || filename.contains(binary_name) {
debug!("Found checksum for {}: {}", binary_name, checksum);
return Ok(Some(checksum.to_string()));
}
}
}
warn!("No checksum found for binary: {}", binary_name);
Ok(None)
}
pub async fn verify_from_release(
file_path: &Path,
checksums_url: &str,
binary_name: &str,
) -> Result<bool> {
match Self::fetch_expected_checksum(checksums_url, binary_name).await? {
Some(expected) => {
Self::verify_checksum(file_path, &expected).await?;
Ok(true)
}
None => {
warn!("No checksum available for verification, skipping");
Ok(false)
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::Write;
use tempfile::NamedTempFile;
#[tokio::test]
async fn test_compute_sha256() {
let mut temp_file = NamedTempFile::new().unwrap();
temp_file.write_all(b"Hello, World!").unwrap();
let checksum = ChecksumVerifier::compute_sha256(temp_file.path())
.await
.unwrap();
assert_eq!(
checksum,
"dffd6021bb2bd5b0af676290809ec3a53191dd81c7f70a4b28688a362182986f"
);
}
#[tokio::test]
async fn test_verify_checksum_success() {
let mut temp_file = NamedTempFile::new().unwrap();
temp_file.write_all(b"Test content").unwrap();
let actual = ChecksumVerifier::compute_sha256(temp_file.path())
.await
.unwrap();
ChecksumVerifier::verify_checksum(temp_file.path(), &actual)
.await
.unwrap();
}
#[tokio::test]
async fn test_verify_checksum_failure() {
let mut temp_file = NamedTempFile::new().unwrap();
temp_file.write_all(b"Test content").unwrap();
let wrong_checksum = "0000000000000000000000000000000000000000000000000000000000000000";
let result = ChecksumVerifier::verify_checksum(temp_file.path(), wrong_checksum).await;
assert!(result.is_err());
assert!(
result
.unwrap_err()
.to_string()
.contains("Checksum verification failed")
);
}
#[tokio::test]
async fn test_verify_checksum_case_insensitive() {
let mut temp_file = NamedTempFile::new().unwrap();
temp_file.write_all(b"Test").unwrap();
let lowercase = "532eaabd9574880dbf76b9b8cc00832c20a6ec113d682299550d7a6e0f345e25";
let uppercase = "532EAABD9574880DBF76B9B8CC00832C20A6EC113D682299550D7A6E0F345E25";
ChecksumVerifier::verify_checksum(temp_file.path(), lowercase)
.await
.unwrap();
ChecksumVerifier::verify_checksum(temp_file.path(), uppercase)
.await
.unwrap();
}
}