use crate::persistence::entry::WalEntry;
use crate::persistence::storage::StorageBackend;
use crate::persistence::PersistenceError;
use crc32fast::Hasher;
use std::io::{self, Read};
use thiserror::Error;
pub const WAL_HEADER_SIZE: usize = 16;
pub const CRC_SIZE: usize = 4;
pub const MAX_PAYLOAD_SIZE: usize = 16 * 1024 * 1024;
#[derive(Debug, Error)]
pub enum WalError {
#[error("io error: {0}")]
Io(#[from] io::Error),
#[error("persistence error: {0}")]
Persistence(#[from] PersistenceError),
#[error("checksum mismatch: expected {expected:#010x}, got {actual:#010x}")]
ChecksumMismatch {
expected: u32,
actual: u32,
},
#[error("file truncated: expected {expected} bytes, got {actual}")]
Truncated {
expected: usize,
actual: usize,
},
#[error("payload too large: size {size} exceeds max {max}")]
PayloadTooLarge {
size: usize,
max: usize,
},
}
pub struct WalIterator<R> {
reader: R,
}
impl<R: Read> WalIterator<R> {
pub fn new(reader: R) -> Self {
Self { reader }
}
fn read_exact_or_eof(&mut self, buf: &mut [u8]) -> Result<bool, WalError> {
let mut total_read = 0;
while total_read < buf.len() {
match self.reader.read(&mut buf[total_read..]) {
Ok(0) => {
if total_read == 0 {
return Ok(false); }
return Err(WalError::Truncated {
expected: buf.len(),
actual: total_read,
});
}
Ok(n) => total_read += n,
Err(ref e) if e.kind() == io::ErrorKind::Interrupted => {}
Err(e) => return Err(WalError::Io(e)),
}
}
Ok(true)
}
}
impl<R: Read> Iterator for WalIterator<R> {
type Item = Result<(WalEntry, Vec<u8>), WalError>;
fn next(&mut self) -> Option<Self::Item> {
let mut header_bytes = [0u8; WAL_HEADER_SIZE];
match self.read_exact_or_eof(&mut header_bytes) {
Ok(true) => {} Ok(false) => return None, Err(e) => return Some(Err(e)),
}
let sequence = u64::from_le_bytes(
header_bytes[0..8]
.try_into()
.expect("slice is exactly 8 bytes from 16-byte header"),
);
let entry_type = header_bytes[8];
let payload_len_u32 = u32::from_le_bytes(
header_bytes[12..16]
.try_into()
.expect("slice is exactly 4 bytes from 16-byte header"),
);
let entry = WalEntry::new(sequence, entry_type, payload_len_u32);
let payload_len = entry.payload_len as usize;
if payload_len > MAX_PAYLOAD_SIZE {
return Some(Err(WalError::PayloadTooLarge {
size: payload_len,
max: MAX_PAYLOAD_SIZE,
}));
}
let mut payload = vec![0u8; payload_len];
match self.read_exact_or_eof(&mut payload) {
Ok(true) => {}
Ok(false) => {
return Some(Err(WalError::Truncated {
expected: payload_len,
actual: 0,
}))
}
Err(e) => return Some(Err(e)),
}
let mut crc_bytes = [0u8; CRC_SIZE];
match self.read_exact_or_eof(&mut crc_bytes) {
Ok(true) => {}
Ok(false) => {
return Some(Err(WalError::Truncated {
expected: CRC_SIZE,
actual: 0,
}))
}
Err(e) => return Some(Err(e)),
}
let stored_crc = u32::from_le_bytes(crc_bytes);
let mut hasher = Hasher::new();
hasher.update(&header_bytes);
hasher.update(&payload);
let calculated_crc = hasher.finalize();
if calculated_crc != stored_crc {
return Some(Err(WalError::ChecksumMismatch {
expected: stored_crc,
actual: calculated_crc,
}));
}
Some(Ok((entry, payload)))
}
}
pub struct WalAppender {
backend: Box<dyn StorageBackend>,
next_sequence: u64,
}
impl WalAppender {
#[must_use]
pub fn new(backend: Box<dyn StorageBackend>, next_sequence: u64) -> Self {
Self {
backend,
next_sequence,
}
}
pub fn append(&mut self, entry_type: u8, payload: &[u8]) -> Result<(), WalError> {
let payload_len = payload.len();
if payload_len > MAX_PAYLOAD_SIZE {
return Err(WalError::PayloadTooLarge {
size: payload_len,
max: MAX_PAYLOAD_SIZE,
});
}
#[allow(clippy::cast_possible_truncation)]
let payload_len_u32 = payload_len as u32;
let entry_sequence = self.next_sequence;
self.next_sequence += 1;
let mut header_bytes = [0u8; WAL_HEADER_SIZE];
header_bytes[0..8].copy_from_slice(&entry_sequence.to_le_bytes());
header_bytes[8] = entry_type;
header_bytes[9..12].fill(0); header_bytes[12..16].copy_from_slice(&payload_len_u32.to_le_bytes());
let mut hasher = Hasher::new();
hasher.update(&header_bytes);
hasher.update(payload);
let crc = hasher.finalize();
let mut buffer = Vec::with_capacity(WAL_HEADER_SIZE + payload_len + CRC_SIZE);
buffer.extend_from_slice(&header_bytes);
buffer.extend_from_slice(payload);
buffer.extend_from_slice(&crc.to_le_bytes());
self.backend.append(&buffer)?;
Ok(())
}
pub fn sync(&mut self) -> Result<(), WalError> {
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::mem::{align_of, size_of};
#[test]
fn test_wal_constants() {
assert_eq!(WAL_HEADER_SIZE, 16);
assert_eq!(CRC_SIZE, 4);
}
#[test]
fn test_wal_entry_layout() {
assert_eq!(size_of::<WalEntry>(), WAL_HEADER_SIZE);
assert_eq!(align_of::<WalEntry>(), 8);
}
#[test]
fn test_wal_replay_integrity() {
use crate::persistence::storage::{MemoryBackend, StorageBackend};
use std::io::Cursor;
let memory = MemoryBackend::new();
let backend = Box::new(memory.clone());
let mut appender = WalAppender::new(backend, 0);
#[allow(clippy::cast_sign_loss)]
for i in 0..100_i32 {
let payload = (i as u32).to_le_bytes(); appender.append(0, &payload).expect("append failed");
}
let read_backend = Box::new(memory); let data = read_backend.read().expect("read failed");
let cursor = Cursor::new(data);
let iterator = WalIterator::new(cursor);
let mut count = 0;
#[allow(clippy::cast_possible_truncation)]
for (i, result) in iterator.enumerate() {
let (entry, payload) = result.expect("replay failed");
assert_eq!(entry.sequence, i as u64);
assert_eq!(entry.entry_type, 0);
let expected_payload = (i as u32).to_le_bytes();
assert_eq!(payload, expected_payload);
count += 1;
}
assert_eq!(count, 100);
}
}