use std::fs::{File, OpenOptions};
use std::io::{self, BufReader, BufWriter, Read, Write};
use std::path::Path;
use crate::core::{DocId, LuciError, Result};
use crate::storage::header::xxh3_checksum;
const RECORD_HEADER_SIZE: usize = 16;
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum DurabilityMode {
Full,
Batch,
None,
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum WalRecord {
Put {
doc_id: DocId,
data: Vec<u8>,
},
Delete { doc_id: DocId },
}
const TAG_PUT: u8 = 1;
const TAG_DELETE: u8 = 2;
pub struct Wal {
writer: BufWriter<File>,
mode: DurabilityMode,
}
impl Wal {
pub fn open(path: impl AsRef<Path>, mode: DurabilityMode) -> Result<Self> {
let file = OpenOptions::new()
.create(true)
.append(true)
.open(path.as_ref())?;
Ok(Self {
writer: BufWriter::new(file),
mode,
})
}
pub fn append(&mut self, record: &WalRecord) -> Result<()> {
let (tag, payload) = encode_payload(record);
let checksum = xxh3_checksum(&payload);
let mut header = [0u8; RECORD_HEADER_SIZE];
header[0] = tag;
header[4..8].copy_from_slice(&(payload.len() as u32).to_le_bytes());
header[8..16].copy_from_slice(&checksum.to_le_bytes());
self.writer.write_all(&header)?;
self.writer.write_all(&payload)?;
if self.mode == DurabilityMode::Full {
self.writer.flush()?;
self.writer.get_ref().sync_all()?;
}
Ok(())
}
pub fn sync(&mut self) -> Result<()> {
self.writer.flush()?;
self.writer.get_ref().sync_all()?;
Ok(())
}
pub fn truncate(&mut self) -> Result<()> {
self.writer.flush()?;
self.writer.get_ref().set_len(0)?;
Ok(())
}
pub fn mode(&self) -> DurabilityMode {
self.mode
}
}
pub fn replay_wal(path: impl AsRef<Path>) -> Result<Vec<WalRecord>> {
let file = match File::open(path.as_ref()) {
Ok(f) => f,
Err(e) if e.kind() == io::ErrorKind::NotFound => return Ok(Vec::new()),
Err(e) => return Err(LuciError::Io(e)),
};
let mut reader = BufReader::new(file);
let mut records = Vec::new();
loop {
let mut header = [0u8; RECORD_HEADER_SIZE];
match reader.read_exact(&mut header) {
Ok(()) => {}
Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => break,
Err(e) => return Err(LuciError::Io(e)),
}
let tag = header[0];
let payload_length = u32::from_le_bytes(header[4..8].try_into().unwrap()) as usize;
let stored_checksum = u64::from_le_bytes(header[8..16].try_into().unwrap());
let mut payload = vec![0u8; payload_length];
match reader.read_exact(&mut payload) {
Ok(()) => {}
Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => break,
Err(e) => return Err(LuciError::Io(e)),
}
if xxh3_checksum(&payload) != stored_checksum {
break;
}
match decode_record(tag, &payload) {
Some(record) => records.push(record),
None => break,
}
}
Ok(records)
}
fn encode_payload(record: &WalRecord) -> (u8, Vec<u8>) {
match record {
WalRecord::Put { doc_id, data } => {
let mut payload = Vec::with_capacity(4 + data.len());
payload.extend_from_slice(&doc_id.as_u32().to_le_bytes());
payload.extend_from_slice(data);
(TAG_PUT, payload)
}
WalRecord::Delete { doc_id } => {
let payload = doc_id.as_u32().to_le_bytes().to_vec();
(TAG_DELETE, payload)
}
}
}
fn decode_record(tag: u8, payload: &[u8]) -> Option<WalRecord> {
match tag {
TAG_PUT => {
if payload.len() < 4 {
return None;
}
let doc_id = DocId::new(u32::from_le_bytes(payload[0..4].try_into().unwrap()));
let data = payload[4..].to_vec();
Some(WalRecord::Put { doc_id, data })
}
TAG_DELETE => {
if payload.len() < 4 {
return None;
}
let doc_id = DocId::new(u32::from_le_bytes(payload[0..4].try_into().unwrap()));
Some(WalRecord::Delete { doc_id })
}
_ => None,
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::fs;
fn test_path(name: &str) -> std::path::PathBuf {
let dir = std::env::temp_dir().join(format!("luci_wal_test_{}", std::process::id()));
fs::create_dir_all(&dir).unwrap();
dir.join(name)
}
#[test]
fn write_and_replay_put() {
let path = test_path("put.wal");
let _ = fs::remove_file(&path);
{
let mut wal = Wal::open(&path, DurabilityMode::Batch).unwrap();
wal.append(&WalRecord::Put {
doc_id: DocId::new(1),
data: b"hello".to_vec(),
})
.unwrap();
wal.sync().unwrap();
}
let records = replay_wal(&path).unwrap();
assert_eq!(records.len(), 1);
assert_eq!(
records[0],
WalRecord::Put {
doc_id: DocId::new(1),
data: b"hello".to_vec(),
}
);
fs::remove_file(&path).unwrap();
}
#[test]
fn write_and_replay_delete() {
let path = test_path("delete.wal");
let _ = fs::remove_file(&path);
{
let mut wal = Wal::open(&path, DurabilityMode::Batch).unwrap();
wal.append(&WalRecord::Delete {
doc_id: DocId::new(42),
})
.unwrap();
wal.sync().unwrap();
}
let records = replay_wal(&path).unwrap();
assert_eq!(records.len(), 1);
assert_eq!(
records[0],
WalRecord::Delete {
doc_id: DocId::new(42),
}
);
fs::remove_file(&path).unwrap();
}
#[test]
fn multiple_records() {
let path = test_path("multi.wal");
let _ = fs::remove_file(&path);
{
let mut wal = Wal::open(&path, DurabilityMode::Batch).unwrap();
for i in 0..100 {
wal.append(&WalRecord::Put {
doc_id: DocId::new(i),
data: format!("doc-{i}").into_bytes(),
})
.unwrap();
}
wal.append(&WalRecord::Delete {
doc_id: DocId::new(50),
})
.unwrap();
wal.sync().unwrap();
}
let records = replay_wal(&path).unwrap();
assert_eq!(records.len(), 101);
assert_eq!(
records[0],
WalRecord::Put {
doc_id: DocId::new(0),
data: b"doc-0".to_vec(),
}
);
assert_eq!(
records[100],
WalRecord::Delete {
doc_id: DocId::new(50),
}
);
fs::remove_file(&path).unwrap();
}
#[test]
fn truncate_clears_wal() {
let path = test_path("truncate.wal");
let _ = fs::remove_file(&path);
{
let mut wal = Wal::open(&path, DurabilityMode::Batch).unwrap();
wal.append(&WalRecord::Put {
doc_id: DocId::new(1),
data: b"data".to_vec(),
})
.unwrap();
wal.sync().unwrap();
wal.truncate().unwrap();
}
let records = replay_wal(&path).unwrap();
assert!(records.is_empty());
fs::remove_file(&path).unwrap();
}
#[test]
fn truncated_header_is_discarded() {
let path = test_path("trunc_header.wal");
let _ = fs::remove_file(&path);
{
let mut wal = Wal::open(&path, DurabilityMode::Batch).unwrap();
wal.append(&WalRecord::Put {
doc_id: DocId::new(1),
data: b"valid".to_vec(),
})
.unwrap();
wal.sync().unwrap();
}
{
let mut file = OpenOptions::new().append(true).open(&path).unwrap();
file.write_all(&[0u8; 8]).unwrap();
}
let records = replay_wal(&path).unwrap();
assert_eq!(records.len(), 1);
assert_eq!(
records[0],
WalRecord::Put {
doc_id: DocId::new(1),
data: b"valid".to_vec(),
}
);
fs::remove_file(&path).unwrap();
}
#[test]
fn truncated_payload_is_discarded() {
let path = test_path("trunc_payload.wal");
let _ = fs::remove_file(&path);
{
let mut wal = Wal::open(&path, DurabilityMode::Batch).unwrap();
wal.append(&WalRecord::Put {
doc_id: DocId::new(1),
data: b"first".to_vec(),
})
.unwrap();
wal.append(&WalRecord::Put {
doc_id: DocId::new(2),
data: b"second".to_vec(),
})
.unwrap();
wal.sync().unwrap();
}
{
let meta = fs::metadata(&path).unwrap();
let first_record_size = RECORD_HEADER_SIZE + 4 + 5; let truncated_len = first_record_size + RECORD_HEADER_SIZE + 2;
assert!(truncated_len < meta.len() as usize);
let file = OpenOptions::new().write(true).open(&path).unwrap();
file.set_len(truncated_len as u64).unwrap();
}
let records = replay_wal(&path).unwrap();
assert_eq!(records.len(), 1);
assert_eq!(
records[0],
WalRecord::Put {
doc_id: DocId::new(1),
data: b"first".to_vec(),
}
);
fs::remove_file(&path).unwrap();
}
#[test]
fn corrupted_checksum_stops_replay() {
let path = test_path("corrupt_checksum.wal");
let _ = fs::remove_file(&path);
{
let mut wal = Wal::open(&path, DurabilityMode::Batch).unwrap();
wal.append(&WalRecord::Put {
doc_id: DocId::new(1),
data: b"good".to_vec(),
})
.unwrap();
wal.append(&WalRecord::Put {
doc_id: DocId::new(2),
data: b"will-be-corrupted".to_vec(),
})
.unwrap();
wal.append(&WalRecord::Put {
doc_id: DocId::new(3),
data: b"after-corrupt".to_vec(),
})
.unwrap();
wal.sync().unwrap();
}
{
let first_record_size = RECORD_HEADER_SIZE + 4 + 4; let checksum_offset = first_record_size + 8; let mut file = OpenOptions::new()
.read(true)
.write(true)
.open(&path)
.unwrap();
let mut buf = [0u8; 1];
use std::io::{Seek, SeekFrom};
file.seek(SeekFrom::Start(checksum_offset as u64)).unwrap();
file.read_exact(&mut buf).unwrap();
buf[0] ^= 0xFF;
file.seek(SeekFrom::Start(checksum_offset as u64)).unwrap();
file.write_all(&buf).unwrap();
}
let records = replay_wal(&path).unwrap();
assert_eq!(records.len(), 1);
assert_eq!(
records[0],
WalRecord::Put {
doc_id: DocId::new(1),
data: b"good".to_vec(),
}
);
fs::remove_file(&path).unwrap();
}
#[test]
fn replay_nonexistent_file_returns_empty() {
let path = test_path("nonexistent.wal");
let _ = fs::remove_file(&path);
let records = replay_wal(&path).unwrap();
assert!(records.is_empty());
}
#[test]
fn replay_empty_file() {
let path = test_path("empty.wal");
let _ = fs::remove_file(&path);
File::create(&path).unwrap();
let records = replay_wal(&path).unwrap();
assert!(records.is_empty());
fs::remove_file(&path).unwrap();
}
#[test]
fn full_durability_mode() {
let path = test_path("full_mode.wal");
let _ = fs::remove_file(&path);
let mut wal = Wal::open(&path, DurabilityMode::Full).unwrap();
assert_eq!(wal.mode(), DurabilityMode::Full);
wal.append(&WalRecord::Put {
doc_id: DocId::new(1),
data: b"durable".to_vec(),
})
.unwrap();
drop(wal);
let records = replay_wal(&path).unwrap();
assert_eq!(records.len(), 1);
fs::remove_file(&path).unwrap();
}
#[test]
fn append_after_truncate() {
let path = test_path("append_after_trunc.wal");
let _ = fs::remove_file(&path);
let mut wal = Wal::open(&path, DurabilityMode::Batch).unwrap();
wal.append(&WalRecord::Put {
doc_id: DocId::new(1),
data: b"before".to_vec(),
})
.unwrap();
wal.sync().unwrap();
wal.truncate().unwrap();
wal.append(&WalRecord::Put {
doc_id: DocId::new(2),
data: b"after".to_vec(),
})
.unwrap();
wal.sync().unwrap();
drop(wal);
let records = replay_wal(&path).unwrap();
assert_eq!(records.len(), 1);
assert_eq!(
records[0],
WalRecord::Put {
doc_id: DocId::new(2),
data: b"after".to_vec(),
}
);
fs::remove_file(&path).unwrap();
}
#[test]
fn large_payload() {
let path = test_path("large_payload.wal");
let _ = fs::remove_file(&path);
let big_data = vec![0xCDu8; 1_000_000];
{
let mut wal = Wal::open(&path, DurabilityMode::Batch).unwrap();
wal.append(&WalRecord::Put {
doc_id: DocId::new(1),
data: big_data.clone(),
})
.unwrap();
wal.sync().unwrap();
}
let records = replay_wal(&path).unwrap();
assert_eq!(records.len(), 1);
match &records[0] {
WalRecord::Put { doc_id, data } => {
assert_eq!(*doc_id, DocId::new(1));
assert_eq!(data, &big_data);
}
_ => panic!("expected Put"),
}
fs::remove_file(&path).unwrap();
}
}