use alloc::vec::Vec;
use core::fmt;
use super::error::BfTreeError;
const CHECKSUM_SIZE: usize = 4;
#[derive(Clone, Debug)]
pub enum VerifyMode {
None,
Full,
Sampled(f32),
}
impl Default for VerifyMode {
fn default() -> Self {
Self::None
}
}
impl VerifyMode {
pub fn is_enabled(&self) -> bool {
!matches!(self, Self::None)
}
}
fn compute_checksum(data: &[u8]) -> u32 {
let mut h: u32 = 0x811c9dc5; for &b in data {
h ^= u32::from(b);
h = h.wrapping_mul(0x01000193); }
h
}
pub fn wrap_value(value: &[u8]) -> Vec<u8> {
let checksum = compute_checksum(value);
let mut wrapped = Vec::with_capacity(CHECKSUM_SIZE + value.len());
wrapped.extend_from_slice(&checksum.to_le_bytes());
wrapped.extend_from_slice(value);
wrapped
}
pub fn unwrap_value(wrapped: &[u8], verify: bool) -> Result<&[u8], BfTreeError> {
if wrapped.len() < CHECKSUM_SIZE {
return Err(BfTreeError::Corruption(
"value too short for checksum".into(),
));
}
let stored_checksum = u32::from_le_bytes(wrapped[..CHECKSUM_SIZE].try_into().unwrap());
let data = &wrapped[CHECKSUM_SIZE..];
if verify {
let computed = compute_checksum(data);
if computed != stored_checksum {
return Err(BfTreeError::Corruption(alloc::format!(
"checksum mismatch: stored={stored_checksum:#010x}, computed={computed:#010x}",
)));
}
}
Ok(data)
}
#[allow(dead_code)]
pub fn unwrap_value_owned(wrapped: &[u8], verify: bool) -> Result<Vec<u8>, BfTreeError> {
unwrap_value(wrapped, verify).map(|data| data.to_vec())
}
pub fn should_verify(mode: &VerifyMode) -> bool {
use core::sync::atomic::{AtomicU64, Ordering};
static COUNTER: AtomicU64 = AtomicU64::new(0);
match mode {
VerifyMode::None => false,
VerifyMode::Full => true,
VerifyMode::Sampled(rate) => {
if *rate >= 1.0 {
return true;
}
if *rate <= 0.0 {
return false;
}
let count = COUNTER.fetch_add(1, Ordering::Relaxed);
#[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
let threshold = (rate.max(0.0) * 1000.0) as u64;
(count % 1000) < threshold
}
}
}
impl fmt::Display for VerifyMode {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::None => write!(f, "none"),
Self::Full => write!(f, "full"),
Self::Sampled(rate) => write!(f, "sampled({:.1}%)", rate * 100.0),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn wrap_unwrap_roundtrip() {
let original = b"hello, world!";
let wrapped = wrap_value(original);
assert_eq!(wrapped.len(), CHECKSUM_SIZE + original.len());
let unwrapped = unwrap_value(&wrapped, true).unwrap();
assert_eq!(unwrapped, original);
}
#[test]
fn corruption_detected() {
let original = b"important data";
let mut wrapped = wrap_value(original);
if wrapped.len() > CHECKSUM_SIZE {
wrapped[CHECKSUM_SIZE] ^= 0xFF;
}
let result = unwrap_value(&wrapped, true);
assert!(result.is_err());
assert!(matches!(result.unwrap_err(), BfTreeError::Corruption(_)));
}
#[test]
fn no_verify_skips_check() {
let original = b"data";
let mut wrapped = wrap_value(original);
wrapped[CHECKSUM_SIZE] ^= 0xFF;
let unwrapped = unwrap_value(&wrapped, false).unwrap();
assert_ne!(unwrapped, original); }
#[test]
fn empty_value() {
let wrapped = wrap_value(b"");
assert_eq!(wrapped.len(), CHECKSUM_SIZE);
let unwrapped = unwrap_value(&wrapped, true).unwrap();
assert!(unwrapped.is_empty());
}
#[test]
fn verify_mode_sampling() {
assert!(!should_verify(&VerifyMode::None));
assert!(should_verify(&VerifyMode::Full));
assert!(should_verify(&VerifyMode::Sampled(1.0)));
assert!(!should_verify(&VerifyMode::Sampled(0.0)));
}
#[test]
fn checksum_deterministic() {
let data = b"deterministic";
let c1 = compute_checksum(data);
let c2 = compute_checksum(data);
assert_eq!(c1, c2);
}
}