use std::fs::{File, OpenOptions};
use std::io::{Read, Seek, SeekFrom, Write};
use std::path::{Path, PathBuf};
use crate::error::{Result, WalError};
use crate::record::{HEADER_SIZE, RecordHeader, WAL_MAGIC, WalRecord};
const DWB_CAPACITY: usize = 64;
const DWB_HEADER_SIZE: usize = 12;
const DWB_MAGIC: u32 = 0x4457_4246;
pub struct DoubleWriteBuffer {
file: File,
path: PathBuf,
write_pos: u32,
count: u32,
dirty: bool,
}
impl std::fmt::Debug for DoubleWriteBuffer {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("DoubleWriteBuffer")
.field("path", &self.path)
.field("write_pos", &self.write_pos)
.field("count", &self.count)
.finish()
}
}
impl DoubleWriteBuffer {
pub fn open(path: &Path) -> Result<Self> {
let file = OpenOptions::new()
.read(true)
.write(true)
.create(true)
.truncate(false)
.open(path)
.map_err(|e| {
tracing::warn!(path = %path.display(), error = %e, "failed to open double-write buffer");
WalError::Io(e)
})?;
let mut dwb = Self {
file,
path: path.to_path_buf(),
write_pos: 0,
count: 0,
dirty: false,
};
let file_len = dwb.file.metadata().map(|m| m.len()).unwrap_or(0);
if file_len >= DWB_HEADER_SIZE as u64 {
let mut header = [0u8; DWB_HEADER_SIZE];
dwb.file.seek(SeekFrom::Start(0)).map_err(WalError::Io)?;
if dwb.file.read_exact(&mut header).is_ok() {
let mut arr4 = [0u8; 4];
arr4.copy_from_slice(&header[0..4]);
let magic = u32::from_le_bytes(arr4);
if magic == DWB_MAGIC {
arr4.copy_from_slice(&header[4..8]);
dwb.count = u32::from_le_bytes(arr4);
arr4.copy_from_slice(&header[8..12]);
dwb.write_pos = u32::from_le_bytes(arr4);
}
}
}
Ok(dwb)
}
pub fn write_record(&mut self, record: &WalRecord) -> Result<()> {
self.write_record_deferred(record)?;
self.flush()
}
pub fn write_record_deferred(&mut self, record: &WalRecord) -> Result<()> {
let record_bytes = record.header.to_bytes();
let total_size = HEADER_SIZE + record.payload.len();
if total_size > 64 * 1024 {
return Ok(()); }
let slot_offset = DWB_HEADER_SIZE as u64
+ (self.write_pos as u64 % DWB_CAPACITY as u64) * (4 + HEADER_SIZE as u64 + 64 * 1024);
self.file
.seek(SeekFrom::Start(slot_offset))
.map_err(WalError::Io)?;
self.file
.write_all(&(total_size as u32).to_le_bytes())
.map_err(WalError::Io)?;
self.file.write_all(&record_bytes).map_err(WalError::Io)?;
self.file.write_all(&record.payload).map_err(WalError::Io)?;
self.write_pos = self.write_pos.wrapping_add(1);
self.count = self.count.saturating_add(1).min(DWB_CAPACITY as u32);
self.dirty = true;
Ok(())
}
pub fn flush(&mut self) -> Result<()> {
if !self.dirty {
return Ok(());
}
let mut header = [0u8; DWB_HEADER_SIZE];
header[0..4].copy_from_slice(&DWB_MAGIC.to_le_bytes());
header[4..8].copy_from_slice(&self.count.to_le_bytes());
header[8..12].copy_from_slice(&self.write_pos.to_le_bytes());
self.file.seek(SeekFrom::Start(0)).map_err(WalError::Io)?;
self.file.write_all(&header).map_err(WalError::Io)?;
self.file.sync_all().map_err(WalError::Io)?;
self.dirty = false;
Ok(())
}
pub fn path(&self) -> &Path {
&self.path
}
pub fn recover_record(&mut self, target_lsn: u64) -> Result<Option<WalRecord>> {
let slot_size = 4 + HEADER_SIZE as u64 + 64 * 1024;
for i in 0..DWB_CAPACITY {
let slot_offset = DWB_HEADER_SIZE as u64 + (i as u64) * slot_size;
self.file
.seek(SeekFrom::Start(slot_offset))
.map_err(WalError::Io)?;
let mut size_buf = [0u8; 4];
if self.file.read_exact(&mut size_buf).is_err() {
continue;
}
let total_size = u32::from_le_bytes(size_buf) as usize;
if !(HEADER_SIZE..=64 * 1024).contains(&total_size) {
continue;
}
let mut header_buf = [0u8; HEADER_SIZE];
if self.file.read_exact(&mut header_buf).is_err() {
continue;
}
let header = RecordHeader::from_bytes(&header_buf);
if header.magic != WAL_MAGIC || header.lsn != target_lsn {
continue;
}
let payload_len = total_size - HEADER_SIZE;
let mut payload = vec![0u8; payload_len];
if self.file.read_exact(&mut payload).is_err() {
continue;
}
let record = WalRecord { header, payload };
if record.verify_checksum().is_ok() {
return Ok(Some(record));
}
}
Ok(None)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::record::RecordType;
#[test]
fn write_and_recover() {
let dir = tempfile::tempdir().unwrap();
let dwb_path = dir.path().join("test.dwb");
let mut dwb = DoubleWriteBuffer::open(&dwb_path).unwrap();
let record = WalRecord::new(
RecordType::Put as u16,
42,
1,
0,
b"hello double-write".to_vec(),
None,
)
.unwrap();
dwb.write_record(&record).unwrap();
let recovered = dwb.recover_record(42).unwrap();
assert!(recovered.is_some());
let rec = recovered.unwrap();
assert_eq!(rec.header.lsn, 42);
assert_eq!(rec.payload, b"hello double-write");
}
#[test]
fn recover_nonexistent_returns_none() {
let dir = tempfile::tempdir().unwrap();
let dwb_path = dir.path().join("test2.dwb");
let mut dwb = DoubleWriteBuffer::open(&dwb_path).unwrap();
let result = dwb.recover_record(999).unwrap();
assert!(result.is_none());
}
#[test]
fn survives_reopen() {
let dir = tempfile::tempdir().unwrap();
let dwb_path = dir.path().join("reopen.dwb");
{
let mut dwb = DoubleWriteBuffer::open(&dwb_path).unwrap();
let record =
WalRecord::new(RecordType::Put as u16, 7, 1, 0, b"durable".to_vec(), None).unwrap();
dwb.write_record(&record).unwrap();
}
let mut dwb = DoubleWriteBuffer::open(&dwb_path).unwrap();
let recovered = dwb.recover_record(7).unwrap();
assert!(recovered.is_some());
assert_eq!(recovered.unwrap().payload, b"durable");
}
#[test]
fn batch_deferred_writes_and_flush() {
let dir = tempfile::tempdir().unwrap();
let dwb_path = dir.path().join("batch.dwb");
let mut dwb = DoubleWriteBuffer::open(&dwb_path).unwrap();
for lsn in 1..=5u64 {
let record = WalRecord::new(
RecordType::Put as u16,
lsn,
1,
0,
format!("batch-{lsn}").into_bytes(),
None,
)
.unwrap();
dwb.write_record_deferred(&record).unwrap();
}
assert!(dwb.dirty);
dwb.flush().unwrap();
assert!(!dwb.dirty);
for lsn in 1..=5u64 {
let recovered = dwb.recover_record(lsn).unwrap();
assert!(recovered.is_some(), "LSN {lsn} should be recoverable");
assert_eq!(
recovered.unwrap().payload,
format!("batch-{lsn}").into_bytes()
);
}
}
#[test]
fn flush_is_idempotent() {
let dir = tempfile::tempdir().unwrap();
let dwb_path = dir.path().join("idem.dwb");
let mut dwb = DoubleWriteBuffer::open(&dwb_path).unwrap();
dwb.flush().unwrap();
assert!(!dwb.dirty);
let record =
WalRecord::new(RecordType::Put as u16, 1, 1, 0, b"data".to_vec(), None).unwrap();
dwb.write_record_deferred(&record).unwrap();
dwb.flush().unwrap();
dwb.flush().unwrap(); assert!(!dwb.dirty);
}
#[test]
fn recover_after_wraparound() {
let dir = tempfile::tempdir().unwrap();
let dwb_path = dir.path().join("wrap.dwb");
let mut dwb = DoubleWriteBuffer::open(&dwb_path).unwrap();
let total = super::DWB_CAPACITY as u64 + 5;
for lsn in 1..=total {
let record = WalRecord::new(
RecordType::Put as u16,
lsn,
1,
0,
format!("wrap-{lsn}").into_bytes(),
None,
)
.unwrap();
dwb.write_record_deferred(&record).unwrap();
}
dwb.flush().unwrap();
for lsn in (total - 4)..=total {
let recovered = dwb.recover_record(lsn).unwrap();
assert!(
recovered.is_some(),
"LSN {lsn} should be recoverable after wrap-around"
);
assert_eq!(
recovered.unwrap().payload,
format!("wrap-{lsn}").into_bytes()
);
}
for lsn in 1..=5u64 {
let recovered = dwb.recover_record(lsn).unwrap();
assert!(
recovered.is_none(),
"LSN {lsn} should have been overwritten by wrap-around"
);
}
}
}