use std::fs::{File, OpenOptions};
use std::io::{Read, Seek, SeekFrom, Write};
use std::os::unix::fs::OpenOptionsExt;
use std::os::unix::io::AsRawFd;
use std::path::{Path, PathBuf};
use std::sync::atomic::{AtomicU64, Ordering};
use crate::align::{AlignedBuf, DEFAULT_ALIGNMENT, is_aligned};
use crate::error::{Result, WalError};
use crate::record::{HEADER_SIZE, RecordHeader, WAL_MAGIC, WalRecord};
const DWB_CAPACITY: usize = 64;
const DWB_SLOT_PAYLOAD_MAX: usize = 64 * 1024;
const DWB_SLOT_RAW: usize = 4 + HEADER_SIZE + DWB_SLOT_PAYLOAD_MAX;
const DWB_SLOT_STRIDE: usize = round_up_const(DWB_SLOT_RAW, DEFAULT_ALIGNMENT);
const DWB_HEADER_STRIDE: usize = DEFAULT_ALIGNMENT;
const DWB_HEADER_FIELDS: usize = 12;
const DWB_MAGIC: u32 = 0x4457_4246;
static DWB_BYTES_WRITTEN_TOTAL: AtomicU64 = AtomicU64::new(0);
pub fn wal_dwb_bytes_written_total() -> u64 {
DWB_BYTES_WRITTEN_TOTAL.load(Ordering::Relaxed)
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum DwbMode {
Off,
Buffered,
Direct,
}
impl DwbMode {
pub fn default_for_parent(parent_uses_direct_io: bool) -> Self {
if parent_uses_direct_io {
Self::Direct
} else {
Self::Buffered
}
}
}
const fn round_up_const(value: usize, align: usize) -> usize {
(value + align - 1) & !(align - 1)
}
pub const fn slot_stride() -> usize {
DWB_SLOT_STRIDE
}
fn slot_offset(idx: u32) -> u64 {
DWB_HEADER_STRIDE as u64 + (idx as u64 % DWB_CAPACITY as u64) * DWB_SLOT_STRIDE as u64
}
pub struct DoubleWriteBuffer {
file: File,
path: PathBuf,
mode: DwbMode,
write_pos: u32,
count: u32,
dirty: bool,
slot_buf: Option<AlignedBuf>,
header_buf: Option<AlignedBuf>,
}
impl std::fmt::Debug for DoubleWriteBuffer {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("DoubleWriteBuffer")
.field("path", &self.path)
.field("mode", &self.mode)
.field("write_pos", &self.write_pos)
.field("count", &self.count)
.finish()
}
}
impl DoubleWriteBuffer {
pub fn open(path: &Path, mode: DwbMode) -> Result<Self> {
if mode == DwbMode::Off {
return Err(WalError::DwbOffNotOpenable);
}
let mut opts = OpenOptions::new();
opts.read(true).write(true).create(true).truncate(false);
if mode == DwbMode::Direct {
opts.custom_flags(libc::O_DIRECT);
}
let file = opts.open(path).map_err(|e| {
tracing::warn!(path = %path.display(), error = %e, mode = ?mode, "failed to open double-write buffer");
WalError::Io(e)
})?;
let (slot_buf, header_buf) = if mode == DwbMode::Direct {
(
Some(AlignedBuf::new(DWB_SLOT_STRIDE, DEFAULT_ALIGNMENT)?),
Some(AlignedBuf::new(DWB_HEADER_STRIDE, DEFAULT_ALIGNMENT)?),
)
} else {
(None, None)
};
let mut dwb = Self {
file,
path: path.to_path_buf(),
mode,
write_pos: 0,
count: 0,
dirty: false,
slot_buf,
header_buf,
};
let file_len = dwb.file.metadata().map(|m| m.len()).unwrap_or(0);
if file_len >= DWB_HEADER_STRIDE as u64 {
let mut block = vec![0u8; DWB_HEADER_STRIDE];
dwb.file.seek(SeekFrom::Start(0)).map_err(WalError::Io)?;
if dwb.file.read_exact(&mut block).is_ok() {
let mut arr4 = [0u8; 4];
arr4.copy_from_slice(&block[0..4]);
let magic = u32::from_le_bytes(arr4);
if magic == DWB_MAGIC {
arr4.copy_from_slice(&block[4..8]);
dwb.count = u32::from_le_bytes(arr4);
arr4.copy_from_slice(&block[8..12]);
dwb.write_pos = u32::from_le_bytes(arr4);
}
}
}
Ok(dwb)
}
pub fn mode(&self) -> DwbMode {
self.mode
}
pub fn write_record(&mut self, record: &WalRecord) -> Result<()> {
self.write_record_deferred(record)?;
self.flush()
}
pub fn write_record_deferred(&mut self, record: &WalRecord) -> Result<()> {
let total_size = HEADER_SIZE + record.payload.len();
if total_size > DWB_SLOT_PAYLOAD_MAX {
return Ok(()); }
let header_bytes = record.header.to_bytes();
let offset = slot_offset(self.write_pos);
match self.mode {
DwbMode::Off => unreachable!("Off never opens a DoubleWriteBuffer"),
DwbMode::Buffered => {
self.file
.seek(SeekFrom::Start(offset))
.map_err(WalError::Io)?;
self.file
.write_all(&(total_size as u32).to_le_bytes())
.map_err(WalError::Io)?;
self.file.write_all(&header_bytes).map_err(WalError::Io)?;
self.file.write_all(&record.payload).map_err(WalError::Io)?;
DWB_BYTES_WRITTEN_TOTAL.fetch_add(
(4 + header_bytes.len() + record.payload.len()) as u64,
Ordering::Relaxed,
);
}
DwbMode::Direct => {
let buf = self
.slot_buf
.as_mut()
.expect("slot_buf present in Direct mode");
buf.clear();
buf.write(&(total_size as u32).to_le_bytes());
buf.write(&header_bytes);
buf.write(&record.payload);
zero_tail(buf);
let slice = full_capacity_slice(buf);
debug_assert_eq!(slice.len(), DWB_SLOT_STRIDE);
debug_assert!(is_aligned(offset as usize, DEFAULT_ALIGNMENT));
pwrite_all(&self.file, slice, offset)?;
DWB_BYTES_WRITTEN_TOTAL.fetch_add(slice.len() as u64, Ordering::Relaxed);
}
}
self.write_pos = self.write_pos.wrapping_add(1);
self.count = self.count.saturating_add(1).min(DWB_CAPACITY as u32);
self.dirty = true;
Ok(())
}
pub fn flush(&mut self) -> Result<()> {
if !self.dirty {
return Ok(());
}
let mut header = [0u8; DWB_HEADER_FIELDS];
header[0..4].copy_from_slice(&DWB_MAGIC.to_le_bytes());
header[4..8].copy_from_slice(&self.count.to_le_bytes());
header[8..12].copy_from_slice(&self.write_pos.to_le_bytes());
match self.mode {
DwbMode::Off => unreachable!("invariant: flush() is gated on mode != Off by caller"),
DwbMode::Buffered => {
self.file.seek(SeekFrom::Start(0)).map_err(WalError::Io)?;
self.file.write_all(&header).map_err(WalError::Io)?;
DWB_BYTES_WRITTEN_TOTAL.fetch_add(header.len() as u64, Ordering::Relaxed);
}
DwbMode::Direct => {
let buf = self
.header_buf
.as_mut()
.expect("header_buf present in Direct mode");
buf.clear();
buf.write(&header);
zero_tail(buf);
let slice = full_capacity_slice(buf);
debug_assert_eq!(slice.len(), DWB_HEADER_STRIDE);
pwrite_all(&self.file, slice, 0)?;
DWB_BYTES_WRITTEN_TOTAL.fetch_add(slice.len() as u64, Ordering::Relaxed);
}
}
self.file.sync_all().map_err(WalError::Io)?;
self.dirty = false;
Ok(())
}
pub fn path(&self) -> &Path {
&self.path
}
pub fn recover_record(&mut self, target_lsn: u64) -> Result<Option<WalRecord>> {
let mut slot = AlignedBuf::new(DWB_SLOT_STRIDE, DEFAULT_ALIGNMENT)?;
for i in 0..DWB_CAPACITY as u32 {
let offset = slot_offset(i);
let read = unsafe {
libc::pread(
self.file.as_raw_fd(),
slot.as_mut_ptr() as *mut libc::c_void,
DWB_SLOT_STRIDE,
offset as libc::off_t,
)
};
if read <= 0 {
continue;
}
let bytes: &[u8] = unsafe { std::slice::from_raw_parts(slot.as_ptr(), read as usize) };
if bytes.len() < 4 + HEADER_SIZE {
continue;
}
let mut arr4 = [0u8; 4];
arr4.copy_from_slice(&bytes[0..4]);
let total_size = u32::from_le_bytes(arr4) as usize;
if !(HEADER_SIZE..=DWB_SLOT_PAYLOAD_MAX).contains(&total_size)
|| bytes.len() < 4 + total_size
{
continue;
}
let mut header_buf = [0u8; HEADER_SIZE];
header_buf.copy_from_slice(&bytes[4..4 + HEADER_SIZE]);
let header = RecordHeader::from_bytes(&header_buf);
if header.magic != WAL_MAGIC || header.lsn != target_lsn {
continue;
}
let payload_len = total_size - HEADER_SIZE;
let payload = bytes[4 + HEADER_SIZE..4 + HEADER_SIZE + payload_len].to_vec();
let record = WalRecord { header, payload };
if record.verify_checksum().is_ok() {
return Ok(Some(record));
}
}
Ok(None)
}
}
fn zero_tail(buf: &mut AlignedBuf) {
let written = buf.len();
let cap = buf.capacity();
if written < cap {
unsafe {
std::ptr::write_bytes(buf.as_mut_ptr().add(written), 0, cap - written);
}
}
}
fn full_capacity_slice(buf: &AlignedBuf) -> &[u8] {
unsafe { std::slice::from_raw_parts(buf.as_ptr(), buf.capacity()) }
}
fn pwrite_all(file: &File, mut data: &[u8], mut offset: u64) -> Result<()> {
let fd = file.as_raw_fd();
while !data.is_empty() {
let n = unsafe {
libc::pwrite(
fd,
data.as_ptr() as *const libc::c_void,
data.len(),
offset as libc::off_t,
)
};
if n < 0 {
return Err(WalError::Io(std::io::Error::last_os_error()));
}
let n = n as usize;
data = &data[n..];
offset += n as u64;
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use crate::record::RecordType;
fn open_buffered(path: &Path) -> DoubleWriteBuffer {
DoubleWriteBuffer::open(path, DwbMode::Buffered).unwrap()
}
#[test]
fn write_and_recover() {
let dir = tempfile::tempdir().unwrap();
let dwb_path = dir.path().join("test.dwb");
let mut dwb = open_buffered(&dwb_path);
let record = WalRecord::new(
RecordType::Put as u32,
42,
1,
0,
0,
b"hello double-write".to_vec(),
None,
None,
)
.unwrap();
dwb.write_record(&record).unwrap();
let recovered = dwb.recover_record(42).unwrap();
assert!(recovered.is_some());
let rec = recovered.unwrap();
assert_eq!(rec.header.lsn, 42);
assert_eq!(rec.payload, b"hello double-write");
}
#[test]
fn recover_nonexistent_returns_none() {
let dir = tempfile::tempdir().unwrap();
let dwb_path = dir.path().join("test2.dwb");
let mut dwb = open_buffered(&dwb_path);
let result = dwb.recover_record(999).unwrap();
assert!(result.is_none());
}
#[test]
fn survives_reopen() {
let dir = tempfile::tempdir().unwrap();
let dwb_path = dir.path().join("reopen.dwb");
{
let mut dwb = open_buffered(&dwb_path);
let record = WalRecord::new(
RecordType::Put as u32,
7,
1,
0,
0,
b"durable".to_vec(),
None,
None,
)
.unwrap();
dwb.write_record(&record).unwrap();
}
let mut dwb = open_buffered(&dwb_path);
let recovered = dwb.recover_record(7).unwrap();
assert!(recovered.is_some());
assert_eq!(recovered.unwrap().payload, b"durable");
}
#[test]
fn batch_deferred_writes_and_flush() {
let dir = tempfile::tempdir().unwrap();
let dwb_path = dir.path().join("batch.dwb");
let mut dwb = open_buffered(&dwb_path);
for lsn in 1..=5u64 {
let record = WalRecord::new(
RecordType::Put as u32,
lsn,
1,
0,
0,
format!("batch-{lsn}").into_bytes(),
None,
None,
)
.unwrap();
dwb.write_record_deferred(&record).unwrap();
}
assert!(dwb.dirty);
dwb.flush().unwrap();
assert!(!dwb.dirty);
for lsn in 1..=5u64 {
let recovered = dwb.recover_record(lsn).unwrap();
assert!(recovered.is_some(), "LSN {lsn} should be recoverable");
assert_eq!(
recovered.unwrap().payload,
format!("batch-{lsn}").into_bytes()
);
}
}
#[test]
fn flush_is_idempotent() {
let dir = tempfile::tempdir().unwrap();
let dwb_path = dir.path().join("idem.dwb");
let mut dwb = open_buffered(&dwb_path);
dwb.flush().unwrap();
assert!(!dwb.dirty);
let record = WalRecord::new(
RecordType::Put as u32,
1,
1,
0,
0,
b"data".to_vec(),
None,
None,
)
.unwrap();
dwb.write_record_deferred(&record).unwrap();
dwb.flush().unwrap();
dwb.flush().unwrap();
assert!(!dwb.dirty);
}
#[test]
fn slot_stride_is_o_direct_aligned() {
assert!(
is_aligned(DWB_SLOT_STRIDE, DEFAULT_ALIGNMENT),
"DWB slot stride {DWB_SLOT_STRIDE} bytes is not a multiple of {DEFAULT_ALIGNMENT}"
);
assert!(is_aligned(DWB_HEADER_STRIDE, DEFAULT_ALIGNMENT));
for i in 0..DWB_CAPACITY as u32 {
assert!(is_aligned(slot_offset(i) as usize, DEFAULT_ALIGNMENT));
}
}
#[test]
fn recover_after_wraparound() {
let dir = tempfile::tempdir().unwrap();
let dwb_path = dir.path().join("wrap.dwb");
let mut dwb = open_buffered(&dwb_path);
let total = DWB_CAPACITY as u64 + 5;
for lsn in 1..=total {
let record = WalRecord::new(
RecordType::Put as u32,
lsn,
1,
0,
0,
format!("wrap-{lsn}").into_bytes(),
None,
None,
)
.unwrap();
dwb.write_record_deferred(&record).unwrap();
}
dwb.flush().unwrap();
for lsn in (total - 4)..=total {
let recovered = dwb.recover_record(lsn).unwrap();
assert!(
recovered.is_some(),
"LSN {lsn} should be recoverable after wrap-around"
);
assert_eq!(
recovered.unwrap().payload,
format!("wrap-{lsn}").into_bytes()
);
}
for lsn in 1..=5u64 {
let recovered = dwb.recover_record(lsn).unwrap();
assert!(
recovered.is_none(),
"LSN {lsn} should have been overwritten by wrap-around"
);
}
}
#[test]
fn bytes_written_counter_increments() {
let dir = tempfile::tempdir().unwrap();
let dwb_path = dir.path().join("counter.dwb");
let before = wal_dwb_bytes_written_total();
let mut dwb = open_buffered(&dwb_path);
let rec = WalRecord::new(
RecordType::Put as u32,
1,
1,
0,
0,
b"counted".to_vec(),
None,
None,
)
.unwrap();
dwb.write_record(&rec).unwrap();
assert!(wal_dwb_bytes_written_total() > before);
}
}