use crc32fast::Hasher;
use std::io::{self, Write};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[derive(Default)]
pub enum ChecksumType {
#[default]
CRC32C,
None,
}
pub struct Checksum;
impl Checksum {
pub fn compute(checksum_type: ChecksumType, data: &[u8]) -> u32 {
match checksum_type {
ChecksumType::CRC32C => {
let mut hasher = Hasher::new();
hasher.update(data);
hasher.finalize()
}
ChecksumType::None => 0,
}
}
pub fn verify(
checksum_type: ChecksumType,
data: &[u8],
expected: u32,
) -> Result<(), ChecksumError> {
if checksum_type == ChecksumType::None {
return Ok(());
}
let actual = Self::compute(checksum_type, data);
if actual != expected {
return Err(ChecksumError::Mismatch {
expected,
actual,
data_len: data.len(),
});
}
Ok(())
}
pub fn builder(checksum_type: ChecksumType) -> ChecksumBuilder {
ChecksumBuilder::new(checksum_type)
}
pub fn encode_with_checksum(checksum_type: ChecksumType, data: &[u8]) -> Vec<u8> {
let checksum = Self::compute(checksum_type, data);
let mut encoded = Vec::with_capacity(4 + data.len() + 4);
encoded.extend_from_slice(&(data.len() as u32).to_le_bytes());
encoded.extend_from_slice(data);
encoded.extend_from_slice(&checksum.to_le_bytes());
encoded
}
pub fn decode_with_checksum(
checksum_type: ChecksumType,
encoded: &[u8],
) -> Result<Vec<u8>, ChecksumError> {
if encoded.len() < 8 {
return Err(ChecksumError::InvalidFormat("Data too short".to_string()));
}
let data_len = u32::from_le_bytes([encoded[0], encoded[1], encoded[2], encoded[3]]) as usize;
if encoded.len() != 4 + data_len + 4 {
return Err(ChecksumError::InvalidFormat(format!(
"Expected {} bytes, got {}",
4 + data_len + 4,
encoded.len()
)));
}
let data = &encoded[4..4 + data_len];
let expected_checksum = u32::from_le_bytes([
encoded[4 + data_len],
encoded[4 + data_len + 1],
encoded[4 + data_len + 2],
encoded[4 + data_len + 3],
]);
Self::verify(checksum_type, data, expected_checksum)?;
Ok(data.to_vec())
}
}
pub struct ChecksumBuilder {
#[allow(dead_code)]
checksum_type: ChecksumType,
hasher: Option<Hasher>,
}
impl ChecksumBuilder {
pub fn new(checksum_type: ChecksumType) -> Self {
let hasher = match checksum_type {
ChecksumType::CRC32C => Some(Hasher::new()),
ChecksumType::None => None,
};
Self {
checksum_type,
hasher,
}
}
pub fn update(&mut self, data: &[u8]) {
if let Some(hasher) = &mut self.hasher {
hasher.update(data);
}
}
pub fn finalize(self) -> u32 {
match self.hasher {
Some(hasher) => hasher.finalize(),
None => 0,
}
}
}
impl Write for ChecksumBuilder {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
self.update(buf);
Ok(buf.len())
}
fn flush(&mut self) -> io::Result<()> {
Ok(())
}
}
#[derive(Debug, thiserror::Error)]
pub enum ChecksumError {
#[error("Checksum mismatch: expected {expected:#010x}, got {actual:#010x} (data_len={data_len})")]
Mismatch {
expected: u32,
actual: u32,
data_len: usize,
},
#[error("Invalid checksum format: {0}")]
InvalidFormat(String),
#[error("IO error: {0}")]
Io(#[from] std::io::Error),
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_checksum_basic() {
let data = b"Hello, MoteDB!";
let checksum = Checksum::compute(ChecksumType::CRC32C, data);
assert!(Checksum::verify(ChecksumType::CRC32C, data, checksum).is_ok());
assert!(Checksum::verify(ChecksumType::CRC32C, data, checksum + 1).is_err());
let corrupted = b"Hello, MoteDB?";
assert!(Checksum::verify(ChecksumType::CRC32C, corrupted, checksum).is_err());
}
#[test]
fn test_checksum_none() {
let data = b"Hello, MoteDB!";
let checksum = Checksum::compute(ChecksumType::None, data);
assert_eq!(checksum, 0);
assert!(Checksum::verify(ChecksumType::None, data, 12345).is_ok());
}
#[test]
fn test_checksum_builder() {
let data1 = b"Hello, ";
let data2 = b"MoteDB!";
let mut builder = Checksum::builder(ChecksumType::CRC32C);
builder.update(data1);
builder.update(data2);
let checksum1 = builder.finalize();
let checksum2 = Checksum::compute(ChecksumType::CRC32C, b"Hello, MoteDB!");
assert_eq!(checksum1, checksum2);
}
#[test]
fn test_checksum_encode_decode() {
let data = b"Hello, MoteDB! This is a test message.";
let encoded = Checksum::encode_with_checksum(ChecksumType::CRC32C, data);
let decoded = Checksum::decode_with_checksum(ChecksumType::CRC32C, &encoded).unwrap();
assert_eq!(data, decoded.as_slice());
}
#[test]
fn test_checksum_decode_corrupted() {
let data = b"Hello, MoteDB!";
let mut encoded = Checksum::encode_with_checksum(ChecksumType::CRC32C, data);
encoded[10] ^= 0xFF;
let result = Checksum::decode_with_checksum(ChecksumType::CRC32C, &encoded);
assert!(result.is_err());
assert!(matches!(result.unwrap_err(), ChecksumError::Mismatch { .. }));
}
#[test]
fn test_checksum_decode_invalid_format() {
let short_data = b"abc";
let result = Checksum::decode_with_checksum(ChecksumType::CRC32C, short_data);
assert!(result.is_err());
let mut invalid = vec![0u8; 20];
invalid[0] = 100; let result = Checksum::decode_with_checksum(ChecksumType::CRC32C, &invalid);
assert!(result.is_err());
}
#[test]
fn test_checksum_deterministic() {
let data = b"Deterministic test";
let checksum1 = Checksum::compute(ChecksumType::CRC32C, data);
let checksum2 = Checksum::compute(ChecksumType::CRC32C, data);
let checksum3 = Checksum::compute(ChecksumType::CRC32C, data);
assert_eq!(checksum1, checksum2);
assert_eq!(checksum2, checksum3);
}
#[test]
fn test_checksum_empty_data() {
let data = b"";
let checksum = Checksum::compute(ChecksumType::CRC32C, data);
assert_eq!(checksum, 0);
assert!(Checksum::verify(ChecksumType::CRC32C, data, checksum).is_ok());
}
#[test]
fn test_checksum_builder_write_trait() {
use std::io::Write;
let mut builder = Checksum::builder(ChecksumType::CRC32C);
builder.write_all(b"Hello, ").unwrap();
builder.write_all(b"MoteDB!").unwrap();
builder.flush().unwrap();
let checksum = builder.finalize();
let expected = Checksum::compute(ChecksumType::CRC32C, b"Hello, MoteDB!");
assert_eq!(checksum, expected);
}
}