use std::sync::Arc;
use wal_db::Wal;
use crate::error::{Result, TxnError};
use crate::store::WriteEntry;
use crate::timestamp::Timestamp;
pub(crate) struct RecoveredCommit {
pub(crate) commit_ts: Timestamp,
pub(crate) writes: Vec<WriteEntry>,
}
pub(crate) struct CommitLog {
wal: Wal,
}
impl CommitLog {
pub(crate) fn open(path: impl AsRef<std::path::Path>) -> Result<(Self, Vec<RecoveredCommit>)> {
let wal = Wal::open(path).map_err(TxnError::durability)?;
let mut recovered = Vec::new();
for entry in wal.iter().map_err(TxnError::durability)? {
let entry = entry.map_err(TxnError::durability)?;
recovered.push(decode_commit(entry.data())?);
}
Ok((CommitLog { wal }, recovered))
}
pub(crate) fn append_committed(&self, record: &[u8]) -> Result<()> {
self.wal
.append_and_sync(record)
.map(|_lsn| ())
.map_err(TxnError::durability)
}
}
const MIN_WRITE_BYTES: usize = 4 + 1;
pub(crate) fn encode_for_log(commit_ts: Timestamp, writes: &[WriteEntry]) -> Vec<u8> {
let body: usize = writes
.iter()
.map(|(key, value)| 4 + key.len() + 1 + value.as_ref().map_or(0, |v| 4 + v.len()))
.sum();
let mut buf = Vec::with_capacity(8 + 4 + body);
buf.extend_from_slice(&commit_ts.get().to_le_bytes());
buf.extend_from_slice(&(writes.len() as u32).to_le_bytes());
for (key, value) in writes {
buf.extend_from_slice(&(key.len() as u32).to_le_bytes());
buf.extend_from_slice(key);
match value {
Some(v) => {
buf.push(1);
buf.extend_from_slice(&(v.len() as u32).to_le_bytes());
buf.extend_from_slice(v);
}
None => buf.push(0),
}
}
buf
}
fn decode_commit(bytes: &[u8]) -> Result<RecoveredCommit> {
let mut reader = Reader::new(bytes);
let commit_ts = Timestamp::from_raw(reader.read_u64()?);
let count = reader.read_u32()? as usize;
if count > reader.remaining() / MIN_WRITE_BYTES {
return Err(corrupt("write count exceeds record size"));
}
let mut writes = Vec::with_capacity(count);
for _ in 0..count {
let key_len = reader.read_u32()? as usize;
let key: Arc<[u8]> = Arc::from(reader.read_bytes(key_len)?);
let value = match reader.read_u8()? {
0 => None,
1 => {
let value_len = reader.read_u32()? as usize;
Some(Arc::from(reader.read_bytes(value_len)?))
}
other => return Err(corrupt_tag(other)),
};
writes.push((key, value));
}
if reader.remaining() != 0 {
return Err(corrupt("trailing bytes after commit record"));
}
Ok(RecoveredCommit { commit_ts, writes })
}
struct Reader<'a> {
buf: &'a [u8],
pos: usize,
}
impl<'a> Reader<'a> {
fn new(buf: &'a [u8]) -> Self {
Reader { buf, pos: 0 }
}
fn remaining(&self) -> usize {
self.buf.len() - self.pos
}
fn read_bytes(&mut self, len: usize) -> Result<&'a [u8]> {
let end = self
.pos
.checked_add(len)
.filter(|&end| end <= self.buf.len())
.ok_or_else(|| corrupt("record ends mid-field"))?;
let slice = &self.buf[self.pos..end];
self.pos = end;
Ok(slice)
}
fn read_u8(&mut self) -> Result<u8> {
Ok(self.read_bytes(1)?[0])
}
fn read_u32(&mut self) -> Result<u32> {
let bytes = self.read_bytes(4)?;
Ok(u32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]))
}
fn read_u64(&mut self) -> Result<u64> {
let b = self.read_bytes(8)?;
Ok(u64::from_le_bytes([
b[0], b[1], b[2], b[3], b[4], b[5], b[6], b[7],
]))
}
}
fn corrupt(reason: &str) -> TxnError {
TxnError::durability(format!("malformed commit record: {reason}"))
}
fn corrupt_tag(tag: u8) -> TxnError {
TxnError::durability(format!("malformed commit record: invalid value tag {tag}"))
}
#[cfg(all(test, not(loom)))]
#[allow(clippy::unwrap_used, clippy::expect_used)]
mod tests {
use super::*;
fn entry(key: &[u8], value: Option<&[u8]>) -> WriteEntry {
(Arc::from(key), value.map(Arc::from))
}
#[test]
fn test_encode_decode_roundtrip() {
let writes = vec![
entry(b"alice", Some(b"100")),
entry(b"bob", None),
entry(b"", Some(b"")),
];
let bytes = encode_for_log(Timestamp::from_raw(42), &writes);
let decoded = decode_commit(&bytes).unwrap();
assert_eq!(decoded.commit_ts, Timestamp::from_raw(42));
assert_eq!(decoded.writes, writes);
}
#[test]
fn test_decode_empty_write_set() {
let bytes = encode_for_log(Timestamp::from_raw(7), &[]);
let decoded = decode_commit(&bytes).unwrap();
assert_eq!(decoded.commit_ts, Timestamp::from_raw(7));
assert!(decoded.writes.is_empty());
}
#[test]
fn test_decode_truncated_record_is_rejected() {
let bytes = encode_for_log(Timestamp::from_raw(1), &[entry(b"k", Some(b"v"))]);
for cut in 0..bytes.len() {
assert!(decode_commit(&bytes[..cut]).is_err());
}
}
#[test]
fn test_decode_rejects_implausible_count() {
let mut bytes = Vec::new();
bytes.extend_from_slice(&0u64.to_le_bytes());
bytes.extend_from_slice(&u32::MAX.to_le_bytes());
assert!(decode_commit(&bytes).is_err());
}
#[test]
fn test_decode_rejects_trailing_bytes() {
let mut bytes = encode_for_log(Timestamp::from_raw(1), &[]);
bytes.push(0xff);
assert!(decode_commit(&bytes).is_err());
}
#[test]
fn test_decode_rejects_bad_value_tag() {
let mut bytes = Vec::new();
bytes.extend_from_slice(&1u64.to_le_bytes());
bytes.extend_from_slice(&1u32.to_le_bytes()); bytes.extend_from_slice(&1u32.to_le_bytes()); bytes.push(b'k');
bytes.push(9); assert!(decode_commit(&bytes).is_err());
}
}