use crate::error::{AllSourceError, Result};
use parquet::file::reader::{FileReader, SerializedFileReader};
use sha2::{Digest, Sha256};
use std::path::Path;
pub struct StorageIntegrity;
impl StorageIntegrity {
pub fn compute_checksum(data: &[u8]) -> String {
let mut hasher = Sha256::new();
hasher.update(data);
format!("{:x}", hasher.finalize())
}
pub fn verify_checksum(data: &[u8], expected: &str) -> Result<bool> {
let computed = Self::compute_checksum(data);
Ok(computed == expected)
}
pub fn verify_or_error(data: &[u8], expected: &str) -> Result<()> {
if !Self::verify_checksum(data, expected)? {
return Err(AllSourceError::StorageError(format!(
"Checksum mismatch: expected {}, computed {}",
expected,
Self::compute_checksum(data)
)));
}
Ok(())
}
pub fn compute_checksum_with_metadata(data: &[u8], label: Option<&str>) -> String {
let mut hasher = Sha256::new();
hasher.update((data.len() as u64).to_le_bytes());
if let Some(l) = label {
hasher.update(l.as_bytes());
}
hasher.update(data);
format!("{:x}", hasher.finalize())
}
pub fn verify_wal_segment(segment_path: &Path) -> Result<bool> {
if !segment_path.exists() {
return Ok(false);
}
let data = std::fs::read(segment_path).map_err(|e| {
AllSourceError::StorageError(format!("Failed to read WAL segment: {e}"))
})?;
if data.len() < 64 {
return Err(AllSourceError::StorageError(
"WAL segment too short for checksum".to_string(),
));
}
let stored_checksum = String::from_utf8_lossy(&data[0..64]).to_string();
let segment_data = &data[64..];
Self::verify_or_error(segment_data, &stored_checksum)?;
Ok(true)
}
pub fn verify_parquet_file(file_path: &Path) -> Result<bool> {
if !file_path.exists() {
return Ok(false);
}
let file = std::fs::File::open(file_path).map_err(|e| {
AllSourceError::StorageError(format!("Failed to open Parquet file: {e}"))
})?;
let reader = SerializedFileReader::new(file).map_err(|e| {
AllSourceError::StorageError(format!(
"Parquet metadata verification failed for {}: {e}",
file_path.display()
))
})?;
let metadata = reader.metadata();
let file_metadata = metadata.file_metadata();
for rg_idx in 0..metadata.num_row_groups() {
let row_group = metadata.row_group(rg_idx);
for col_idx in 0..row_group.num_columns() {
let col = row_group.column(col_idx);
let _compression = col.compression();
let _num_values = col.num_values();
let (start, len) = col.byte_range();
if len == 0 && col.num_values() > 0 {
return Err(AllSourceError::StorageError(format!(
"Parquet column chunk {col_idx} in row group {rg_idx} has zero bytes but {} values in {}",
col.num_values(),
file_path.display()
)));
}
let _ = start; }
}
let _schema = file_metadata.schema_descr();
let _num_rows = file_metadata.num_rows();
Ok(true)
}
pub fn batch_verify<P: AsRef<Path>>(
paths: &[P],
progress_callback: Option<&dyn Fn(usize, usize)>,
) -> Result<Vec<bool>> {
let mut results = Vec::new();
for (idx, path) in paths.iter().enumerate() {
let path = path.as_ref();
let result = if path.extension().and_then(|s| s.to_str()) == Some("wal") {
Self::verify_wal_segment(path)?
} else if path.extension().and_then(|s| s.to_str()) == Some("parquet") {
Self::verify_parquet_file(path)?
} else {
false
};
results.push(result);
if let Some(callback) = progress_callback {
callback(idx + 1, paths.len());
}
}
Ok(results)
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct IntegrityCheckResult {
pub path: String,
pub valid: bool,
pub checksum: Option<String>,
pub error: Option<String>,
}
impl IntegrityCheckResult {
pub fn success(path: String, checksum: String) -> Self {
Self {
path,
valid: true,
checksum: Some(checksum),
error: None,
}
}
pub fn failure(path: String, error: String) -> Self {
Self {
path,
valid: false,
checksum: None,
error: Some(error),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use arrow::{
array::{Int32Array, StringArray},
datatypes::{DataType, Field, Schema},
record_batch::RecordBatch,
};
use parquet::arrow::ArrowWriter;
use std::sync::Arc;
use tempfile::NamedTempFile;
fn create_test_parquet_file() -> NamedTempFile {
let schema = Arc::new(Schema::new(vec![
Field::new("id", DataType::Int32, false),
Field::new("name", DataType::Utf8, false),
]));
let ids = Int32Array::from(vec![1, 2, 3]);
let names = StringArray::from(vec!["alpha", "beta", "gamma"]);
let batch = RecordBatch::try_new(schema.clone(), vec![Arc::new(ids), Arc::new(names)])
.expect("valid batch");
let tmp = NamedTempFile::new().expect("create temp file");
let mut writer =
ArrowWriter::try_new(tmp.reopen().expect("reopen"), schema, None).expect("writer");
writer.write(&batch).expect("write batch");
writer.close().expect("close writer");
tmp
}
#[test]
fn test_verify_parquet_file_valid() {
let tmp = create_test_parquet_file();
let result = StorageIntegrity::verify_parquet_file(tmp.path());
assert!(result.is_ok());
assert!(result.unwrap());
}
#[test]
fn test_verify_parquet_file_nonexistent() {
let result = StorageIntegrity::verify_parquet_file(Path::new("/nonexistent/file.parquet"));
assert!(result.is_ok());
assert!(!result.unwrap());
}
#[test]
fn test_verify_parquet_file_corrupt() {
let tmp = NamedTempFile::new().expect("create temp file");
std::fs::write(tmp.path(), b"this is not a parquet file").expect("write corrupt data");
let result = StorageIntegrity::verify_parquet_file(tmp.path());
assert!(result.is_err());
let err = result.unwrap_err();
assert!(matches!(err, AllSourceError::StorageError(_)));
}
#[test]
fn test_verify_parquet_file_truncated() {
let tmp = NamedTempFile::new().expect("create temp file");
std::fs::write(tmp.path(), b"PAR1").expect("write truncated data");
let result = StorageIntegrity::verify_parquet_file(tmp.path());
assert!(result.is_err());
}
#[test]
fn test_batch_verify_with_parquet() {
let tmp = create_test_parquet_file();
let parquet_path = tmp.path().with_extension("parquet");
std::fs::copy(tmp.path(), &parquet_path).expect("copy to .parquet");
let paths = vec![parquet_path.clone()];
let results = StorageIntegrity::batch_verify(&paths, None).expect("batch verify");
assert_eq!(results.len(), 1);
assert!(results[0]);
std::fs::remove_file(&parquet_path).ok();
}
#[test]
fn test_compute_checksum() {
let data = b"hello world";
let checksum = StorageIntegrity::compute_checksum(data);
assert_eq!(checksum.len(), 64);
let checksum2 = StorageIntegrity::compute_checksum(data);
assert_eq!(checksum, checksum2);
}
#[test]
fn test_verify_checksum() {
let data = b"test data";
let checksum = StorageIntegrity::compute_checksum(data);
assert!(StorageIntegrity::verify_checksum(data, &checksum).unwrap());
assert!(!StorageIntegrity::verify_checksum(data, "wrong").unwrap());
}
#[test]
fn test_verify_or_error() {
let data = b"test data";
let checksum = StorageIntegrity::compute_checksum(data);
assert!(StorageIntegrity::verify_or_error(data, &checksum).is_ok());
let result = StorageIntegrity::verify_or_error(data, "wrong");
assert!(result.is_err());
assert!(matches!(result, Err(AllSourceError::StorageError(_))));
}
#[test]
fn test_checksum_with_metadata() {
let data = b"test";
let checksum1 = StorageIntegrity::compute_checksum_with_metadata(data, Some("label1"));
let checksum2 = StorageIntegrity::compute_checksum_with_metadata(data, Some("label2"));
assert_ne!(checksum1, checksum2);
let checksum3 = StorageIntegrity::compute_checksum_with_metadata(data, Some("label1"));
assert_eq!(checksum1, checksum3);
}
#[test]
fn test_different_data_different_checksums() {
let data1 = b"hello";
let data2 = b"world";
let checksum1 = StorageIntegrity::compute_checksum(data1);
let checksum2 = StorageIntegrity::compute_checksum(data2);
assert_ne!(checksum1, checksum2);
}
#[test]
fn test_empty_data() {
let data = b"";
let checksum = StorageIntegrity::compute_checksum(data);
assert_eq!(checksum.len(), 64);
assert!(StorageIntegrity::verify_checksum(data, &checksum).unwrap());
}
#[test]
fn test_large_data() {
let data = vec![0u8; 1_000_000]; let checksum = StorageIntegrity::compute_checksum(&data);
assert_eq!(checksum.len(), 64);
assert!(StorageIntegrity::verify_checksum(&data, &checksum).unwrap());
}
#[test]
fn test_integrity_check_result() {
let success = IntegrityCheckResult::success("test.wal".to_string(), "abc123".to_string());
assert!(success.valid);
assert_eq!(success.checksum, Some("abc123".to_string()));
assert_eq!(success.error, None);
let failure = IntegrityCheckResult::failure(
"test.wal".to_string(),
"corruption detected".to_string(),
);
assert!(!failure.valid);
assert_eq!(failure.checksum, None);
assert_eq!(failure.error, Some("corruption detected".to_string()));
}
}