use anyhow::{Context, Result};
use bytemuck::{Pod, Zeroable};
use std::fs::{File, OpenOptions};
use std::io::{Read, Seek, SeekFrom, Write};
use std::path::{Path, PathBuf};
const WAL_MAGIC: u32 = 0x4C415747;
const WAL_VERSION: u32 = 1;
#[repr(C)]
#[derive(Clone, Copy, Debug, Pod, Zeroable)]
struct WalFileHeader {
magic: u32,
version: u32,
entry_count: u64,
last_checkpoint_lsn: u64,
_padding: [u8; 16],
}
impl WalFileHeader {
fn new() -> Self {
Self {
magic: WAL_MAGIC,
version: WAL_VERSION,
entry_count: 0,
last_checkpoint_lsn: 0,
_padding: [0; 16],
}
}
}
#[repr(u8)]
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum WalEntryType {
NodeInsert = 1,
EdgeInsert = 2,
NodeUpdate = 3,
MetadataInsert = 4,
}
#[repr(C)]
#[derive(Clone, Copy, Debug, Pod, Zeroable)]
pub struct WalEntry {
pub entry_type: u8, pub _padding: [u8; 7], pub node_id: u64, pub data_offset: u64, pub data_length: u64, pub timestamp: u64, }
#[repr(C)]
#[derive(Clone, Copy, Debug, Pod, Zeroable)]
struct WalEntryWithCrc {
entry: WalEntry,
crc32: u32,
_padding: u32,
}
impl WalEntryWithCrc {
fn new(entry: WalEntry) -> Self {
let crc32 = Self::compute_crc(&entry);
Self {
entry,
crc32,
_padding: 0,
}
}
fn compute_crc(entry: &WalEntry) -> u32 {
use crc32fast::Hasher;
let mut hasher = Hasher::new();
hasher.update(bytemuck::bytes_of(entry));
hasher.finalize()
}
fn verify_crc(&self) -> bool {
self.crc32 == Self::compute_crc(&self.entry)
}
}
pub struct Wal {
file: File,
path: PathBuf,
entry_count: u64,
pending: Vec<WalEntryWithCrc>,
batch_size: usize,
}
impl Wal {
pub fn open<P: AsRef<Path>>(path: P, batch_size: usize) -> Result<Self> {
let path_buf = path.as_ref().to_path_buf();
let mut file = OpenOptions::new()
.read(true)
.write(true)
.create(true)
.truncate(false)
.open(&path_buf)
.context("Failed to open WAL file")?;
let metadata = file.metadata().context("Failed to get WAL metadata")?;
let entry_count = if metadata.len() == 0 {
Self::write_header(&mut file, 0, 0)?;
0
} else {
let header = Self::read_header(&mut file)?;
if header.magic != WAL_MAGIC {
anyhow::bail!("Invalid WAL magic number");
}
if header.version != WAL_VERSION {
anyhow::bail!("Incompatible WAL version");
}
header.entry_count
};
Ok(Self {
file,
path: path_buf,
entry_count,
pending: Vec::new(),
batch_size,
})
}
pub fn wal_path_for_db(db_path: &Path) -> PathBuf {
let mut path = db_path.to_path_buf();
let stem = path.file_stem().unwrap_or_default();
let ext = path.extension().unwrap_or_default();
let new_name = format!("{}_wal.{}", stem.to_string_lossy(), ext.to_string_lossy());
path.set_file_name(new_name);
path
}
pub fn append(&mut self, entry: WalEntry) -> Result<()> {
let entry_with_crc = WalEntryWithCrc::new(entry);
self.pending.push(entry_with_crc);
if self.pending.len() >= self.batch_size {
self.flush()?;
}
Ok(())
}
pub fn flush(&mut self) -> Result<()> {
if self.pending.is_empty() {
return Ok(());
}
self.file
.seek(SeekFrom::End(0))
.context("Failed to seek to end of WAL")?;
for entry_with_crc in &self.pending {
self.file
.write_all(bytemuck::bytes_of(entry_with_crc))
.context("Failed to write WAL entry")?;
}
self.file.sync_data().context("Failed to fsync WAL")?;
self.entry_count += self.pending.len() as u64;
Self::update_header_entry_count(&mut self.file, self.entry_count)?;
self.pending.clear();
Ok(())
}
pub fn replay(&mut self) -> Result<Vec<WalEntry>> {
let header = Self::read_header(&mut self.file)?;
let entry_size = std::mem::size_of::<WalEntryWithCrc>();
let header_size = std::mem::size_of::<WalFileHeader>();
let mut entries = Vec::new();
self.file
.seek(SeekFrom::Start(header_size as u64))
.context("Failed to seek past WAL header")?;
let mut buffer = vec![0u8; entry_size];
while let Ok(()) = self.file.read_exact(&mut buffer) {
let entry_with_crc: WalEntryWithCrc = *bytemuck::try_from_bytes(&buffer)
.map_err(|e| anyhow::anyhow!("Invalid WAL entry bytes: {}", e))?;
if !entry_with_crc.verify_crc() {
eprintln!("WAL corruption detected: CRC mismatch, stopping replay");
break;
}
entries.push(entry_with_crc.entry);
}
Ok(entries)
}
pub fn truncate(&mut self) -> Result<()> {
self.pending.clear();
self.entry_count = 0;
self.file.set_len(0)?;
self.file.seek(SeekFrom::Start(0))?;
Self::write_header(&mut self.file, 0, 0)?;
self.file.sync_data()?;
Ok(())
}
pub fn entry_count(&self) -> u64 {
self.entry_count + self.pending.len() as u64
}
fn write_header(file: &mut File, entry_count: u64, checkpoint_lsn: u64) -> Result<()> {
let header = WalFileHeader {
magic: WAL_MAGIC,
version: WAL_VERSION,
entry_count,
last_checkpoint_lsn: checkpoint_lsn,
_padding: [0; 16],
};
file.seek(SeekFrom::Start(0))?;
file.write_all(bytemuck::bytes_of(&header))?;
file.flush()?;
Ok(())
}
fn read_header(file: &mut File) -> Result<WalFileHeader> {
let mut buffer = [0u8; std::mem::size_of::<WalFileHeader>()];
file.seek(SeekFrom::Start(0))?;
file.read_exact(&mut buffer)?;
let header: WalFileHeader = *bytemuck::try_from_bytes(&buffer)
.map_err(|e| anyhow::anyhow!("Invalid WAL header: {}", e))?;
Ok(header)
}
fn update_header_entry_count(file: &mut File, entry_count: u64) -> Result<()> {
let mut header = Self::read_header(file)?;
header.entry_count = entry_count;
file.seek(SeekFrom::Start(0))?;
file.write_all(bytemuck::bytes_of(&header))?;
file.flush()?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::tempdir;
#[test]
fn test_wal_create_and_open() {
let temp_dir = tempdir().unwrap();
let wal_path = temp_dir.path().join("test.wal");
let wal = Wal::open(&wal_path, 10);
assert!(wal.is_ok());
assert!(wal_path.exists());
let wal2 = Wal::open(&wal_path, 10);
assert!(wal2.is_ok());
}
#[test]
fn test_wal_append_and_replay() {
let temp_dir = tempdir().unwrap();
let wal_path = temp_dir.path().join("test.wal");
let mut wal = Wal::open(&wal_path, 100).unwrap();
for i in 0..5 {
wal.append(WalEntry {
entry_type: WalEntryType::NodeInsert as u8,
_padding: [0; 7],
node_id: i,
data_offset: i * 100,
data_length: 72,
timestamp: i * 1000,
})
.unwrap();
}
wal.flush().unwrap();
assert_eq!(wal.entry_count(), 5);
let entries = wal.replay().unwrap();
assert_eq!(entries.len(), 5);
for (i, entry) in entries.iter().enumerate() {
assert_eq!(entry.node_id, i as u64);
assert_eq!(entry.entry_type, WalEntryType::NodeInsert as u8);
}
}
#[test]
fn test_wal_auto_flush_on_batch_size() {
let temp_dir = tempdir().unwrap();
let wal_path = temp_dir.path().join("test.wal");
let mut wal = Wal::open(&wal_path, 3).unwrap();
for i in 0..2 {
wal.append(WalEntry {
entry_type: WalEntryType::NodeInsert as u8,
_padding: [0; 7],
node_id: i,
data_offset: 0,
data_length: 0,
timestamp: 0,
})
.unwrap();
}
assert_eq!(wal.entry_count(), 2);
wal.append(WalEntry {
entry_type: WalEntryType::NodeInsert as u8,
_padding: [0; 7],
node_id: 2,
data_offset: 0,
data_length: 0,
timestamp: 0,
})
.unwrap();
let entries = wal.replay().unwrap();
assert_eq!(entries.len(), 3);
}
#[test]
fn test_wal_corruption_detection() {
let temp_dir = tempdir().unwrap();
let wal_path = temp_dir.path().join("test.wal");
let mut wal = Wal::open(&wal_path, 100).unwrap();
for i in 0..3 {
wal.append(WalEntry {
entry_type: WalEntryType::NodeInsert as u8,
_padding: [0; 7],
node_id: i,
data_offset: 0,
data_length: 0,
timestamp: 0,
})
.unwrap();
}
wal.flush().unwrap();
{
let mut file = OpenOptions::new().write(true).open(&wal_path).unwrap();
file.seek(SeekFrom::End(-8)).unwrap();
file.write_all(&[0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF])
.unwrap();
}
let mut wal = Wal::open(&wal_path, 100).unwrap();
let entries = wal.replay().unwrap();
assert_eq!(entries.len(), 2);
}
#[test]
fn test_wal_truncate() {
let temp_dir = tempdir().unwrap();
let wal_path = temp_dir.path().join("test.wal");
let mut wal = Wal::open(&wal_path, 100).unwrap();
for i in 0..5 {
wal.append(WalEntry {
entry_type: WalEntryType::NodeInsert as u8,
_padding: [0; 7],
node_id: i,
data_offset: 0,
data_length: 0,
timestamp: 0,
})
.unwrap();
}
wal.flush().unwrap();
assert_eq!(wal.entry_count(), 5);
wal.truncate().unwrap();
assert_eq!(wal.entry_count(), 0);
let entries = wal.replay().unwrap();
assert!(entries.is_empty());
}
#[test]
fn test_wal_path_for_db() {
let db_path = Path::new("/tmp/test.db");
let wal_path = Wal::wal_path_for_db(db_path);
assert_eq!(wal_path, PathBuf::from("/tmp/test_wal.db"));
}
}