use crate::error::{GrumpyError, Result};
use crate::page::PAGE_SIZE;
use crate::wal::hlc::Hlc;
use crate::wal::vclock::VectorClock;
pub const WAL_MAGIC: &[u8; 8] = b"GRUMPWAL";
pub const WAL_VERSION_V1: u16 = 1;
pub const WAL_VERSION_V2: u16 = 2;
pub const WAL_VERSION_CURRENT: u16 = WAL_VERSION_V2;
pub const WAL_HEADER_SIZE: usize = PAGE_SIZE;
pub const NIL_NODE_ID: u128 = 0;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[repr(u8)]
pub enum WalOpType {
PageWrite = 1,
Commit = 2,
Rollback = 3,
Checkpoint = 4,
}
impl WalOpType {
pub fn from_u8(v: u8) -> Option<Self> {
match v {
1 => Some(Self::PageWrite),
2 => Some(Self::Commit),
3 => Some(Self::Rollback),
4 => Some(Self::Checkpoint),
_ => None,
}
}
}
pub const WAL_RECORD_HEADER_SIZE_V1: usize = 33;
pub const WAL_RECORD_HEADER_SIZE_V2: usize = 49;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct WalRecord {
pub record_len: u32,
pub lsn: u64,
pub tx_id: u64,
pub op_type: WalOpType,
pub origin_node_id: u128,
pub hlc: Hlc,
pub vector_clock: VectorClock,
pub page_id: u32,
pub data: Vec<u8>,
pub checksum: u32,
}
impl WalRecord {
#[allow(clippy::too_many_arguments)]
pub fn page_write(
lsn: u64,
tx_id: u64,
origin_node_id: u128,
hlc: Hlc,
vector_clock: VectorClock,
page_id: u32,
before: &[u8],
after: &[u8],
) -> Self {
let mut data = Vec::with_capacity(after.len() + before.len());
data.extend_from_slice(after);
data.extend_from_slice(before);
let mut rec = Self {
record_len: 0,
lsn,
tx_id,
op_type: WalOpType::PageWrite,
origin_node_id,
hlc,
vector_clock,
page_id,
data,
checksum: 0,
};
rec.record_len = rec.encoded_v2_len() as u32;
rec.checksum = rec.compute_checksum_v2();
rec
}
pub fn commit(
lsn: u64,
tx_id: u64,
origin_node_id: u128,
hlc: Hlc,
vector_clock: VectorClock,
) -> Self {
let mut rec = Self {
record_len: 0,
lsn,
tx_id,
op_type: WalOpType::Commit,
origin_node_id,
hlc,
vector_clock,
page_id: 0,
data: Vec::new(),
checksum: 0,
};
rec.record_len = rec.encoded_v2_len() as u32;
rec.checksum = rec.compute_checksum_v2();
rec
}
pub fn checkpoint(lsn: u64, origin_node_id: u128, hlc: Hlc, vector_clock: VectorClock) -> Self {
let mut rec = Self {
record_len: 0,
lsn,
tx_id: 0,
op_type: WalOpType::Checkpoint,
origin_node_id,
hlc,
vector_clock,
page_id: 0,
data: Vec::new(),
checksum: 0,
};
rec.record_len = rec.encoded_v2_len() as u32;
rec.checksum = rec.compute_checksum_v2();
rec
}
pub fn encoded_v2_len(&self) -> usize {
let payload = match self.op_type {
WalOpType::PageWrite => 4 + 4 + self.data.len(), _ => 0,
};
WAL_RECORD_HEADER_SIZE_V2 + self.vector_clock.encoded_len() + payload
}
pub fn to_bytes(&self) -> Vec<u8> {
let total = self.encoded_v2_len();
let mut buf = Vec::with_capacity(total);
buf.extend_from_slice(&self.record_len.to_le_bytes());
buf.extend_from_slice(&self.lsn.to_le_bytes());
buf.extend_from_slice(&self.tx_id.to_le_bytes());
buf.push(self.op_type as u8);
buf.extend_from_slice(&self.origin_node_id.to_le_bytes());
buf.extend_from_slice(&self.hlc.to_le_bytes());
self.vector_clock.encode_to(&mut buf);
if self.op_type == WalOpType::PageWrite {
buf.extend_from_slice(&self.page_id.to_le_bytes());
buf.extend_from_slice(&(self.data.len() as u32).to_le_bytes());
buf.extend_from_slice(&self.data);
}
buf.extend_from_slice(&self.checksum.to_le_bytes());
debug_assert_eq!(buf.len(), total);
buf
}
pub fn from_bytes_v2(buf: &[u8]) -> Result<(Self, usize)> {
if buf.len() < 4 {
return Err(GrumpyError::WalCorrupted(0));
}
let record_len = u32::from_le_bytes([buf[0], buf[1], buf[2], buf[3]]) as usize;
if record_len < WAL_RECORD_HEADER_SIZE_V2 || buf.len() < record_len {
return Err(GrumpyError::WalCorrupted(0));
}
let lsn = u64::from_le_bytes([
buf[4], buf[5], buf[6], buf[7], buf[8], buf[9], buf[10], buf[11],
]);
let tx_id = u64::from_le_bytes([
buf[12], buf[13], buf[14], buf[15], buf[16], buf[17], buf[18], buf[19],
]);
let op_type = WalOpType::from_u8(buf[20]).ok_or(GrumpyError::WalCorrupted(lsn))?;
let mut origin_bytes = [0u8; 16];
origin_bytes.copy_from_slice(&buf[21..37]);
let origin_node_id = u128::from_le_bytes(origin_bytes);
let mut hlc_bytes = [0u8; 8];
hlc_bytes.copy_from_slice(&buf[37..45]);
let hlc = Hlc::from_le_bytes(hlc_bytes);
let payload_end = record_len.saturating_sub(4);
let (vector_clock, vc_consumed) = VectorClock::decode(&buf[45..payload_end])
.map_err(|e| GrumpyError::VectorClock(e.to_string()))?;
let mut cursor = 45 + vc_consumed;
let (page_id, data) = if op_type == WalOpType::PageWrite {
if cursor + 8 > payload_end {
return Err(GrumpyError::WalCorrupted(lsn));
}
let page_id = u32::from_le_bytes([
buf[cursor],
buf[cursor + 1],
buf[cursor + 2],
buf[cursor + 3],
]);
cursor += 4;
let data_len = u32::from_le_bytes([
buf[cursor],
buf[cursor + 1],
buf[cursor + 2],
buf[cursor + 3],
]) as usize;
cursor += 4;
if cursor + data_len > payload_end {
return Err(GrumpyError::WalCorrupted(lsn));
}
let data = buf[cursor..cursor + data_len].to_vec();
cursor += data_len;
(page_id, data)
} else {
(0u32, Vec::new())
};
if cursor != payload_end {
return Err(GrumpyError::WalCorrupted(lsn));
}
let checksum = u32::from_le_bytes([
buf[cursor],
buf[cursor + 1],
buf[cursor + 2],
buf[cursor + 3],
]);
let rec = Self {
record_len: record_len as u32,
lsn,
tx_id,
op_type,
origin_node_id,
hlc,
vector_clock,
page_id,
data,
checksum,
};
if rec.compute_checksum_v2() != checksum {
return Err(GrumpyError::WalCorrupted(lsn));
}
Ok((rec, record_len))
}
pub fn from_bytes_v1(buf: &[u8]) -> Result<(Self, usize)> {
if buf.len() < WAL_RECORD_HEADER_SIZE_V1 {
return Err(GrumpyError::WalCorrupted(0));
}
let record_len = u32::from_le_bytes([buf[0], buf[1], buf[2], buf[3]]) as usize;
if record_len < WAL_RECORD_HEADER_SIZE_V1 || buf.len() < record_len {
return Err(GrumpyError::WalCorrupted(0));
}
let lsn = u64::from_le_bytes([
buf[4], buf[5], buf[6], buf[7], buf[8], buf[9], buf[10], buf[11],
]);
let tx_id = u64::from_le_bytes([
buf[12], buf[13], buf[14], buf[15], buf[16], buf[17], buf[18], buf[19],
]);
let op_type = WalOpType::from_u8(buf[20]).ok_or(GrumpyError::WalCorrupted(lsn))?;
let page_id = u32::from_le_bytes([buf[21], buf[22], buf[23], buf[24]]);
let data_len = u32::from_le_bytes([buf[25], buf[26], buf[27], buf[28]]) as usize;
if 29 + data_len + 4 > record_len {
return Err(GrumpyError::WalCorrupted(lsn));
}
let data = buf[29..29 + data_len].to_vec();
let checksum_off = 29 + data_len;
let checksum = u32::from_le_bytes([
buf[checksum_off],
buf[checksum_off + 1],
buf[checksum_off + 2],
buf[checksum_off + 3],
]);
let computed = compute_v1_checksum(record_len as u32, lsn, tx_id, op_type, page_id, &data);
if computed != checksum {
return Err(GrumpyError::WalCorrupted(lsn));
}
let hlc = Hlc::from_packed(lsn);
let vector_clock = VectorClock::singleton(NIL_NODE_ID, lsn);
let mut rec = Self {
record_len: 0,
lsn,
tx_id,
op_type,
origin_node_id: NIL_NODE_ID,
hlc,
vector_clock,
page_id,
data,
checksum: 0,
};
rec.record_len = rec.encoded_v2_len() as u32;
rec.checksum = rec.compute_checksum_v2();
Ok((rec, record_len))
}
pub fn page_images(&self) -> Option<(&[u8], &[u8])> {
if self.op_type != WalOpType::PageWrite {
return None;
}
let half = self.data.len() / 2;
Some((&self.data[..half], &self.data[half..]))
}
pub fn after_image(&self) -> Option<&[u8]> {
self.page_images().map(|(a, _)| a)
}
pub fn before_image(&self) -> Option<&[u8]> {
self.page_images().map(|(_, b)| b)
}
pub fn is_valid(&self) -> bool {
self.compute_checksum_v2() == self.checksum
}
fn compute_checksum_v2(&self) -> u32 {
let mut hasher = crc32fast::Hasher::new();
hasher.update(&self.lsn.to_le_bytes());
hasher.update(&self.tx_id.to_le_bytes());
hasher.update(&[self.op_type as u8]);
hasher.update(&self.origin_node_id.to_le_bytes());
hasher.update(&self.hlc.to_le_bytes());
let mut vc_buf = Vec::with_capacity(self.vector_clock.encoded_len());
self.vector_clock.encode_to(&mut vc_buf);
hasher.update(&vc_buf);
if self.op_type == WalOpType::PageWrite {
hasher.update(&self.page_id.to_le_bytes());
hasher.update(&(self.data.len() as u32).to_le_bytes());
hasher.update(&self.data);
}
hasher.finalize()
}
}
fn compute_v1_checksum(
record_len: u32,
lsn: u64,
tx_id: u64,
op_type: WalOpType,
page_id: u32,
data: &[u8],
) -> u32 {
let mut hasher = crc32fast::Hasher::new();
hasher.update(&record_len.to_le_bytes());
hasher.update(&lsn.to_le_bytes());
hasher.update(&tx_id.to_le_bytes());
hasher.update(&[op_type as u8]);
hasher.update(&page_id.to_le_bytes());
hasher.update(&(data.len() as u32).to_le_bytes());
hasher.update(data);
hasher.finalize()
}
#[doc(hidden)]
pub fn encode_v1_record(
lsn: u64,
tx_id: u64,
op_type: WalOpType,
page_id: u32,
data: &[u8],
) -> Vec<u8> {
let record_len = (WAL_RECORD_HEADER_SIZE_V1 + data.len()) as u32;
let checksum = compute_v1_checksum(record_len, lsn, tx_id, op_type, page_id, data);
let mut buf = Vec::with_capacity(record_len as usize);
buf.extend_from_slice(&record_len.to_le_bytes());
buf.extend_from_slice(&lsn.to_le_bytes());
buf.extend_from_slice(&tx_id.to_le_bytes());
buf.push(op_type as u8);
buf.extend_from_slice(&page_id.to_le_bytes());
buf.extend_from_slice(&(data.len() as u32).to_le_bytes());
buf.extend_from_slice(data);
buf.extend_from_slice(&checksum.to_le_bytes());
buf
}
pub fn build_wal_header(version: u16) -> [u8; WAL_HEADER_SIZE] {
let mut hdr = [0u8; WAL_HEADER_SIZE];
hdr[..8].copy_from_slice(WAL_MAGIC);
hdr[8..10].copy_from_slice(&version.to_le_bytes());
hdr
}
pub fn parse_wal_header(hdr: &[u8]) -> Result<u16> {
if hdr.len() < 10 {
return Err(GrumpyError::WalCorrupted(0));
}
if &hdr[..8] != WAL_MAGIC {
return Err(GrumpyError::WalCorrupted(0));
}
let version = u16::from_le_bytes([hdr[8], hdr[9]]);
if version > WAL_VERSION_CURRENT {
return Err(GrumpyError::UnsupportedWalVersion(version));
}
Ok(version)
}
#[cfg(test)]
mod tests {
use super::*;
fn dummy_origin() -> u128 {
0x0123_4567_89ab_cdef_0123_4567_89ab_cdef
}
#[test]
fn test_v2_commit_round_trip() {
let rec = WalRecord::commit(
1,
42,
dummy_origin(),
Hlc::pack(1_700_000_000_000, 3),
VectorClock::singleton(dummy_origin(), 1),
);
let bytes = rec.to_bytes();
let (decoded, consumed) = WalRecord::from_bytes_v2(&bytes).unwrap();
assert_eq!(consumed, bytes.len());
assert_eq!(decoded, rec);
assert!(decoded.is_valid());
}
#[test]
fn test_v2_checkpoint_round_trip() {
let rec = WalRecord::checkpoint(
100,
dummy_origin(),
Hlc::pack(1_700_000_000_000, 0),
VectorClock::new(),
);
let bytes = rec.to_bytes();
let (decoded, _) = WalRecord::from_bytes_v2(&bytes).unwrap();
assert_eq!(decoded, rec);
}
#[test]
fn test_v2_encode_decode_round_trip() {
let before = vec![0xAA; 8192];
let after = vec![0xBB; 8192];
let rec = WalRecord::page_write(
5,
1,
dummy_origin(),
Hlc::pack(1, 2),
VectorClock::singleton(dummy_origin(), 7),
42,
&before,
&after,
);
let bytes = rec.to_bytes();
let (decoded, _) = WalRecord::from_bytes_v2(&bytes).unwrap();
assert_eq!(decoded.lsn, 5);
assert_eq!(decoded.page_id, 42);
assert_eq!(decoded.after_image().unwrap(), after.as_slice());
assert_eq!(decoded.before_image().unwrap(), before.as_slice());
assert_eq!(decoded.origin_node_id, dummy_origin());
assert_eq!(decoded.hlc, Hlc::pack(1, 2));
}
#[test]
fn test_v2_record_with_vclock() {
let mut vc = VectorClock::new();
vc.set(1, 100);
vc.set(2, 200);
vc.set(3, 300);
let rec = WalRecord::commit(7, 7, dummy_origin(), Hlc::pack(99, 1), vc.clone());
let bytes = rec.to_bytes();
let (decoded, _) = WalRecord::from_bytes_v2(&bytes).unwrap();
assert_eq!(decoded.vector_clock, vc);
}
#[test]
fn test_v1_record_decoded_as_v2() {
let v1 = encode_v1_record(11, 3, WalOpType::Commit, 0, &[]);
let (rec, consumed) = WalRecord::from_bytes_v1(&v1).unwrap();
assert_eq!(consumed, v1.len());
assert_eq!(rec.lsn, 11);
assert_eq!(rec.tx_id, 3);
assert_eq!(rec.op_type, WalOpType::Commit);
assert_eq!(rec.origin_node_id, NIL_NODE_ID);
assert_eq!(rec.hlc, Hlc::from_packed(11));
assert_eq!(rec.vector_clock, VectorClock::singleton(NIL_NODE_ID, 11));
}
#[test]
fn test_v1_page_write_decoded_as_v2() {
let before = vec![1u8; 64];
let after = vec![2u8; 64];
let mut payload = Vec::new();
payload.extend_from_slice(&before);
payload.extend_from_slice(&after);
let v1 = encode_v1_record(5, 2, WalOpType::PageWrite, 99, &payload);
let (rec, _) = WalRecord::from_bytes_v1(&v1).unwrap();
assert_eq!(rec.op_type, WalOpType::PageWrite);
assert_eq!(rec.page_id, 99);
assert_eq!(&rec.data[..64], before.as_slice());
assert_eq!(&rec.data[64..], after.as_slice());
let bytes = rec.to_bytes();
let (re, _) = WalRecord::from_bytes_v2(&bytes).unwrap();
assert_eq!(re, rec);
}
#[test]
fn test_v2_record_checksum_mismatch_returns_error() {
let rec = WalRecord::commit(1, 1, dummy_origin(), Hlc::pack(1, 0), VectorClock::new());
let mut bytes = rec.to_bytes();
let last = bytes.len() - 1;
bytes[last] ^= 0xFF;
assert!(WalRecord::from_bytes_v2(&bytes).is_err());
}
#[test]
fn test_truncated_v2_record_detected() {
let rec = WalRecord::commit(1, 1, dummy_origin(), Hlc::pack(1, 0), VectorClock::new());
let bytes = rec.to_bytes();
assert!(WalRecord::from_bytes_v2(&bytes[..10]).is_err());
}
#[test]
fn test_op_type_from_u8() {
assert_eq!(WalOpType::from_u8(1), Some(WalOpType::PageWrite));
assert_eq!(WalOpType::from_u8(2), Some(WalOpType::Commit));
assert_eq!(WalOpType::from_u8(4), Some(WalOpType::Checkpoint));
assert_eq!(WalOpType::from_u8(99), None);
}
#[test]
fn test_v2_multiple_records_sequential() {
let r1 = WalRecord::page_write(
1,
1,
dummy_origin(),
Hlc::pack(1, 0),
VectorClock::new(),
5,
&[0; 100],
&[1; 100],
);
let r2 = WalRecord::commit(2, 1, dummy_origin(), Hlc::pack(2, 0), VectorClock::new());
let r3 = WalRecord::checkpoint(3, dummy_origin(), Hlc::pack(3, 0), VectorClock::new());
let mut buf = Vec::new();
buf.extend_from_slice(&r1.to_bytes());
buf.extend_from_slice(&r2.to_bytes());
buf.extend_from_slice(&r3.to_bytes());
let (d1, c1) = WalRecord::from_bytes_v2(&buf).unwrap();
let (d2, c2) = WalRecord::from_bytes_v2(&buf[c1..]).unwrap();
let (d3, _) = WalRecord::from_bytes_v2(&buf[c1 + c2..]).unwrap();
assert_eq!(d1.lsn, 1);
assert_eq!(d2.op_type, WalOpType::Commit);
assert_eq!(d3.op_type, WalOpType::Checkpoint);
}
#[test]
fn test_build_and_parse_wal_header() {
let hdr = build_wal_header(WAL_VERSION_V2);
assert_eq!(&hdr[..8], WAL_MAGIC);
assert_eq!(u16::from_le_bytes([hdr[8], hdr[9]]), WAL_VERSION_V2);
for b in &hdr[10..18] {
assert_eq!(*b, 0);
}
assert_eq!(hdr.len(), WAL_HEADER_SIZE);
assert_eq!(parse_wal_header(&hdr).unwrap(), WAL_VERSION_V2);
}
#[test]
fn test_parse_wal_header_rejects_unknown_version() {
let mut hdr = build_wal_header(99);
hdr[8] = 99;
hdr[9] = 0;
assert!(matches!(
parse_wal_header(&hdr),
Err(GrumpyError::UnsupportedWalVersion(99))
));
}
}