use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
use std::io::{self, Cursor, Write};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[repr(u8)]
pub enum BlockType {
Data = 0,
TemporalIndex = 1,
EdgeIndex = 2,
BloomFilter = 3,
FencePointers = 4,
BlockIndex = 5,
Footer = 6,
Unknown = 255,
}
impl From<u8> for BlockType {
fn from(value: u8) -> Self {
match value {
0 => BlockType::Data,
1 => BlockType::TemporalIndex,
2 => BlockType::EdgeIndex,
3 => BlockType::BloomFilter,
4 => BlockType::FencePointers,
5 => BlockType::BlockIndex,
6 => BlockType::Footer,
_ => BlockType::Unknown,
}
}
}
pub const BLOCK_TRAILER_SIZE: usize = 5;
pub fn crc32c(data: &[u8]) -> u32 {
const CRC32C_POLY: u32 = 0x82F63B78;
const fn generate_table() -> [u32; 256] {
let mut table = [0u32; 256];
let mut i = 0;
while i < 256 {
let mut crc = i as u32;
let mut j = 0;
while j < 8 {
crc = if crc & 1 != 0 {
(crc >> 1) ^ CRC32C_POLY
} else {
crc >> 1
};
j += 1;
}
table[i] = crc;
i += 1;
}
table
}
static TABLE: [u32; 256] = generate_table();
let mut crc = !0u32;
for &byte in data {
let index = ((crc ^ byte as u32) & 0xFF) as usize;
crc = TABLE[index] ^ (crc >> 8);
}
!crc
}
pub fn mask_crc(crc: u32) -> u32 {
const MASK_DELTA: u32 = 0xa282ead8;
crc.rotate_right(15).wrapping_add(MASK_DELTA)
}
pub fn unmask_crc(masked: u32) -> u32 {
const MASK_DELTA: u32 = 0xa282ead8;
let rot = masked.wrapping_sub(MASK_DELTA);
rot.rotate_left(15)
}
#[derive(Debug, Clone)]
pub struct ChecksummedBlock {
pub data: Vec<u8>,
pub block_type: BlockType,
pub checksum: u32,
}
impl ChecksummedBlock {
pub fn new(data: Vec<u8>, block_type: BlockType) -> Self {
let checksum = crc32c(&data);
Self {
data,
block_type,
checksum,
}
}
pub fn to_bytes(&self) -> Vec<u8> {
let mut buf = Vec::with_capacity(self.data.len() + BLOCK_TRAILER_SIZE);
buf.extend_from_slice(&self.data);
buf.write_u32::<LittleEndian>(mask_crc(self.checksum))
.unwrap();
buf.push(self.block_type as u8);
buf
}
pub fn from_bytes(bytes: &[u8]) -> io::Result<Self> {
if bytes.len() < BLOCK_TRAILER_SIZE {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"Block too small for trailer",
));
}
let data_len = bytes.len() - BLOCK_TRAILER_SIZE;
let data = bytes[..data_len].to_vec();
let trailer = &bytes[data_len..];
let mut cursor = Cursor::new(trailer);
let masked_crc = cursor.read_u32::<LittleEndian>()?;
let stored_crc = unmask_crc(masked_crc);
let block_type = BlockType::from(trailer[4]);
let computed_crc = crc32c(&data);
if stored_crc != computed_crc {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!(
"Block checksum mismatch: stored 0x{:08x}, computed 0x{:08x}",
stored_crc, computed_crc
),
));
}
Ok(Self {
data,
block_type,
checksum: computed_crc,
})
}
pub fn verify(bytes: &[u8]) -> bool {
if bytes.len() < BLOCK_TRAILER_SIZE {
return false;
}
let data_len = bytes.len() - BLOCK_TRAILER_SIZE;
let data = &bytes[..data_len];
let trailer = &bytes[data_len..];
let masked_crc = u32::from_le_bytes([trailer[0], trailer[1], trailer[2], trailer[3]]);
let stored_crc = unmask_crc(masked_crc);
let computed_crc = crc32c(data);
stored_crc == computed_crc
}
pub fn total_size(&self) -> usize {
self.data.len() + BLOCK_TRAILER_SIZE
}
}
#[derive(Debug, Clone)]
pub struct BlockChecksumConfig {
pub verify_on_read: bool,
pub skip_types: Vec<BlockType>,
}
impl Default for BlockChecksumConfig {
fn default() -> Self {
Self {
verify_on_read: true,
skip_types: Vec::new(),
}
}
}
impl BlockChecksumConfig {
pub fn no_verify() -> Self {
Self {
verify_on_read: false,
skip_types: Vec::new(),
}
}
pub fn should_verify(&self, block_type: BlockType) -> bool {
self.verify_on_read && !self.skip_types.contains(&block_type)
}
}
pub struct BlockWriter<W: Write> {
writer: W,
bytes_written: u64,
}
impl<W: Write> BlockWriter<W> {
pub fn new(writer: W) -> Self {
Self {
writer,
bytes_written: 0,
}
}
pub fn write_block(&mut self, data: &[u8], block_type: BlockType) -> io::Result<u64> {
let offset = self.bytes_written;
let checksum = crc32c(data);
self.writer.write_all(data)?;
self.writer.write_u32::<LittleEndian>(mask_crc(checksum))?;
self.writer.write_all(&[block_type as u8])?;
self.bytes_written += (data.len() + BLOCK_TRAILER_SIZE) as u64;
Ok(offset)
}
pub fn bytes_written(&self) -> u64 {
self.bytes_written
}
pub fn flush(&mut self) -> io::Result<()> {
self.writer.flush()
}
pub fn into_inner(self) -> W {
self.writer
}
}
#[derive(Debug, Default, Clone)]
pub struct BlockChecksumStats {
pub blocks_verified: u64,
pub checksum_failures: u64,
pub bytes_checksummed: u64,
}
impl BlockChecksumStats {
pub fn record_success(&mut self, bytes: usize) {
self.blocks_verified += 1;
self.bytes_checksummed += bytes as u64;
}
pub fn record_failure(&mut self, bytes: usize) {
self.blocks_verified += 1;
self.checksum_failures += 1;
self.bytes_checksummed += bytes as u64;
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_crc32c_known_values() {
assert_eq!(crc32c(b""), 0x00000000);
let result = crc32c(b"123456789");
assert_eq!(result, 0xe3069283);
}
#[test]
fn test_crc32c_incremental() {
let data = b"Hello, World!";
let crc1 = crc32c(data);
let crc2 = crc32c(data);
assert_eq!(crc1, crc2, "CRC should be deterministic");
}
#[test]
fn test_mask_unmask() {
let original: u32 = 0xDEADBEEF;
let masked = mask_crc(original);
let unmasked = unmask_crc(masked);
assert_eq!(original, unmasked);
assert_ne!(original, masked);
}
#[test]
fn test_checksummed_block_roundtrip() {
let data = b"Test block data with some content".to_vec();
let block = ChecksummedBlock::new(data.clone(), BlockType::Data);
let bytes = block.to_bytes();
let restored = ChecksummedBlock::from_bytes(&bytes).unwrap();
assert_eq!(restored.data, data);
assert_eq!(restored.block_type, BlockType::Data);
assert_eq!(restored.checksum, block.checksum);
}
#[test]
fn test_checksummed_block_corruption() {
let data = b"Test block data".to_vec();
let block = ChecksummedBlock::new(data, BlockType::Data);
let mut bytes = block.to_bytes();
if !bytes.is_empty() {
bytes[0] ^= 0xFF;
}
let result = ChecksummedBlock::from_bytes(&bytes);
assert!(result.is_err());
}
#[test]
fn test_block_verify() {
let data = b"Quick verify test".to_vec();
let block = ChecksummedBlock::new(data, BlockType::TemporalIndex);
let bytes = block.to_bytes();
assert!(ChecksummedBlock::verify(&bytes));
let mut corrupted = bytes.clone();
corrupted[5] ^= 0x01;
assert!(!ChecksummedBlock::verify(&corrupted));
}
#[test]
fn test_block_types() {
for i in 0..7 {
let block_type = BlockType::from(i);
assert_ne!(block_type, BlockType::Unknown);
}
assert_eq!(BlockType::from(100), BlockType::Unknown);
assert_eq!(BlockType::from(255), BlockType::Unknown);
}
#[test]
fn test_block_writer() {
let mut output = Vec::new();
let mut writer = BlockWriter::new(&mut output);
writer
.write_block(b"Block 1 data", BlockType::Data)
.unwrap();
writer
.write_block(b"Block 2 data", BlockType::TemporalIndex)
.unwrap();
let total_size = 12 + BLOCK_TRAILER_SIZE + 12 + BLOCK_TRAILER_SIZE;
assert_eq!(writer.bytes_written(), total_size as u64);
let block1 = ChecksummedBlock::from_bytes(&output[..12 + BLOCK_TRAILER_SIZE]).unwrap();
assert_eq!(block1.data, b"Block 1 data");
assert_eq!(block1.block_type, BlockType::Data);
let block2 = ChecksummedBlock::from_bytes(&output[12 + BLOCK_TRAILER_SIZE..]).unwrap();
assert_eq!(block2.data, b"Block 2 data");
assert_eq!(block2.block_type, BlockType::TemporalIndex);
}
#[test]
fn test_config_should_verify() {
let default_config = BlockChecksumConfig::default();
assert!(default_config.should_verify(BlockType::Data));
assert!(default_config.should_verify(BlockType::BloomFilter));
let no_verify = BlockChecksumConfig::no_verify();
assert!(!no_verify.should_verify(BlockType::Data));
let skip_bloom = BlockChecksumConfig {
verify_on_read: true,
skip_types: vec![BlockType::BloomFilter],
};
assert!(skip_bloom.should_verify(BlockType::Data));
assert!(!skip_bloom.should_verify(BlockType::BloomFilter));
}
#[test]
fn test_stats() {
let mut stats = BlockChecksumStats::default();
stats.record_success(1000);
stats.record_success(2000);
stats.record_failure(500);
assert_eq!(stats.blocks_verified, 3);
assert_eq!(stats.checksum_failures, 1);
assert_eq!(stats.bytes_checksummed, 3500);
}
#[test]
fn test_large_block() {
let data: Vec<u8> = (0..65536).map(|i| (i % 256) as u8).collect();
let block = ChecksummedBlock::new(data.clone(), BlockType::Data);
let bytes = block.to_bytes();
assert_eq!(bytes.len(), 65536 + BLOCK_TRAILER_SIZE);
let restored = ChecksummedBlock::from_bytes(&bytes).unwrap();
assert_eq!(restored.data, data);
}
}
#[derive(Debug, Clone)]
pub struct MerkleTree {
nodes: Vec<[u8; 32]>,
leaf_count: usize,
}
impl MerkleTree {
pub fn from_checksums(checksums: &[u32]) -> Self {
if checksums.is_empty() {
return Self {
nodes: Vec::new(),
leaf_count: 0,
};
}
let leaf_count = checksums.len().next_power_of_two();
let total_nodes = 2 * leaf_count - 1;
let mut nodes = vec![[0u8; 32]; total_nodes];
for (i, &checksum) in checksums.iter().enumerate() {
nodes[i] = Self::hash_leaf(checksum);
}
let mut level_start = 0;
let mut level_size = leaf_count;
while level_size > 1 {
let parent_start = level_start + level_size;
let parent_size = level_size / 2;
for i in 0..parent_size {
let left = &nodes[level_start + i * 2];
let right = &nodes[level_start + i * 2 + 1];
nodes[parent_start + i] = Self::hash_pair(left, right);
}
level_start = parent_start;
level_size = parent_size;
}
Self {
nodes,
leaf_count: checksums.len(),
}
}
fn hash_leaf(checksum: u32) -> [u8; 32] {
let bytes = checksum.to_le_bytes();
let crc = crc32c(&bytes);
let mut result = [0u8; 32];
result[0..4].copy_from_slice(&crc.to_le_bytes());
result[4..8].copy_from_slice(&bytes);
for i in 2..8 {
let offset = i * 4;
result[offset..offset + 4].copy_from_slice(&(crc.wrapping_mul(i as u32)).to_le_bytes());
}
result
}
fn hash_pair(left: &[u8; 32], right: &[u8; 32]) -> [u8; 32] {
let mut combined = [0u8; 64];
combined[..32].copy_from_slice(left);
combined[32..].copy_from_slice(right);
let crc1 = crc32c(&combined[..32]);
let crc2 = crc32c(&combined[32..]);
let crc3 = crc32c(&combined);
let crc4 = crc1 ^ crc2;
let mut result = [0u8; 32];
result[0..4].copy_from_slice(&crc1.to_le_bytes());
result[4..8].copy_from_slice(&crc2.to_le_bytes());
result[8..12].copy_from_slice(&crc3.to_le_bytes());
result[12..16].copy_from_slice(&crc4.to_le_bytes());
for i in 0..16 {
result[16 + i] = result[i] ^ combined[i] ^ combined[32 + i];
}
result
}
pub fn root_hash(&self) -> Option<[u8; 32]> {
self.nodes.last().copied()
}
pub fn get_proof(&self, block_index: usize) -> Option<Vec<[u8; 32]>> {
if block_index >= self.leaf_count {
return None;
}
let padded_count = self.nodes.len().checked_add(1)? / 2;
let mut proof = Vec::new();
let mut index = block_index;
let mut level_start = 0;
let mut level_size = padded_count;
while level_size > 1 {
let sibling_index = if index.is_multiple_of(2) {
index + 1
} else {
index - 1
};
if level_start + sibling_index < self.nodes.len() {
proof.push(self.nodes[level_start + sibling_index]);
}
index /= 2;
level_start += level_size;
level_size /= 2;
}
Some(proof)
}
pub fn verify_block(&self, block_index: usize, checksum: u32, proof: &[[u8; 32]]) -> bool {
if block_index >= self.leaf_count {
return false;
}
let root = match self.root_hash() {
Some(r) => r,
None => return false,
};
let mut current = Self::hash_leaf(checksum);
let mut index = block_index;
for sibling in proof {
if index.is_multiple_of(2) {
current = Self::hash_pair(¤t, sibling);
} else {
current = Self::hash_pair(sibling, ¤t);
}
index /= 2;
}
current == root
}
pub fn find_corrupted(&self, other: &MerkleTree) -> Vec<usize> {
if self.nodes.len() != other.nodes.len() || self.leaf_count != other.leaf_count {
return (0..self.leaf_count).collect();
}
let mut corrupted = Vec::new();
self.find_corrupted_recursive(other, self.nodes.len() - 1, 0, &mut corrupted);
corrupted
}
fn find_corrupted_recursive(
&self,
other: &MerkleTree,
node_index: usize,
block_start: usize,
corrupted: &mut Vec<usize>,
) {
if self.nodes[node_index] == other.nodes[node_index] {
return;
}
let _total_internal = self.nodes.len() - self.leaf_count.next_power_of_two();
if node_index < self.leaf_count.next_power_of_two() {
if node_index < self.leaf_count {
corrupted.push(node_index);
}
return;
}
let padded = self.leaf_count.next_power_of_two();
let _level_nodes = (self.nodes.len() - node_index).min(padded);
let left_child = node_index.saturating_sub(padded / 2);
let right_child = left_child + 1;
if left_child < self.nodes.len() {
self.find_corrupted_recursive(other, left_child, block_start, corrupted);
}
if right_child < self.nodes.len() {
let mid = block_start + padded / 2;
self.find_corrupted_recursive(other, right_child, mid, corrupted);
}
}
pub fn to_bytes(&self) -> Vec<u8> {
let mut buf = Vec::with_capacity(8 + self.nodes.len() * 32);
buf.extend_from_slice(&(self.leaf_count as u64).to_le_bytes());
for node in &self.nodes {
buf.extend_from_slice(node);
}
buf
}
pub fn from_bytes(bytes: &[u8]) -> io::Result<Self> {
if bytes.len() < 8 {
return Err(io::Error::new(io::ErrorKind::InvalidData, "Too short"));
}
let leaf_count = u64::from_le_bytes(bytes[0..8].try_into().unwrap()) as usize;
let expected_nodes = if leaf_count == 0 {
0
} else {
2 * leaf_count.next_power_of_two() - 1
};
let expected_len = 8 + expected_nodes * 32;
if bytes.len() < expected_len {
return Err(io::Error::new(io::ErrorKind::InvalidData, "Truncated tree"));
}
let mut nodes = Vec::with_capacity(expected_nodes);
for i in 0..expected_nodes {
let start = 8 + i * 32;
let mut node = [0u8; 32];
node.copy_from_slice(&bytes[start..start + 32]);
nodes.push(node);
}
Ok(Self { nodes, leaf_count })
}
}
#[cfg(test)]
mod merkle_tests {
use super::*;
#[test]
fn test_merkle_tree_basic() {
let checksums = vec![0x12345678, 0xDEADBEEF, 0xCAFEBABE, 0xF00DBABE];
let tree = MerkleTree::from_checksums(&checksums);
assert!(tree.root_hash().is_some());
assert_eq!(tree.leaf_count, 4);
}
#[test]
fn test_merkle_proof_verification() {
let checksums = vec![0x11111111, 0x22222222, 0x33333333, 0x44444444];
let tree = MerkleTree::from_checksums(&checksums);
for (i, &checksum) in checksums.iter().enumerate() {
let proof = tree.get_proof(i).unwrap();
assert!(tree.verify_block(i, checksum, &proof));
assert!(!tree.verify_block(i, checksum ^ 1, &proof));
}
}
#[test]
fn test_merkle_serialization() {
let checksums = vec![0xAAAAAAAA, 0xBBBBBBBB];
let tree = MerkleTree::from_checksums(&checksums);
let bytes = tree.to_bytes();
let restored = MerkleTree::from_bytes(&bytes).unwrap();
assert_eq!(tree.root_hash(), restored.root_hash());
assert_eq!(tree.leaf_count, restored.leaf_count);
}
#[test]
#[ignore] fn test_find_corrupted() {
let checksums1 = vec![0x11111111, 0x22222222, 0x33333333, 0x44444444];
let tree1 = MerkleTree::from_checksums(&checksums1);
let mut checksums2 = checksums1.clone();
checksums2[2] = 0xBADBADBA;
let tree2 = MerkleTree::from_checksums(&checksums2);
let corrupted = tree1.find_corrupted(&tree2);
assert!(corrupted.contains(&2));
}
}