use std::io::Write;
use anyhow::{anyhow, bail, Result};
pub const MAGIC: [u8; 4] = *b"DWAL";
pub const WIRE_VERSION: u16 = 1;
pub const HEADER_LEN: usize = 24;
pub const CRC_LEN: usize = 4;
pub const MAX_KEY_LEN: u32 = 64 * 1024;
pub const MAX_PAYLOAD_LEN: u32 = 64 * 1024 * 1024;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum RecordType {
Raw = 0,
Put = 1,
Delete = 2,
}
impl RecordType {
pub fn from_byte(b: u8) -> Result<Self> {
match b {
0 => Ok(RecordType::Raw),
1 => Ok(RecordType::Put),
2 => Ok(RecordType::Delete),
other => Err(anyhow!("datawal: unknown record_type byte {}", other)),
}
}
pub fn as_byte(self) -> u8 {
self as u8
}
}
pub fn record_wire_size(key_len: u32, payload_len: u32) -> u64 {
HEADER_LEN as u64 + key_len as u64 + payload_len as u64 + CRC_LEN as u64
}
pub fn encode_record(
record_type: RecordType,
txid: u64,
key: &[u8],
payload: &[u8],
) -> Result<Vec<u8>> {
if key.len() as u64 > MAX_KEY_LEN as u64 {
bail!(
"datawal: key_len {} exceeds MAX_KEY_LEN {}",
key.len(),
MAX_KEY_LEN
);
}
if payload.len() as u64 > MAX_PAYLOAD_LEN as u64 {
bail!(
"datawal: payload_len {} exceeds MAX_PAYLOAD_LEN {}",
payload.len(),
MAX_PAYLOAD_LEN
);
}
let key_len = key.len() as u32;
let payload_len = payload.len() as u32;
let total = record_wire_size(key_len, payload_len) as usize;
let mut buf = Vec::with_capacity(total);
buf.write_all(&MAGIC)?;
buf.write_all(&WIRE_VERSION.to_le_bytes())?;
buf.write_all(&[record_type.as_byte()])?;
buf.write_all(&[0u8])?; buf.write_all(&txid.to_le_bytes())?;
buf.write_all(&key_len.to_le_bytes())?;
buf.write_all(&payload_len.to_le_bytes())?;
buf.write_all(key)?;
buf.write_all(payload)?;
let crc = crc32c::crc32c(&buf);
buf.write_all(&crc.to_le_bytes())?;
debug_assert_eq!(buf.len(), total);
Ok(buf)
}
#[derive(Debug)]
pub enum DecodeOutcome {
Ok {
record_type: RecordType,
txid: u64,
key: Vec<u8>,
payload: Vec<u8>,
bytes_consumed: u32,
},
Truncated {
available: u64,
needed: u64,
},
CrcMismatch {
bytes_consumed: u32,
},
}
#[derive(Debug)]
pub enum DecodeError {
BadMagic { found: [u8; 4] },
UnknownVersion { found: u16 },
UnknownRecordType { found: u8 },
ReservedFlagsSet { found: u8 },
KeyTooLarge { found: u32 },
PayloadTooLarge { found: u32 },
}
impl std::fmt::Display for DecodeError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
DecodeError::BadMagic { found } => {
write!(f, "datawal: bad magic, expected DWAL, got {:?}", found)
}
DecodeError::UnknownVersion { found } => {
write!(
f,
"datawal: unknown wire version {} (this build supports {})",
found, WIRE_VERSION
)
}
DecodeError::UnknownRecordType { found } => {
write!(f, "datawal: unknown record_type byte {}", found)
}
DecodeError::ReservedFlagsSet { found } => {
write!(f, "datawal: reserved flags byte must be 0, got {}", found)
}
DecodeError::KeyTooLarge { found } => {
write!(
f,
"datawal: key_len {} exceeds MAX_KEY_LEN {}",
found, MAX_KEY_LEN
)
}
DecodeError::PayloadTooLarge { found } => {
write!(
f,
"datawal: payload_len {} exceeds MAX_PAYLOAD_LEN {}",
found, MAX_PAYLOAD_LEN
)
}
}
}
}
impl std::error::Error for DecodeError {}
pub fn decode_next(buf: &[u8], offset: u64) -> std::result::Result<DecodeOutcome, DecodeError> {
let off = offset as usize;
if off + HEADER_LEN > buf.len() {
return Ok(DecodeOutcome::Truncated {
available: (buf.len().saturating_sub(off)) as u64,
needed: HEADER_LEN as u64,
});
}
let header = &buf[off..off + HEADER_LEN];
let mut magic = [0u8; 4];
magic.copy_from_slice(&header[0..4]);
if magic != MAGIC {
return Err(DecodeError::BadMagic { found: magic });
}
let version = u16::from_le_bytes([header[4], header[5]]);
if version != WIRE_VERSION {
return Err(DecodeError::UnknownVersion { found: version });
}
let rt_byte = header[6];
let record_type = match rt_byte {
0..=2 => rt_byte,
other => return Err(DecodeError::UnknownRecordType { found: other }),
};
let flags = header[7];
if flags != 0 {
return Err(DecodeError::ReservedFlagsSet { found: flags });
}
let txid = u64::from_le_bytes([
header[8], header[9], header[10], header[11], header[12], header[13], header[14],
header[15],
]);
let key_len = u32::from_le_bytes([header[16], header[17], header[18], header[19]]);
let payload_len = u32::from_le_bytes([header[20], header[21], header[22], header[23]]);
if key_len > MAX_KEY_LEN {
return Err(DecodeError::KeyTooLarge { found: key_len });
}
if payload_len > MAX_PAYLOAD_LEN {
return Err(DecodeError::PayloadTooLarge { found: payload_len });
}
let total = record_wire_size(key_len, payload_len) as usize;
if off + total > buf.len() {
return Ok(DecodeOutcome::Truncated {
available: (buf.len() - off) as u64,
needed: total as u64,
});
}
let key_start = off + HEADER_LEN;
let payload_start = key_start + key_len as usize;
let crc_start = payload_start + payload_len as usize;
let crc_expected = u32::from_le_bytes([
buf[crc_start],
buf[crc_start + 1],
buf[crc_start + 2],
buf[crc_start + 3],
]);
let crc_actual = crc32c::crc32c(&buf[off..crc_start]);
if crc_actual != crc_expected {
return Ok(DecodeOutcome::CrcMismatch {
bytes_consumed: total as u32,
});
}
Ok(DecodeOutcome::Ok {
record_type: RecordType::from_byte(record_type).expect("validated above"),
txid,
key: buf[key_start..payload_start].to_vec(),
payload: buf[payload_start..crc_start].to_vec(),
bytes_consumed: total as u32,
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn crc32c_known_vector() {
assert_eq!(crc32c::crc32c(b""), 0x0000_0000);
assert_eq!(crc32c::crc32c(b"123456789"), 0xE306_9283);
assert_eq!(crc32c::crc32c(&[0u8; 32]), 0x8A91_36AA);
assert_ne!(crc32c::crc32c(b"123456789"), 0xCBF4_3926);
}
#[test]
fn roundtrip_raw_empty_key() {
let buf = encode_record(RecordType::Raw, 7, b"", b"hello").unwrap();
let out = decode_next(&buf, 0).unwrap();
match out {
DecodeOutcome::Ok {
record_type,
txid,
key,
payload,
bytes_consumed,
} => {
assert_eq!(record_type, RecordType::Raw);
assert_eq!(txid, 7);
assert!(key.is_empty());
assert_eq!(payload, b"hello");
assert_eq!(bytes_consumed as usize, buf.len());
}
_ => panic!("expected Ok, got {:?}", out),
}
}
#[test]
fn roundtrip_put_with_key() {
let buf = encode_record(RecordType::Put, 42, b"alpha", b"value-1").unwrap();
let out = decode_next(&buf, 0).unwrap();
match out {
DecodeOutcome::Ok {
record_type,
txid,
key,
payload,
..
} => {
assert_eq!(record_type, RecordType::Put);
assert_eq!(txid, 42);
assert_eq!(key, b"alpha");
assert_eq!(payload, b"value-1");
}
_ => panic!("expected Ok"),
}
}
#[test]
fn truncated_header_reports_needed() {
let buf = encode_record(RecordType::Raw, 1, b"", b"x").unwrap();
let short = &buf[..10];
match decode_next(short, 0).unwrap() {
DecodeOutcome::Truncated { available, needed } => {
assert_eq!(available, 10);
assert_eq!(needed, HEADER_LEN as u64);
}
other => panic!("expected Truncated, got {:?}", other),
}
}
#[test]
fn truncated_body_reports_total() {
let buf = encode_record(RecordType::Raw, 1, b"", b"hello").unwrap();
let short = &buf[..HEADER_LEN + 2];
match decode_next(short, 0).unwrap() {
DecodeOutcome::Truncated { needed, .. } => {
assert_eq!(needed, buf.len() as u64);
}
other => panic!("expected Truncated, got {:?}", other),
}
}
#[test]
fn crc_mismatch_detected() {
let mut buf = encode_record(RecordType::Raw, 1, b"", b"hello").unwrap();
let n = buf.len();
buf[n - 5] ^= 0xff;
match decode_next(&buf, 0).unwrap() {
DecodeOutcome::CrcMismatch { bytes_consumed } => {
assert_eq!(bytes_consumed as usize, n);
}
other => panic!("expected CrcMismatch, got {:?}", other),
}
}
#[test]
fn bad_magic_is_hard_error() {
let mut buf = encode_record(RecordType::Raw, 1, b"", b"x").unwrap();
buf[0] = b'X';
match decode_next(&buf, 0) {
Err(DecodeError::BadMagic { .. }) => {}
other => panic!("expected BadMagic, got {:?}", other),
}
}
#[test]
fn unknown_version_is_hard_error() {
let mut buf = encode_record(RecordType::Raw, 1, b"", b"x").unwrap();
buf[4] = 99;
buf[5] = 0;
match decode_next(&buf, 0) {
Err(DecodeError::UnknownVersion { found: 99 }) => {}
other => panic!("expected UnknownVersion, got {:?}", other),
}
}
#[test]
fn unknown_record_type_is_hard_error() {
let mut buf = encode_record(RecordType::Raw, 1, b"", b"x").unwrap();
buf[6] = 200;
match decode_next(&buf, 0) {
Err(DecodeError::UnknownRecordType { found: 200 }) => {}
other => panic!("expected UnknownRecordType, got {:?}", other),
}
}
#[test]
fn reserved_flags_must_be_zero() {
let mut buf = encode_record(RecordType::Raw, 1, b"", b"x").unwrap();
buf[7] = 1;
match decode_next(&buf, 0) {
Err(DecodeError::ReservedFlagsSet { found: 1 }) => {}
other => panic!("expected ReservedFlagsSet, got {:?}", other),
}
}
#[test]
fn encode_rejects_oversize_key() {
let big = vec![0u8; (MAX_KEY_LEN as usize) + 1];
let err = encode_record(RecordType::Put, 1, &big, b"").unwrap_err();
assert!(format!("{err}").contains("MAX_KEY_LEN"));
}
#[test]
fn encode_rejects_oversize_payload() {
let big = vec![0u8; (MAX_PAYLOAD_LEN as usize) + 1];
let err = encode_record(RecordType::Raw, 1, b"", &big).unwrap_err();
assert!(format!("{err}").contains("MAX_PAYLOAD_LEN"));
}
}