use anyhow::{Context, Result};
use bytemuck::{bytes_of, Pod, Zeroable};
use crc32fast::Hasher;
use parking_lot::Mutex;
use std::fs::{File, OpenOptions};
use std::io::{Read, Seek, SeekFrom, Write};
use std::path::{Path, PathBuf};
use std::sync::Arc;
use crate::storage::data_structures::WalEntry;
const WAL_MAGIC: u32 = 0x57414C47;
const WAL_VERSION: u32 = 2;
#[repr(C)]
#[derive(Clone, Copy, Debug, Pod, Zeroable)]
pub(crate) struct WalFileHeader {
magic: u32,
version: u32,
entry_count: u64,
last_checkpoint_lsn: u64,
_padding: [u8; 32],
}
#[repr(C)]
#[derive(Clone, Copy, Debug, Pod, Zeroable)]
pub(crate) struct WalEntryWithCrc {
entry: WalEntry,
crc32: u32,
_padding: u32,
}
impl WalEntryWithCrc {
fn new(entry: WalEntry) -> Self {
let crc32 = Self::compute_crc(&entry);
Self {
entry,
crc32,
_padding: 0,
}
}
fn compute_crc(entry: &WalEntry) -> u32 {
let mut hasher = Hasher::new();
hasher.update(bytes_of(entry));
hasher.finalize()
}
fn verify_crc(&self) -> bool {
Self::compute_crc(&self.entry) == self.crc32
}
}
pub struct Wal {
file: Arc<Mutex<File>>,
_path: PathBuf,
dir_path: PathBuf,
entry_count: Arc<Mutex<u64>>,
pending_entries: Arc<Mutex<Vec<WalEntryWithCrc>>>,
batch_size: usize,
}
impl Wal {
pub fn open<P: AsRef<Path>>(path: P, batch_size: usize) -> Result<Self> {
let path_buf = path.as_ref().to_path_buf();
let dir_path = path_buf
.parent()
.ok_or_else(|| anyhow::anyhow!("WAL path has no parent directory"))?
.to_path_buf();
let mut file = OpenOptions::new()
.read(true)
.write(true)
.create(true)
.truncate(false)
.open(&path_buf)
.context("Failed to open WAL file")?;
let metadata = file.metadata().context("Failed to get WAL file metadata")?;
let entry_count = if metadata.len() == 0 {
Self::write_header(&mut file, 0, 0)?;
0
} else {
let header = Self::read_header(&mut file)?;
if header.magic != WAL_MAGIC {
anyhow::bail!(
"Invalid WAL magic number: expected {:#x}, got {:#x}",
WAL_MAGIC,
header.magic
);
}
if header.version != WAL_VERSION {
anyhow::bail!(
"Incompatible WAL version: expected {}, got {}",
WAL_VERSION,
header.version
);
}
header.entry_count
};
Ok(Self {
file: Arc::new(Mutex::new(file)),
_path: path_buf,
dir_path,
entry_count: Arc::new(Mutex::new(entry_count)),
pending_entries: Arc::new(Mutex::new(Vec::new())),
batch_size,
})
}
pub fn append(&self, entry: WalEntry) -> Result<()> {
let entry_with_crc = WalEntryWithCrc::new(entry);
let mut pending = self.pending_entries.lock();
pending.push(entry_with_crc);
if pending.len() >= self.batch_size {
drop(pending);
self.flush()?;
}
Ok(())
}
pub fn flush(&self) -> Result<()> {
let mut pending = self.pending_entries.lock();
if pending.is_empty() {
return Ok(());
}
let mut file = self.file.lock();
let mut count = self.entry_count.lock();
file.seek(SeekFrom::End(0))
.context("Failed to seek to end of WAL")?;
for entry_with_crc in pending.iter() {
file.write_all(bytes_of(entry_with_crc))
.context("Failed to write WAL entry")?;
}
file.sync_data().context("Failed to fsync WAL file")?;
*count += pending.len() as u64;
Self::update_header_entry_count(&mut file, *count)?;
Self::fsync_directory(&self.dir_path)?;
pending.clear();
Ok(())
}
pub fn replay(&self) -> Result<Vec<WalEntry>> {
let mut file = self.file.lock();
let _header = Self::read_header(&mut file)?;
let entry_size = std::mem::size_of::<WalEntryWithCrc>();
let header_size = std::mem::size_of::<WalFileHeader>();
let mut entries = Vec::new();
file.seek(SeekFrom::Start(header_size as u64))
.context("Failed to seek past WAL header")?;
let mut buffer = vec![0u8; entry_size];
while let Ok(()) = file.read_exact(&mut buffer) {
let entry_with_crc: WalEntryWithCrc = *bytemuck::try_from_bytes(&buffer)
.map_err(|e| anyhow::anyhow!("Invalid WAL entry bytes: {}", e))?;
if !entry_with_crc.verify_crc() {
eprintln!("WAL corruption detected: CRC mismatch, stopping replay");
break;
}
entries.push(entry_with_crc.entry);
}
Ok(entries)
}
pub fn truncate(&self) -> Result<()> {
let mut pending = self.pending_entries.lock();
pending.clear();
let mut count = self.entry_count.lock();
*count = 0;
let mut file = self.file.lock();
file.set_len(0)?;
file.seek(SeekFrom::Start(0))?;
Self::write_header(&mut file, 0, 0)?;
file.sync_data()?;
Ok(())
}
pub fn entry_count(&self) -> u64 {
*self.entry_count.lock() + self.pending_entries.lock().len() as u64
}
fn write_header(file: &mut File, entry_count: u64, checkpoint_lsn: u64) -> Result<()> {
let header = WalFileHeader {
magic: WAL_MAGIC,
version: WAL_VERSION,
entry_count,
last_checkpoint_lsn: checkpoint_lsn,
_padding: [0u8; 32],
};
file.seek(SeekFrom::Start(0))?;
file.write_all(bytes_of(&header))?;
file.flush()?;
Ok(())
}
fn read_header(file: &mut File) -> Result<WalFileHeader> {
let mut buffer = [0u8; std::mem::size_of::<WalFileHeader>()];
file.seek(SeekFrom::Start(0))?;
file.read_exact(&mut buffer)?;
let header: WalFileHeader = *bytemuck::try_from_bytes(&buffer)
.map_err(|e| anyhow::anyhow!("Invalid WAL header: {}", e))?;
Ok(header)
}
fn update_header_entry_count(file: &mut File, entry_count: u64) -> Result<()> {
let mut header = Self::read_header(file)?;
header.entry_count = entry_count;
file.seek(SeekFrom::Start(0))?;
file.write_all(bytes_of(&header))?;
file.flush()?;
Ok(())
}
fn fsync_directory(dir_path: &Path) -> Result<()> {
#[cfg(unix)]
{
use std::os::unix::io::AsRawFd;
let dir = File::open(dir_path).context("Failed to open directory for fsync")?;
unsafe {
if libc::fsync(dir.as_raw_fd()) != 0 {
return Err(anyhow::anyhow!(
"fsync directory failed: {}",
std::io::Error::last_os_error()
));
}
}
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::storage::data_structures::{WAL_ENTRY_EDGE_CREATE, WAL_ENTRY_NODE_CREATE};
use tempfile::tempdir;
fn make_wal_entry(node_id: u64, entry_type: u8) -> WalEntry {
WalEntry {
timestamp: 0,
node_id,
edge_dst: 0,
x: 0.0,
y: 0.0,
z: 0.0,
edge_w: 0.0,
entry_type,
_padding: [0u8; 7],
tx_id: 0,
lsn: 0,
}
}
#[test]
fn test_wal_create_and_open() {
let temp_dir = tempdir().unwrap();
let wal_path = temp_dir.path().join("test.wal");
let wal = Wal::open(&wal_path, 10);
assert!(wal.is_ok());
assert!(wal_path.exists());
let wal2 = Wal::open(&wal_path, 10);
assert!(wal2.is_ok());
}
#[test]
fn test_wal_append_and_replay() {
let temp_dir = tempdir().unwrap();
let wal_path = temp_dir.path().join("test.wal");
let wal = Wal::open(&wal_path, 100).unwrap();
for i in 0..5 {
wal.append(make_wal_entry(i, WAL_ENTRY_NODE_CREATE))
.unwrap();
}
wal.flush().unwrap();
assert_eq!(wal.entry_count(), 5);
let entries = wal.replay().unwrap();
assert_eq!(entries.len(), 5);
for (i, entry) in entries.iter().enumerate() {
assert_eq!(entry.node_id, i as u64);
assert_eq!(entry.entry_type, WAL_ENTRY_NODE_CREATE);
}
}
#[test]
fn test_wal_auto_flush_on_batch_size() {
let temp_dir = tempdir().unwrap();
let wal_path = temp_dir.path().join("test.wal");
let wal = Wal::open(&wal_path, 3).unwrap();
for i in 0..2 {
wal.append(make_wal_entry(i, WAL_ENTRY_NODE_CREATE))
.unwrap();
}
assert_eq!(wal.entry_count(), 2);
wal.append(make_wal_entry(2, WAL_ENTRY_NODE_CREATE))
.unwrap();
assert_eq!(wal.entry_count(), 3);
let entries = wal.replay().unwrap();
assert_eq!(entries.len(), 3);
}
#[test]
fn test_wal_truncate() {
let temp_dir = tempdir().unwrap();
let wal_path = temp_dir.path().join("test.wal");
let wal = Wal::open(&wal_path, 100).unwrap();
for i in 0..5 {
wal.append(make_wal_entry(i, WAL_ENTRY_EDGE_CREATE))
.unwrap();
}
wal.flush().unwrap();
wal.truncate().unwrap();
assert_eq!(wal.entry_count(), 0);
let entries = wal.replay().unwrap();
assert_eq!(entries.len(), 0);
}
#[test]
fn test_wal_crc_corruption() {
let temp_dir = tempdir().unwrap();
let wal_path = temp_dir.path().join("test.wal");
let wal = Wal::open(&wal_path, 1).unwrap();
wal.append(make_wal_entry(0, WAL_ENTRY_NODE_CREATE))
.unwrap();
wal.flush().unwrap();
let _entry_size = std::mem::size_of::<WalEntryWithCrc>();
let header_size = std::mem::size_of::<WalFileHeader>();
let corrupt_offset = header_size + 4;
{
let mut file = OpenOptions::new()
.read(true)
.write(true)
.open(&wal_path)
.unwrap();
use std::io::Seek;
file.seek(SeekFrom::Start(corrupt_offset as u64)).unwrap();
let mut byte = [0u8; 1];
file.read_exact(&mut byte).unwrap();
byte[0] = byte[0].wrapping_add(1);
file.seek(SeekFrom::Start(corrupt_offset as u64)).unwrap();
file.write_all(&byte).unwrap();
}
let wal2 = Wal::open(&wal_path, 100).unwrap();
let entries = wal2.replay().unwrap();
assert_eq!(entries.len(), 0);
}
}