#[cfg(test)]
mod tests;
use std::{
fs::{File, OpenOptions},
io::{self, Read, Seek, SeekFrom, Write},
path::{Path, PathBuf},
sync::{Arc, Mutex},
};
use crate::encoding::{self, EncodingError};
use crc32fast::Hasher as Crc32;
use std::ffi::OsStr;
use thiserror::Error;
use tracing::{debug, error, info, trace, warn};
const U32_SIZE: usize = std::mem::size_of::<u32>();
#[derive(Debug, Error)]
#[non_exhaustive]
pub enum WalError {
#[error("I/O error: {0}")]
Io(#[from] io::Error),
#[error("Encoding error: {0}")]
Encoding(#[from] EncodingError),
#[error("Checksum mismatch")]
ChecksumMismatch,
#[error("Record size exceeds limit ({0} bytes)")]
RecordTooLarge(usize),
#[error("Unexpected end of file")]
UnexpectedEof,
#[error("Internal header: {0}")]
InvalidHeader(String),
#[error("Internal error: {0}")]
Internal(String),
}
#[derive(Debug)]
pub struct WalHeader {
magic: [u8; 4],
version: u32,
max_record_size: u32,
wal_seq: u64,
}
impl WalHeader {
pub const MAGIC: [u8; 4] = *b"AWAL";
pub const VERSION: u32 = 1;
pub const DEFAULT_MAX_RECORD_SIZE: u32 = 1024 * 1024;
pub fn new(max_record_size: u32, wal_seq: u64) -> Self {
Self {
magic: Self::MAGIC,
version: Self::VERSION,
max_record_size,
wal_seq,
}
}
pub const ENCODED_SIZE: usize = 4 + 4 + 4 + 8;
pub const HEADER_DISK_SIZE: usize = Self::ENCODED_SIZE + U32_SIZE;
#[allow(dead_code)]
pub fn wal_seq(&self) -> u64 {
self.wal_seq
}
#[allow(dead_code)]
pub fn max_record_size(&self) -> u32 {
self.max_record_size
}
#[allow(dead_code)]
pub fn version(&self) -> u32 {
self.version
}
}
impl encoding::Encode for WalHeader {
fn encode_to(&self, buf: &mut Vec<u8>) -> Result<(), EncodingError> {
encoding::Encode::encode_to(&self.magic, buf)?;
encoding::Encode::encode_to(&self.version, buf)?;
encoding::Encode::encode_to(&self.max_record_size, buf)?;
encoding::Encode::encode_to(&self.wal_seq, buf)?;
Ok(())
}
}
impl encoding::Decode for WalHeader {
fn decode_from(buf: &[u8]) -> Result<(Self, usize), EncodingError> {
let mut offset = 0;
let (magic, n) = <[u8; 4]>::decode_from(&buf[offset..])?;
offset += n;
let (version, n) = u32::decode_from(&buf[offset..])?;
offset += n;
let (max_record_size, n) = u32::decode_from(&buf[offset..])?;
offset += n;
let (wal_seq, n) = u64::decode_from(&buf[offset..])?;
offset += n;
Ok((
Self {
magic,
version,
max_record_size,
wal_seq,
},
offset,
))
}
}
pub trait WalData: encoding::Encode + encoding::Decode + std::fmt::Debug + Send + Sync {}
impl<T> WalData for T where T: encoding::Encode + encoding::Decode + std::fmt::Debug + Send + Sync {}
#[derive(Debug)]
pub struct Wal<T: WalData> {
inner_file: Arc<Mutex<File>>,
path: PathBuf,
header: WalHeader,
_phantom: std::marker::PhantomData<T>,
}
impl<T: WalData> Wal<T> {
pub fn open<P: AsRef<Path>>(path: P, max_record_size: Option<u32>) -> Result<Self, WalError> {
let path_ref = path.as_ref();
let mut file = OpenOptions::new()
.create(true)
.read(true)
.append(true)
.open(path_ref)?;
let wal_seq = Self::parse_seq_from_path(path_ref)
.ok_or(WalError::Internal("WAL name incorrect".into()))?;
let header = if file.metadata()?.len() == 0 {
let header = WalHeader::new(
max_record_size.unwrap_or(WalHeader::DEFAULT_MAX_RECORD_SIZE),
wal_seq,
);
write_header(&mut file, &header)?;
file.sync_all()?;
info!(path = %path_ref.display(), seq = wal_seq, "WAL created with new header");
header
} else {
file.seek(SeekFrom::Start(0))?;
let header = read_and_validate_header(&mut file)?;
if header.wal_seq != wal_seq {
return Err(WalError::InvalidHeader("sequence number mismatch".into()));
}
debug!(
path = %path_ref.display(),
max_record_size = header.max_record_size,
seq = header.wal_seq,
"WAL header validated"
);
header
};
info!(path = %path_ref.display(), seq = header.wal_seq, "WAL opened");
Ok(Self {
inner_file: Arc::new(Mutex::new(file)),
path: path_ref.to_path_buf(),
header,
_phantom: std::marker::PhantomData,
})
}
fn parse_seq_from_path(path: &Path) -> Option<u64> {
let name = path.file_name().and_then(OsStr::to_str)?;
let seq_str = name.strip_suffix(".log")?;
seq_str.parse::<u64>().ok()
}
pub fn append(&self, record: &T) -> Result<(), WalError> {
let record_bytes = encoding::encode_to_vec(record)?;
let record_len = u32::try_from(record_bytes.len())
.map_err(|_| WalError::RecordTooLarge(record_bytes.len()))?;
if record_len > self.header.max_record_size {
return Err(WalError::RecordTooLarge(record_len as usize));
}
let len_bytes = record_len.to_le_bytes();
let checksum = compute_crc(&[&len_bytes, &record_bytes]);
let mut guard = self
.inner_file
.lock()
.map_err(|_| WalError::Internal("Mutex poisoned".into()))?;
guard.write_all(&len_bytes)?;
guard.write_all(&record_bytes)?;
guard.write_all(&checksum.to_le_bytes())?;
guard.sync_all()?;
trace!(
len = record_len,
crc = format_args!("{checksum:08x}"),
"WAL record appended"
);
Ok(())
}
pub fn replay_iter(&self) -> Result<WalIter<T>, WalError> {
debug!(path = %self.path.display(), "WAL replay started");
let start_offset = WalHeader::HEADER_DISK_SIZE as u64;
Ok(WalIter {
file: Arc::clone(&self.inner_file),
offset: start_offset,
max_record_size: self.header.max_record_size as usize,
_phantom: std::marker::PhantomData,
})
}
pub fn truncate(&mut self) -> Result<(), WalError> {
let mut guard = self
.inner_file
.lock()
.map_err(|_| WalError::Internal("Mutex poisoned".into()))?;
guard.set_len(0)?;
guard.seek(SeekFrom::Start(0))?;
write_header(&mut *guard, &self.header)?;
guard.sync_all()?;
info!(path = %self.path.display(), "WAL truncated");
Ok(())
}
#[allow(dead_code)]
pub fn rotate_next(&mut self) -> Result<u64, WalError> {
{
let guard = self
.inner_file
.lock()
.map_err(|_| WalError::Internal("Mutex poisoned".into()))?;
guard.sync_all()?;
}
let next_seq = self
.header
.wal_seq
.checked_add(1)
.ok_or_else(|| WalError::Internal("WAL sequence number overflow".into()))?;
let cur_path = PathBuf::from(&self.path);
let dir = cur_path.parent().unwrap_or_else(|| Path::new("."));
let next_path = dir.join(format!("{next_seq:06}.log"));
let new_wal = Wal::<T>::open(&next_path, Some(self.header.max_record_size))?;
*self = new_wal;
Ok(next_seq)
}
#[allow(dead_code)]
pub fn path(&self) -> &Path {
&self.path
}
pub fn wal_seq(&self) -> u64 {
self.header.wal_seq
}
#[allow(dead_code)]
pub fn max_record_size(&self) -> u32 {
self.header.max_record_size
}
#[allow(dead_code)]
pub fn file_size(&self) -> Result<u64, WalError> {
let guard = self
.inner_file
.lock()
.map_err(|_| WalError::Internal("Mutex poisoned".into()))?;
Ok(guard.metadata()?.len())
}
}
impl<T: WalData> Drop for Wal<T> {
fn drop(&mut self) {
match self.inner_file.lock() {
Ok(guard) => {
if let Err(e) = guard.sync_all() {
error!(path = %self.path.display(), error = %e, "WAL sync failed on drop");
}
}
Err(poisoned) => {
let file = poisoned.into_inner();
if let Err(e) = file.sync_all() {
error!(path = %self.path.display(), error = %e, "WAL sync failed on drop (poisoned lock)");
} else {
warn!(path = %self.path.display(), "WAL recovered and synced after poisoned lock");
}
}
}
}
}
pub struct WalIter<T: WalData> {
file: Arc<Mutex<File>>,
offset: u64,
max_record_size: usize,
_phantom: std::marker::PhantomData<T>,
}
impl<T: WalData> std::fmt::Debug for WalIter<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("WalIter")
.field("offset", &self.offset)
.field("max_record_size", &self.max_record_size)
.finish_non_exhaustive()
}
}
impl<T: WalData> Iterator for WalIter<T> {
type Item = Result<T, WalError>;
fn next(&mut self) -> Option<Self::Item> {
let mut guard = match self.file.lock() {
Ok(g) => g,
Err(_) => return Some(Err(WalError::Internal("Mutex poisoned".into()))),
};
if let Err(e) = guard.seek(SeekFrom::Start(self.offset)) {
return Some(Err(WalError::Io(e)));
}
let mut len_bytes = [0u8; U32_SIZE];
match guard.read_exact(&mut len_bytes) {
Ok(_) => {}
Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => {
trace!(offset = self.offset, "WAL replay reached end of file");
return None;
}
Err(e) => return Some(Err(WalError::Io(e))),
}
let record_len = u32::from_le_bytes(len_bytes) as usize;
if record_len > self.max_record_size {
return Some(Err(WalError::RecordTooLarge(record_len)));
}
trace!(offset = self.offset, len = record_len, "WAL reading record");
let mut record_bytes = vec![0u8; record_len];
if let Err(e) = guard.read_exact(&mut record_bytes) {
if e.kind() == io::ErrorKind::UnexpectedEof {
warn!(
offset = self.offset,
len = record_len,
"WAL truncated record (partial payload)"
);
return Some(Err(WalError::UnexpectedEof));
}
return Some(Err(WalError::Io(e)));
}
let mut checksum_bytes = [0u8; U32_SIZE];
if let Err(e) = guard.read_exact(&mut checksum_bytes) {
if e.kind() == io::ErrorKind::UnexpectedEof {
warn!(
offset = self.offset,
len = record_len,
"WAL truncated record (partial checksum)"
);
return Some(Err(WalError::UnexpectedEof));
}
return Some(Err(WalError::Io(e)));
}
let stored_checksum = u32::from_le_bytes(checksum_bytes);
if let Ok(pos) = guard.stream_position() {
self.offset = pos;
}
if let Err(e) = verify_crc(&[&len_bytes, &record_bytes], stored_checksum) {
warn!(
offset = self.offset,
len = record_len,
"WAL record checksum mismatch"
);
return Some(Err(e));
}
match encoding::decode_from_slice::<T>(&record_bytes) {
Ok((record, _)) => Some(Ok(record)),
Err(e) => Some(Err(WalError::Encoding(e))),
}
}
}
fn write_header<W: Write>(writer: &mut W, header: &WalHeader) -> Result<(), WalError> {
let header_bytes = encoding::encode_to_vec(header)?;
let checksum = compute_crc(&[&header_bytes]);
writer.write_all(&header_bytes)?;
writer.write_all(&checksum.to_le_bytes())?;
Ok(())
}
fn read_and_validate_header<R: Read>(reader: &mut R) -> Result<WalHeader, WalError> {
let mut header_bytes = vec![0u8; WalHeader::ENCODED_SIZE];
reader.read_exact(&mut header_bytes)?;
let mut checksum_bytes = [0u8; U32_SIZE];
reader.read_exact(&mut checksum_bytes)?;
let stored_checksum = u32::from_le_bytes(checksum_bytes);
verify_crc(&[&header_bytes], stored_checksum)
.map_err(|_| WalError::InvalidHeader("header checksum mismatch".into()))?;
let (header, _) = encoding::decode_from_slice::<WalHeader>(&header_bytes)?;
if header.magic != WalHeader::MAGIC {
return Err(WalError::InvalidHeader("bad magic".into()));
}
if header.version != WalHeader::VERSION {
return Err(WalError::InvalidHeader(format!(
"unsupported version {}",
header.version
)));
}
Ok(header)
}
fn compute_crc(parts: &[&[u8]]) -> u32 {
let mut hasher = Crc32::new();
for part in parts {
hasher.update(part);
}
hasher.finalize()
}
fn verify_crc(parts: &[&[u8]], expected: u32) -> Result<(), WalError> {
let computed = compute_crc(parts);
if computed != expected {
return Err(WalError::ChecksumMismatch);
}
Ok(())
}