1use std::fs::{File, OpenOptions};
39use std::io::{Read, Seek, SeekFrom, Write};
40use std::os::unix::fs::OpenOptionsExt;
41use std::os::unix::io::AsRawFd;
42use std::path::{Path, PathBuf};
43use std::sync::atomic::{AtomicU64, Ordering};
44
45use crate::align::{AlignedBuf, DEFAULT_ALIGNMENT, is_aligned};
46use crate::error::{Result, WalError};
47use crate::record::{HEADER_SIZE, RecordHeader, WAL_MAGIC, WalRecord};
48
49const DWB_CAPACITY: usize = 64;
57
58const DWB_SLOT_PAYLOAD_MAX: usize = 64 * 1024;
60
61const DWB_SLOT_RAW: usize = 4 + HEADER_SIZE + DWB_SLOT_PAYLOAD_MAX;
63
64const DWB_SLOT_STRIDE: usize = round_up_const(DWB_SLOT_RAW, DEFAULT_ALIGNMENT);
68
69const DWB_HEADER_STRIDE: usize = DEFAULT_ALIGNMENT;
73const DWB_HEADER_FIELDS: usize = 12;
74const DWB_MAGIC: u32 = 0x4457_4246; static DWB_BYTES_WRITTEN_TOTAL: AtomicU64 = AtomicU64::new(0);
80
81pub fn wal_dwb_bytes_written_total() -> u64 {
83 DWB_BYTES_WRITTEN_TOTAL.load(Ordering::Relaxed)
84}
85
86#[derive(Debug, Clone, Copy, PartialEq, Eq)]
88pub enum DwbMode {
89 Off,
92 Buffered,
95 Direct,
98}
99
100impl DwbMode {
101 pub fn default_for_parent(parent_uses_direct_io: bool) -> Self {
105 if parent_uses_direct_io {
106 Self::Direct
107 } else {
108 Self::Buffered
109 }
110 }
111}
112
113const fn round_up_const(value: usize, align: usize) -> usize {
114 (value + align - 1) & !(align - 1)
115}
116
117pub const fn slot_stride() -> usize {
120 DWB_SLOT_STRIDE
121}
122
123fn slot_offset(idx: u32) -> u64 {
125 DWB_HEADER_STRIDE as u64 + (idx as u64 % DWB_CAPACITY as u64) * DWB_SLOT_STRIDE as u64
126}
127
128pub struct DoubleWriteBuffer {
130 file: File,
131 path: PathBuf,
132 mode: DwbMode,
133 write_pos: u32,
135 count: u32,
137 dirty: bool,
139 slot_buf: Option<AlignedBuf>,
142 header_buf: Option<AlignedBuf>,
144}
145
146impl std::fmt::Debug for DoubleWriteBuffer {
147 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
148 f.debug_struct("DoubleWriteBuffer")
149 .field("path", &self.path)
150 .field("mode", &self.mode)
151 .field("write_pos", &self.write_pos)
152 .field("count", &self.count)
153 .finish()
154 }
155}
156
157impl DoubleWriteBuffer {
158 pub fn open(path: &Path, mode: DwbMode) -> Result<Self> {
163 if mode == DwbMode::Off {
164 return Err(WalError::DwbOffNotOpenable);
165 }
166
167 let mut opts = OpenOptions::new();
168 opts.read(true).write(true).create(true).truncate(false);
169 if mode == DwbMode::Direct {
170 opts.custom_flags(libc::O_DIRECT);
171 }
172
173 let file = opts.open(path).map_err(|e| {
174 tracing::warn!(path = %path.display(), error = %e, mode = ?mode, "failed to open double-write buffer");
175 WalError::Io(e)
176 })?;
177
178 let (slot_buf, header_buf) = if mode == DwbMode::Direct {
179 (
180 Some(AlignedBuf::new(DWB_SLOT_STRIDE, DEFAULT_ALIGNMENT)?),
181 Some(AlignedBuf::new(DWB_HEADER_STRIDE, DEFAULT_ALIGNMENT)?),
182 )
183 } else {
184 (None, None)
185 };
186
187 let mut dwb = Self {
188 file,
189 path: path.to_path_buf(),
190 mode,
191 write_pos: 0,
192 count: 0,
193 dirty: false,
194 slot_buf,
195 header_buf,
196 };
197
198 let file_len = dwb.file.metadata().map(|m| m.len()).unwrap_or(0);
200 if file_len >= DWB_HEADER_STRIDE as u64 {
201 let mut block = vec![0u8; DWB_HEADER_STRIDE];
202 dwb.file.seek(SeekFrom::Start(0)).map_err(WalError::Io)?;
203 if dwb.file.read_exact(&mut block).is_ok() {
204 let mut arr4 = [0u8; 4];
205 arr4.copy_from_slice(&block[0..4]);
206 let magic = u32::from_le_bytes(arr4);
207 if magic == DWB_MAGIC {
208 arr4.copy_from_slice(&block[4..8]);
209 dwb.count = u32::from_le_bytes(arr4);
210 arr4.copy_from_slice(&block[8..12]);
211 dwb.write_pos = u32::from_le_bytes(arr4);
212 }
213 }
214 }
215
216 Ok(dwb)
217 }
218
219 pub fn mode(&self) -> DwbMode {
221 self.mode
222 }
223
224 pub fn write_record(&mut self, record: &WalRecord) -> Result<()> {
230 self.write_record_deferred(record)?;
231 self.flush()
232 }
233
234 pub fn write_record_deferred(&mut self, record: &WalRecord) -> Result<()> {
242 let total_size = HEADER_SIZE + record.payload.len();
243
244 if total_size > DWB_SLOT_PAYLOAD_MAX {
247 return Ok(()); }
249
250 let header_bytes = record.header.to_bytes();
251 let offset = slot_offset(self.write_pos);
252
253 match self.mode {
254 DwbMode::Off => unreachable!("Off never opens a DoubleWriteBuffer"),
255 DwbMode::Buffered => {
256 self.file
257 .seek(SeekFrom::Start(offset))
258 .map_err(WalError::Io)?;
259 self.file
260 .write_all(&(total_size as u32).to_le_bytes())
261 .map_err(WalError::Io)?;
262 self.file.write_all(&header_bytes).map_err(WalError::Io)?;
263 self.file.write_all(&record.payload).map_err(WalError::Io)?;
264 DWB_BYTES_WRITTEN_TOTAL.fetch_add(
265 (4 + header_bytes.len() + record.payload.len()) as u64,
266 Ordering::Relaxed,
267 );
268 }
269 DwbMode::Direct => {
270 let buf = self
271 .slot_buf
272 .as_mut()
273 .expect("slot_buf present in Direct mode");
274 buf.clear();
275 buf.write(&(total_size as u32).to_le_bytes());
276 buf.write(&header_bytes);
277 buf.write(&record.payload);
278 zero_tail(buf);
281 let slice = full_capacity_slice(buf);
282 debug_assert_eq!(slice.len(), DWB_SLOT_STRIDE);
283 debug_assert!(is_aligned(offset as usize, DEFAULT_ALIGNMENT));
284 pwrite_all(&self.file, slice, offset)?;
285 DWB_BYTES_WRITTEN_TOTAL.fetch_add(slice.len() as u64, Ordering::Relaxed);
286 }
287 }
288
289 self.write_pos = self.write_pos.wrapping_add(1);
290 self.count = self.count.saturating_add(1).min(DWB_CAPACITY as u32);
291 self.dirty = true;
292
293 Ok(())
294 }
295
296 pub fn flush(&mut self) -> Result<()> {
302 if !self.dirty {
303 return Ok(());
304 }
305
306 let mut header = [0u8; DWB_HEADER_FIELDS];
307 header[0..4].copy_from_slice(&DWB_MAGIC.to_le_bytes());
308 header[4..8].copy_from_slice(&self.count.to_le_bytes());
309 header[8..12].copy_from_slice(&self.write_pos.to_le_bytes());
310
311 match self.mode {
312 DwbMode::Off => unreachable!(),
313 DwbMode::Buffered => {
314 self.file.seek(SeekFrom::Start(0)).map_err(WalError::Io)?;
315 self.file.write_all(&header).map_err(WalError::Io)?;
316 DWB_BYTES_WRITTEN_TOTAL.fetch_add(header.len() as u64, Ordering::Relaxed);
317 }
318 DwbMode::Direct => {
319 let buf = self
320 .header_buf
321 .as_mut()
322 .expect("header_buf present in Direct mode");
323 buf.clear();
324 buf.write(&header);
325 zero_tail(buf);
326 let slice = full_capacity_slice(buf);
327 debug_assert_eq!(slice.len(), DWB_HEADER_STRIDE);
328 pwrite_all(&self.file, slice, 0)?;
329 DWB_BYTES_WRITTEN_TOTAL.fetch_add(slice.len() as u64, Ordering::Relaxed);
330 }
331 }
332
333 self.file.sync_all().map_err(WalError::Io)?;
334 self.dirty = false;
335
336 Ok(())
337 }
338
339 pub fn path(&self) -> &Path {
341 &self.path
342 }
343
344 pub fn recover_record(&mut self, target_lsn: u64) -> Result<Option<WalRecord>> {
352 let mut slot = AlignedBuf::new(DWB_SLOT_STRIDE, DEFAULT_ALIGNMENT)?;
355
356 for i in 0..DWB_CAPACITY as u32 {
357 let offset = slot_offset(i);
358 let read = unsafe {
360 libc::pread(
361 self.file.as_raw_fd(),
362 slot.as_mut_ptr() as *mut libc::c_void,
363 DWB_SLOT_STRIDE,
364 offset as libc::off_t,
365 )
366 };
367 if read <= 0 {
368 continue;
369 }
370 let bytes: &[u8] = unsafe { std::slice::from_raw_parts(slot.as_ptr(), read as usize) };
372 if bytes.len() < 4 + HEADER_SIZE {
373 continue;
374 }
375
376 let mut arr4 = [0u8; 4];
377 arr4.copy_from_slice(&bytes[0..4]);
378 let total_size = u32::from_le_bytes(arr4) as usize;
379 if !(HEADER_SIZE..=DWB_SLOT_PAYLOAD_MAX).contains(&total_size)
380 || bytes.len() < 4 + total_size
381 {
382 continue;
383 }
384
385 let mut header_buf = [0u8; HEADER_SIZE];
386 header_buf.copy_from_slice(&bytes[4..4 + HEADER_SIZE]);
387 let header = RecordHeader::from_bytes(&header_buf);
388 if header.magic != WAL_MAGIC || header.lsn != target_lsn {
389 continue;
390 }
391
392 let payload_len = total_size - HEADER_SIZE;
393 let payload = bytes[4 + HEADER_SIZE..4 + HEADER_SIZE + payload_len].to_vec();
394 let record = WalRecord { header, payload };
395 if record.verify_checksum().is_ok() {
396 return Ok(Some(record));
397 }
398 }
399
400 Ok(None)
401 }
402}
403
404fn zero_tail(buf: &mut AlignedBuf) {
407 let written = buf.len();
408 let cap = buf.capacity();
409 if written < cap {
410 unsafe {
413 std::ptr::write_bytes(buf.as_mut_ptr().add(written), 0, cap - written);
414 }
415 }
416}
417
418fn full_capacity_slice(buf: &AlignedBuf) -> &[u8] {
421 unsafe { std::slice::from_raw_parts(buf.as_ptr(), buf.capacity()) }
424}
425
426fn pwrite_all(file: &File, mut data: &[u8], mut offset: u64) -> Result<()> {
428 let fd = file.as_raw_fd();
429 while !data.is_empty() {
430 let n = unsafe {
431 libc::pwrite(
432 fd,
433 data.as_ptr() as *const libc::c_void,
434 data.len(),
435 offset as libc::off_t,
436 )
437 };
438 if n < 0 {
439 return Err(WalError::Io(std::io::Error::last_os_error()));
440 }
441 let n = n as usize;
442 data = &data[n..];
443 offset += n as u64;
444 }
445 Ok(())
446}
447
448#[cfg(test)]
449mod tests {
450 use super::*;
451 use crate::record::RecordType;
452
453 fn open_buffered(path: &Path) -> DoubleWriteBuffer {
454 DoubleWriteBuffer::open(path, DwbMode::Buffered).unwrap()
455 }
456
457 #[test]
458 fn write_and_recover() {
459 let dir = tempfile::tempdir().unwrap();
460 let dwb_path = dir.path().join("test.dwb");
461
462 let mut dwb = open_buffered(&dwb_path);
463
464 let record = WalRecord::new(
465 RecordType::Put as u16,
466 42,
467 1,
468 0,
469 b"hello double-write".to_vec(),
470 None,
471 )
472 .unwrap();
473
474 dwb.write_record(&record).unwrap();
475
476 let recovered = dwb.recover_record(42).unwrap();
478 assert!(recovered.is_some());
479 let rec = recovered.unwrap();
480 assert_eq!(rec.header.lsn, 42);
481 assert_eq!(rec.payload, b"hello double-write");
482 }
483
484 #[test]
485 fn recover_nonexistent_returns_none() {
486 let dir = tempfile::tempdir().unwrap();
487 let dwb_path = dir.path().join("test2.dwb");
488
489 let mut dwb = open_buffered(&dwb_path);
490 let result = dwb.recover_record(999).unwrap();
491 assert!(result.is_none());
492 }
493
494 #[test]
495 fn survives_reopen() {
496 let dir = tempfile::tempdir().unwrap();
497 let dwb_path = dir.path().join("reopen.dwb");
498
499 {
500 let mut dwb = open_buffered(&dwb_path);
501 let record =
502 WalRecord::new(RecordType::Put as u16, 7, 1, 0, b"durable".to_vec(), None).unwrap();
503 dwb.write_record(&record).unwrap();
504 }
505
506 let mut dwb = open_buffered(&dwb_path);
507 let recovered = dwb.recover_record(7).unwrap();
508 assert!(recovered.is_some());
509 assert_eq!(recovered.unwrap().payload, b"durable");
510 }
511
512 #[test]
513 fn batch_deferred_writes_and_flush() {
514 let dir = tempfile::tempdir().unwrap();
515 let dwb_path = dir.path().join("batch.dwb");
516
517 let mut dwb = open_buffered(&dwb_path);
518
519 for lsn in 1..=5u64 {
520 let record = WalRecord::new(
521 RecordType::Put as u16,
522 lsn,
523 1,
524 0,
525 format!("batch-{lsn}").into_bytes(),
526 None,
527 )
528 .unwrap();
529 dwb.write_record_deferred(&record).unwrap();
530 }
531
532 assert!(dwb.dirty);
533 dwb.flush().unwrap();
534 assert!(!dwb.dirty);
535
536 for lsn in 1..=5u64 {
537 let recovered = dwb.recover_record(lsn).unwrap();
538 assert!(recovered.is_some(), "LSN {lsn} should be recoverable");
539 assert_eq!(
540 recovered.unwrap().payload,
541 format!("batch-{lsn}").into_bytes()
542 );
543 }
544 }
545
546 #[test]
547 fn flush_is_idempotent() {
548 let dir = tempfile::tempdir().unwrap();
549 let dwb_path = dir.path().join("idem.dwb");
550
551 let mut dwb = open_buffered(&dwb_path);
552
553 dwb.flush().unwrap();
554 assert!(!dwb.dirty);
555
556 let record =
557 WalRecord::new(RecordType::Put as u16, 1, 1, 0, b"data".to_vec(), None).unwrap();
558 dwb.write_record_deferred(&record).unwrap();
559 dwb.flush().unwrap();
560 dwb.flush().unwrap();
561 assert!(!dwb.dirty);
562 }
563
564 #[test]
565 fn slot_stride_is_o_direct_aligned() {
566 assert!(
572 is_aligned(DWB_SLOT_STRIDE, DEFAULT_ALIGNMENT),
573 "DWB slot stride {DWB_SLOT_STRIDE} bytes is not a multiple of {DEFAULT_ALIGNMENT}"
574 );
575 assert!(is_aligned(DWB_HEADER_STRIDE, DEFAULT_ALIGNMENT));
576 for i in 0..DWB_CAPACITY as u32 {
577 assert!(is_aligned(slot_offset(i) as usize, DEFAULT_ALIGNMENT));
578 }
579 }
580
581 #[test]
582 fn recover_after_wraparound() {
583 let dir = tempfile::tempdir().unwrap();
584 let dwb_path = dir.path().join("wrap.dwb");
585
586 let mut dwb = open_buffered(&dwb_path);
587
588 let total = DWB_CAPACITY as u64 + 5;
589 for lsn in 1..=total {
590 let record = WalRecord::new(
591 RecordType::Put as u16,
592 lsn,
593 1,
594 0,
595 format!("wrap-{lsn}").into_bytes(),
596 None,
597 )
598 .unwrap();
599 dwb.write_record_deferred(&record).unwrap();
600 }
601 dwb.flush().unwrap();
602
603 for lsn in (total - 4)..=total {
604 let recovered = dwb.recover_record(lsn).unwrap();
605 assert!(
606 recovered.is_some(),
607 "LSN {lsn} should be recoverable after wrap-around"
608 );
609 assert_eq!(
610 recovered.unwrap().payload,
611 format!("wrap-{lsn}").into_bytes()
612 );
613 }
614
615 for lsn in 1..=5u64 {
616 let recovered = dwb.recover_record(lsn).unwrap();
617 assert!(
618 recovered.is_none(),
619 "LSN {lsn} should have been overwritten by wrap-around"
620 );
621 }
622 }
623
624 #[test]
625 fn bytes_written_counter_increments() {
626 let dir = tempfile::tempdir().unwrap();
627 let dwb_path = dir.path().join("counter.dwb");
628 let before = wal_dwb_bytes_written_total();
629
630 let mut dwb = open_buffered(&dwb_path);
631 let rec =
632 WalRecord::new(RecordType::Put as u16, 1, 1, 0, b"counted".to_vec(), None).unwrap();
633 dwb.write_record(&rec).unwrap();
634
635 assert!(wal_dwb_bytes_written_total() > before);
636 }
637}