use std::io::{self, Read};
use crate::error::WalError;
pub const MAX_ENTRY_SIZE: usize = 16 * 1024 * 1024;
pub const FRAME_HEADER_SIZE: usize = 16;
#[derive(Debug)]
pub enum DecodeOutcome {
Entry {
lsn: u64,
data: Vec<u8>,
bytes_consumed: u64,
},
EndOfLog,
Corrupt,
}
pub fn encode(lsn: u64, data: &[u8]) -> Result<Vec<u8>, WalError> {
if data.len() > MAX_ENTRY_SIZE {
return Err(WalError::EntryTooLarge {
size: data.len(),
max: MAX_ENTRY_SIZE,
});
}
let len = data.len() as u32;
let len_bytes = len.to_le_bytes();
let lsn_bytes = lsn.to_le_bytes();
let mut hasher = crc32fast::Hasher::new();
hasher.update(&len_bytes);
hasher.update(&lsn_bytes);
hasher.update(data);
let crc = hasher.finalize();
let mut buf = Vec::with_capacity(FRAME_HEADER_SIZE + data.len());
buf.extend_from_slice(&len_bytes);
buf.extend_from_slice(&crc.to_le_bytes());
buf.extend_from_slice(&lsn_bytes);
buf.extend_from_slice(data);
Ok(buf)
}
fn read_full<R: Read>(reader: &mut R, buf: &mut [u8]) -> io::Result<usize> {
let mut total = 0;
while total < buf.len() {
match reader.read(&mut buf[total..]) {
Ok(0) => return Ok(total),
Ok(n) => total += n,
Err(e) if e.kind() == io::ErrorKind::Interrupted => continue,
Err(e) => return Err(e),
}
}
Ok(total)
}
pub fn decode<R: Read>(reader: &mut R) -> io::Result<DecodeOutcome> {
let mut header = [0u8; FRAME_HEADER_SIZE];
let header_read = read_full(reader, &mut header)?;
if header_read == 0 {
return Ok(DecodeOutcome::EndOfLog);
}
if header_read < FRAME_HEADER_SIZE {
return Ok(DecodeOutcome::EndOfLog);
}
let len = u32::from_le_bytes([header[0], header[1], header[2], header[3]]);
let crc = u32::from_le_bytes([header[4], header[5], header[6], header[7]]);
let lsn = u64::from_le_bytes([
header[8], header[9], header[10], header[11], header[12], header[13], header[14],
header[15],
]);
if (len as usize) > MAX_ENTRY_SIZE {
return Ok(DecodeOutcome::Corrupt);
}
let mut data = vec![0u8; len as usize];
let data_read = read_full(reader, &mut data)?;
if data_read < len as usize {
return Ok(DecodeOutcome::Corrupt);
}
let mut hasher = crc32fast::Hasher::new();
hasher.update(&header[0..4]);
hasher.update(&header[8..16]);
hasher.update(&data);
if hasher.finalize() != crc {
return Ok(DecodeOutcome::Corrupt);
}
Ok(DecodeOutcome::Entry {
lsn,
data,
bytes_consumed: (FRAME_HEADER_SIZE + len as usize) as u64,
})
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::Cursor;
fn assert_entry(out: DecodeOutcome, expected_lsn: u64, expected_data: &[u8]) {
match out {
DecodeOutcome::Entry { lsn, data, .. } => {
assert_eq!(lsn, expected_lsn);
assert_eq!(data, expected_data);
}
other => panic!("expected Entry, got {:?}", other),
}
}
fn assert_eol(out: DecodeOutcome) {
assert!(
matches!(out, DecodeOutcome::EndOfLog),
"expected EndOfLog, got {:?}",
out
);
}
fn assert_corrupt(out: DecodeOutcome) {
assert!(
matches!(out, DecodeOutcome::Corrupt),
"expected Corrupt, got {:?}",
out
);
}
#[test]
fn f1_roundtrip_small() {
let frame = encode(42, b"hello").unwrap();
let mut cur = Cursor::new(frame);
assert_entry(decode(&mut cur).unwrap(), 42, b"hello");
}
#[test]
fn f2_roundtrip_empty_data() {
let frame = encode(1, b"").unwrap();
assert_eq!(frame.len(), FRAME_HEADER_SIZE);
let mut cur = Cursor::new(frame);
assert_entry(decode(&mut cur).unwrap(), 1, b"");
}
#[test]
fn f3_roundtrip_max_size() {
let data = vec![0xAB; MAX_ENTRY_SIZE];
let frame = encode(1, &data).unwrap();
let mut cur = Cursor::new(frame);
assert_entry(decode(&mut cur).unwrap(), 1, &data);
}
#[test]
fn f4_encode_oversize_errors() {
let data = vec![0u8; MAX_ENTRY_SIZE + 1];
match encode(1, &data) {
Err(WalError::EntryTooLarge { size, max }) => {
assert_eq!(size, MAX_ENTRY_SIZE + 1);
assert_eq!(max, MAX_ENTRY_SIZE);
}
other => panic!("expected EntryTooLarge, got {:?}", other),
}
}
#[test]
fn f5_decode_clean_eof() {
let mut cur = Cursor::new(Vec::<u8>::new());
assert_eol(decode(&mut cur).unwrap());
}
#[test]
fn f6_decode_partial_header() {
for n in [1usize, 7, 15] {
let mut cur = Cursor::new(vec![0u8; n]);
assert_eol(decode(&mut cur).unwrap());
}
}
#[test]
fn f7_decode_partial_body() {
let mut frame = encode(1, b"hello").unwrap();
frame.pop(); let mut cur = Cursor::new(frame);
assert_corrupt(decode(&mut cur).unwrap());
}
#[test]
fn f8_decode_oversize_len() {
let mut header = vec![];
header.extend_from_slice(&u32::MAX.to_le_bytes()); header.extend_from_slice(&0u32.to_le_bytes()); header.extend_from_slice(&0u64.to_le_bytes()); let mut cur = Cursor::new(header);
assert_corrupt(decode(&mut cur).unwrap());
}
#[test]
fn f9_decode_flipped_len_byte() {
let mut frame = encode(1, b"hello").unwrap();
frame[0] ^= 0x01;
let mut cur = Cursor::new(frame);
assert_corrupt(decode(&mut cur).unwrap());
}
#[test]
fn f10_decode_flipped_lsn_byte() {
let mut frame = encode(1, b"hello").unwrap();
frame[8] ^= 0x01; let mut cur = Cursor::new(frame);
assert_corrupt(decode(&mut cur).unwrap());
}
#[test]
fn f11_decode_flipped_data_byte() {
let mut frame = encode(1, b"hello").unwrap();
frame[FRAME_HEADER_SIZE] ^= 0x01;
let mut cur = Cursor::new(frame);
assert_corrupt(decode(&mut cur).unwrap());
}
#[test]
fn f12_decode_flipped_crc_byte() {
let mut frame = encode(1, b"hello").unwrap();
frame[4] ^= 0x01;
let mut cur = Cursor::new(frame);
assert_corrupt(decode(&mut cur).unwrap());
}
#[test]
fn f13_decode_two_concatenated_frames() {
let mut buf = encode(1, b"a").unwrap();
buf.extend_from_slice(&encode(2, b"bb").unwrap());
let mut cur = Cursor::new(buf);
assert_entry(decode(&mut cur).unwrap(), 1, b"a");
assert_entry(decode(&mut cur).unwrap(), 2, b"bb");
assert_eol(decode(&mut cur).unwrap());
}
#[test]
fn f14_decode_valid_then_corrupt() {
let mut buf = encode(1, b"good").unwrap();
let mut bad = encode(2, b"bad!").unwrap();
bad[FRAME_HEADER_SIZE] ^= 0xFF; buf.extend_from_slice(&bad);
let mut cur = Cursor::new(buf);
assert_entry(decode(&mut cur).unwrap(), 1, b"good");
assert_corrupt(decode(&mut cur).unwrap());
}
}