use crate::error::{Result, WalError};
pub const WAL_MAGIC: u32 = 0x5359_4E57;
pub const WAL_FORMAT_VERSION: u16 = 1;
pub const MAX_WAL_PAYLOAD_SIZE: usize = 64 * 1024 * 1024;
pub const HEADER_SIZE: usize = 30;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct RecordHeader {
pub magic: u32,
pub format_version: u16,
pub record_type: u16,
pub lsn: u64,
pub tenant_id: u32,
pub vshard_id: u16,
pub payload_len: u32,
pub crc32c: u32,
}
impl RecordHeader {
pub fn to_bytes(&self) -> [u8; HEADER_SIZE] {
let mut buf = [0u8; HEADER_SIZE];
buf[0..4].copy_from_slice(&self.magic.to_le_bytes());
buf[4..6].copy_from_slice(&self.format_version.to_le_bytes());
buf[6..8].copy_from_slice(&self.record_type.to_le_bytes());
buf[8..16].copy_from_slice(&self.lsn.to_le_bytes());
buf[16..20].copy_from_slice(&self.tenant_id.to_le_bytes());
buf[20..22].copy_from_slice(&self.vshard_id.to_le_bytes());
buf[22..26].copy_from_slice(&self.payload_len.to_le_bytes());
buf[26..30].copy_from_slice(&self.crc32c.to_le_bytes());
buf
}
pub fn from_bytes(buf: &[u8; HEADER_SIZE]) -> Self {
Self {
magic: u32::from_le_bytes([buf[0], buf[1], buf[2], buf[3]]),
format_version: u16::from_le_bytes([buf[4], buf[5]]),
record_type: u16::from_le_bytes([buf[6], buf[7]]),
lsn: u64::from_le_bytes([
buf[8], buf[9], buf[10], buf[11], buf[12], buf[13], buf[14], buf[15],
]),
tenant_id: u32::from_le_bytes([buf[16], buf[17], buf[18], buf[19]]),
vshard_id: u16::from_le_bytes([buf[20], buf[21]]),
payload_len: u32::from_le_bytes([buf[22], buf[23], buf[24], buf[25]]),
crc32c: u32::from_le_bytes([buf[26], buf[27], buf[28], buf[29]]),
}
}
pub fn compute_checksum(&self, payload: &[u8]) -> u32 {
let header_bytes = self.to_bytes();
let mut digest = crc32c::crc32c(&header_bytes[..HEADER_SIZE - 4]);
digest = crc32c::crc32c_append(digest, payload);
digest
}
pub fn logical_record_type(&self) -> u16 {
self.record_type & !ENCRYPTED_FLAG
}
pub fn validate(&self, offset: u64) -> Result<()> {
if self.magic != WAL_MAGIC {
return Err(WalError::InvalidMagic {
offset,
expected: WAL_MAGIC,
actual: self.magic,
});
}
if self.format_version > WAL_FORMAT_VERSION {
return Err(WalError::UnsupportedVersion {
version: self.format_version,
supported: WAL_FORMAT_VERSION,
});
}
Ok(())
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[repr(u16)]
pub enum RecordType {
Noop = 0,
Put = 1 | 0x8000,
Delete = 2 | 0x8000,
VectorPut = 10 | 0x8000,
VectorDelete = 11 | 0x8000,
VectorParams = 12 | 0x8000,
CrdtDelta = 20 | 0x8000,
TimeseriesBatch = 30,
LogBatch = 31,
Transaction = 50 | 0x8000,
Checkpoint = 100 | 0x8000,
}
impl RecordType {
pub fn is_required(raw: u16) -> bool {
raw & 0x8000 != 0
}
pub fn from_raw(raw: u16) -> Option<Self> {
match raw {
0 => Some(Self::Noop),
x if x == 1 | 0x8000 => Some(Self::Put),
x if x == 2 | 0x8000 => Some(Self::Delete),
x if x == 10 | 0x8000 => Some(Self::VectorPut),
x if x == 11 | 0x8000 => Some(Self::VectorDelete),
x if x == 12 | 0x8000 => Some(Self::VectorParams),
x if x == 20 | 0x8000 => Some(Self::CrdtDelta),
x if x == 50 | 0x8000 => Some(Self::Transaction),
30 => Some(Self::TimeseriesBatch),
31 => Some(Self::LogBatch),
x if x == 100 | 0x8000 => Some(Self::Checkpoint),
_ => None,
}
}
}
#[derive(Debug, Clone)]
pub struct WalRecord {
pub header: RecordHeader,
pub payload: Vec<u8>,
}
impl WalRecord {
pub fn new(
record_type: u16,
lsn: u64,
tenant_id: u32,
vshard_id: u16,
payload: Vec<u8>,
encryption_key: Option<&crate::crypto::WalEncryptionKey>,
) -> Result<Self> {
if payload.len() > MAX_WAL_PAYLOAD_SIZE {
return Err(WalError::PayloadTooLarge {
size: payload.len(),
max: MAX_WAL_PAYLOAD_SIZE,
});
}
let (final_payload, encrypted) = if let Some(key) = encryption_key {
let temp_header = RecordHeader {
magic: WAL_MAGIC,
format_version: WAL_FORMAT_VERSION,
record_type,
lsn,
tenant_id,
vshard_id,
payload_len: 0, crc32c: 0,
};
let header_bytes = temp_header.to_bytes();
let ciphertext = key.encrypt(lsn, &header_bytes, &payload)?;
(ciphertext, true)
} else {
(payload, false)
};
let record_type = if encrypted {
record_type | ENCRYPTED_FLAG
} else {
record_type
};
let mut header = RecordHeader {
magic: WAL_MAGIC,
format_version: WAL_FORMAT_VERSION,
record_type,
lsn,
tenant_id,
vshard_id,
payload_len: final_payload.len() as u32,
crc32c: 0,
};
header.crc32c = header.compute_checksum(&final_payload);
Ok(Self {
header,
payload: final_payload,
})
}
pub fn decrypt_payload(
&self,
encryption_key: Option<&crate::crypto::WalEncryptionKey>,
) -> Result<Vec<u8>> {
if !self.is_encrypted() {
return Ok(self.payload.clone());
}
let key = encryption_key.ok_or_else(|| WalError::EncryptionError {
detail: "record is encrypted but no decryption key provided".into(),
})?;
let mut aad_header = self.header;
aad_header.record_type &= !ENCRYPTED_FLAG;
aad_header.payload_len = 0;
aad_header.crc32c = 0;
let header_bytes = aad_header.to_bytes();
key.decrypt(self.header.lsn, &header_bytes, &self.payload)
}
pub fn decrypt_payload_ring(&self, ring: Option<&crate::crypto::KeyRing>) -> Result<Vec<u8>> {
if !self.is_encrypted() {
return Ok(self.payload.clone());
}
let ring = ring.ok_or_else(|| WalError::EncryptionError {
detail: "record is encrypted but no decryption key ring provided".into(),
})?;
let mut aad_header = self.header;
aad_header.record_type &= !ENCRYPTED_FLAG;
aad_header.payload_len = 0;
aad_header.crc32c = 0;
let header_bytes = aad_header.to_bytes();
ring.decrypt(self.header.lsn, &header_bytes, &self.payload)
}
pub fn is_encrypted(&self) -> bool {
self.header.record_type & ENCRYPTED_FLAG != 0
}
pub fn logical_record_type(&self) -> u16 {
self.header.record_type & !ENCRYPTED_FLAG
}
pub fn verify_checksum(&self) -> Result<()> {
let expected = self.header.crc32c;
let actual = self.header.compute_checksum(&self.payload);
if expected != actual {
return Err(WalError::ChecksumMismatch {
lsn: self.header.lsn,
expected,
actual,
});
}
Ok(())
}
pub fn wire_size(&self) -> usize {
HEADER_SIZE + self.payload.len()
}
}
pub const ENCRYPTED_FLAG: u16 = 0x4000;
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn header_roundtrip() {
let header = RecordHeader {
magic: WAL_MAGIC,
format_version: WAL_FORMAT_VERSION,
record_type: RecordType::Put as u16,
lsn: 42,
tenant_id: 7,
vshard_id: 3,
payload_len: 100,
crc32c: 0xDEAD_BEEF,
};
let bytes = header.to_bytes();
let decoded = RecordHeader::from_bytes(&bytes);
assert_eq!(header, decoded);
}
#[test]
fn checksum_roundtrip() {
let payload = b"hello nodedb";
let record =
WalRecord::new(RecordType::Put as u16, 1, 0, 0, payload.to_vec(), None).unwrap();
record.verify_checksum().unwrap();
}
#[test]
fn checksum_detects_corruption() {
let payload = b"hello nodedb";
let mut record =
WalRecord::new(RecordType::Put as u16, 1, 0, 0, payload.to_vec(), None).unwrap();
record.payload[0] ^= 0xFF;
assert!(matches!(
record.verify_checksum(),
Err(WalError::ChecksumMismatch { .. })
));
}
#[test]
fn invalid_magic_detected() {
let header = RecordHeader {
magic: 0xBAD0_F00D,
format_version: WAL_FORMAT_VERSION,
record_type: 0,
lsn: 0,
tenant_id: 0,
vshard_id: 0,
payload_len: 0,
crc32c: 0,
};
assert!(matches!(
header.validate(0),
Err(WalError::InvalidMagic { .. })
));
}
#[test]
fn payload_too_large_rejected() {
let big_payload = vec![0u8; MAX_WAL_PAYLOAD_SIZE + 1];
assert!(matches!(
WalRecord::new(RecordType::Put as u16, 1, 0, 0, big_payload, None),
Err(WalError::PayloadTooLarge { .. })
));
}
#[test]
fn record_type_required_flag() {
assert!(RecordType::is_required(RecordType::Put as u16));
assert!(RecordType::is_required(RecordType::Delete as u16));
assert!(RecordType::is_required(RecordType::Checkpoint as u16));
assert!(!RecordType::is_required(RecordType::Noop as u16));
assert!(!RecordType::is_required(RecordType::TimeseriesBatch as u16));
assert!(!RecordType::is_required(RecordType::LogBatch as u16));
}
}