use std::collections::HashMap;
use std::fs::{File, OpenOptions};
use std::io::{Read, Seek, SeekFrom, Write};
use std::path::{Path, PathBuf};
use std::time::{SystemTime, UNIX_EPOCH};
use crate::error::{Result, SQLRiteError};
use crate::sql::pager::page::PAGE_SIZE;
use crate::sql::pager::pager::{AccessMode, acquire_lock};
pub const WAL_HEADER_SIZE: usize = 32;
pub const WAL_MAGIC: &[u8; 8] = b"SQLRWAL\0";
pub const WAL_FORMAT_VERSION: u32 = 1;
pub const FRAME_HEADER_SIZE: usize = 16;
pub const FRAME_SIZE: usize = FRAME_HEADER_SIZE + PAGE_SIZE;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct WalHeader {
pub salt: u32,
pub checkpoint_seq: u32,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct FrameHeader {
pub page_num: u32,
pub commit_page_count: u32,
pub salt: u32,
pub checksum: u32,
}
impl FrameHeader {
pub fn is_commit(&self) -> bool {
self.commit_page_count != 0
}
}
pub struct Wal {
file: File,
path: PathBuf,
header: WalHeader,
latest_frame: HashMap<u32, u64>,
last_commit_offset: u64,
last_commit_page_count: Option<u32>,
frame_count: usize,
}
impl Wal {
pub fn create(path: &Path) -> Result<Self> {
let file = OpenOptions::new()
.read(true)
.write(true)
.create(true)
.truncate(true)
.open(path)?;
acquire_lock(&file, path, AccessMode::ReadWrite)?;
let salt = random_salt();
let header = WalHeader {
salt,
checkpoint_seq: 0,
};
let mut wal = Self {
file,
path: path.to_path_buf(),
header,
latest_frame: HashMap::new(),
last_commit_offset: WAL_HEADER_SIZE as u64,
last_commit_page_count: None,
frame_count: 0,
};
wal.write_header()?;
wal.file.flush()?;
wal.file.sync_all()?;
Ok(wal)
}
pub fn open(path: &Path) -> Result<Self> {
Self::open_with_mode(path, AccessMode::ReadWrite)
}
pub fn open_with_mode(path: &Path, mode: AccessMode) -> Result<Self> {
let mut file = match mode {
AccessMode::ReadWrite => OpenOptions::new().read(true).write(true).open(path)?,
AccessMode::ReadOnly => OpenOptions::new().read(true).open(path)?,
};
acquire_lock(&file, path, mode)?;
let header = read_header(&mut file)?;
let mut wal = Self {
file,
path: path.to_path_buf(),
header,
latest_frame: HashMap::new(),
last_commit_offset: WAL_HEADER_SIZE as u64,
last_commit_page_count: None,
frame_count: 0,
};
wal.replay_frames()?;
Ok(wal)
}
pub fn header(&self) -> WalHeader {
self.header
}
pub fn frame_count(&self) -> usize {
self.frame_count
}
pub fn last_commit_page_count(&self) -> Option<u32> {
self.last_commit_page_count
}
pub fn load_committed_into(
&mut self,
dest: &mut HashMap<u32, Box<[u8; PAGE_SIZE]>>,
) -> Result<()> {
let pages: Vec<u32> = self.latest_frame.keys().copied().collect();
for page_num in pages {
if let Some(body) = self.read_page(page_num)? {
dest.insert(page_num, body);
}
}
Ok(())
}
pub fn append_frame(
&mut self,
page_num: u32,
content: &[u8; PAGE_SIZE],
commit_page_count: Option<u32>,
) -> Result<()> {
let mut header_buf = [0u8; FRAME_HEADER_SIZE];
header_buf[0..4].copy_from_slice(&page_num.to_le_bytes());
header_buf[4..8].copy_from_slice(&commit_page_count.unwrap_or(0).to_le_bytes());
header_buf[8..12].copy_from_slice(&self.header.salt.to_le_bytes());
let sum = compute_checksum(&header_buf[0..12], content);
header_buf[12..16].copy_from_slice(&sum.to_le_bytes());
let offset = self.file.seek(SeekFrom::End(0))?;
self.file.write_all(&header_buf)?;
self.file.write_all(content)?;
if commit_page_count.is_some() {
self.file.flush()?;
self.file.sync_all()?;
}
self.latest_frame.insert(page_num, offset);
if let Some(pc) = commit_page_count {
self.last_commit_offset = offset + FRAME_SIZE as u64;
self.last_commit_page_count = Some(pc);
}
self.frame_count += 1;
Ok(())
}
pub fn read_page(&mut self, page_num: u32) -> Result<Option<Box<[u8; PAGE_SIZE]>>> {
let Some(&offset) = self.latest_frame.get(&page_num) else {
return Ok(None);
};
if offset + FRAME_SIZE as u64 > self.last_commit_offset {
return Ok(None);
}
let (_hdr, body) = self.read_frame_at(offset)?;
Ok(Some(body))
}
pub fn truncate(&mut self) -> Result<()> {
self.header.salt = random_salt();
self.header.checkpoint_seq = self.header.checkpoint_seq.wrapping_add(1);
self.file.set_len(WAL_HEADER_SIZE as u64)?;
self.write_header()?;
self.file.flush()?;
self.file.sync_all()?;
self.latest_frame.clear();
self.last_commit_offset = WAL_HEADER_SIZE as u64;
self.last_commit_page_count = None;
self.frame_count = 0;
Ok(())
}
fn write_header(&mut self) -> Result<()> {
let mut buf = [0u8; WAL_HEADER_SIZE];
buf[0..8].copy_from_slice(WAL_MAGIC);
buf[8..12].copy_from_slice(&WAL_FORMAT_VERSION.to_le_bytes());
buf[12..16].copy_from_slice(&(PAGE_SIZE as u32).to_le_bytes());
buf[16..20].copy_from_slice(&self.header.salt.to_le_bytes());
buf[20..24].copy_from_slice(&self.header.checkpoint_seq.to_le_bytes());
self.file.seek(SeekFrom::Start(0))?;
self.file.write_all(&buf)?;
Ok(())
}
fn read_frame_at(&mut self, offset: u64) -> Result<(FrameHeader, Box<[u8; PAGE_SIZE]>)> {
self.file.seek(SeekFrom::Start(offset))?;
let mut header_buf = [0u8; FRAME_HEADER_SIZE];
self.file.read_exact(&mut header_buf)?;
let mut body = Box::new([0u8; PAGE_SIZE]);
self.file.read_exact(body.as_mut())?;
let page_num = u32::from_le_bytes(header_buf[0..4].try_into().unwrap());
let commit_page_count = u32::from_le_bytes(header_buf[4..8].try_into().unwrap());
let salt = u32::from_le_bytes(header_buf[8..12].try_into().unwrap());
let stored_checksum = u32::from_le_bytes(header_buf[12..16].try_into().unwrap());
if salt != self.header.salt {
return Err(SQLRiteError::General(format!(
"WAL frame at offset {offset}: salt mismatch (expected {:x}, got {:x})",
self.header.salt, salt
)));
}
let computed = compute_checksum(&header_buf[0..12], &body);
if computed != stored_checksum {
return Err(SQLRiteError::General(format!(
"WAL frame at offset {offset}: bad checksum (expected {stored_checksum:x}, got {computed:x})"
)));
}
Ok((
FrameHeader {
page_num,
commit_page_count,
salt,
checksum: stored_checksum,
},
body,
))
}
fn replay_frames(&mut self) -> Result<()> {
let file_len = self.file.seek(SeekFrom::End(0))?;
let mut offset = WAL_HEADER_SIZE as u64;
let mut pending: HashMap<u32, u64> = HashMap::new();
while offset + FRAME_SIZE as u64 <= file_len {
match self.read_frame_at(offset) {
Ok((header, _body)) => {
self.frame_count += 1;
pending.insert(header.page_num, offset);
if header.is_commit() {
for (p, o) in pending.drain() {
self.latest_frame.insert(p, o);
}
self.last_commit_offset = offset + FRAME_SIZE as u64;
self.last_commit_page_count = Some(header.commit_page_count);
}
offset += FRAME_SIZE as u64;
}
Err(_) => break,
}
}
Ok(())
}
}
impl std::fmt::Debug for Wal {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Wal")
.field("path", &self.path)
.field("salt", &format_args!("{:#x}", self.header.salt))
.field("checkpoint_seq", &self.header.checkpoint_seq)
.field("frame_count", &self.frame_count)
.field("last_commit_page_count", &self.last_commit_page_count)
.finish()
}
}
fn read_header(file: &mut File) -> Result<WalHeader> {
let mut buf = [0u8; WAL_HEADER_SIZE];
file.seek(SeekFrom::Start(0))?;
if file.read_exact(&mut buf).is_err() {
return Err(SQLRiteError::General(
"file is not a SQLRite WAL (too short / bad magic)".to_string(),
));
}
if &buf[0..8] != WAL_MAGIC {
return Err(SQLRiteError::General(
"file is not a SQLRite WAL (bad magic)".to_string(),
));
}
let version = u32::from_le_bytes(buf[8..12].try_into().unwrap());
if version != WAL_FORMAT_VERSION {
return Err(SQLRiteError::General(format!(
"unsupported WAL format version {version}; this build understands {WAL_FORMAT_VERSION}"
)));
}
let page_size = u32::from_le_bytes(buf[12..16].try_into().unwrap()) as usize;
if page_size != PAGE_SIZE {
return Err(SQLRiteError::General(format!(
"WAL page size {page_size} doesn't match engine's {PAGE_SIZE}"
)));
}
let salt = u32::from_le_bytes(buf[16..20].try_into().unwrap());
let checkpoint_seq = u32::from_le_bytes(buf[20..24].try_into().unwrap());
Ok(WalHeader {
salt,
checkpoint_seq,
})
}
fn random_salt() -> u32 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| (d.as_nanos() as u32) ^ (d.as_secs() as u32).rotate_left(13))
.unwrap_or(0xdeadbeef)
}
fn compute_checksum(header_bytes: &[u8], body: &[u8; PAGE_SIZE]) -> u32 {
let mut sum: u32 = 0;
for &b in header_bytes {
sum = sum.rotate_left(1).wrapping_add(b as u32);
}
for &b in body.iter() {
sum = sum.rotate_left(1).wrapping_add(b as u32);
}
sum
}
#[cfg(test)]
mod tests {
use super::*;
fn tmp_wal(name: &str) -> PathBuf {
let mut p = std::env::temp_dir();
let pid = std::process::id();
let nanos = std::time::SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.as_nanos())
.unwrap_or(0);
p.push(format!("sqlrite-wal-{pid}-{nanos}-{name}.wal"));
p
}
fn page(byte: u8) -> Box<[u8; PAGE_SIZE]> {
let mut b = Box::new([0u8; PAGE_SIZE]);
for (i, slot) in b.iter_mut().enumerate() {
*slot = byte.wrapping_add(i as u8);
}
b
}
#[test]
fn create_then_open_round_trips_an_empty_wal() {
let p = tmp_wal("empty");
let w = Wal::create(&p).unwrap();
assert_eq!(w.frame_count(), 0);
assert_eq!(w.last_commit_page_count(), None);
let salt = w.header().salt;
drop(w);
let w2 = Wal::open(&p).unwrap();
assert_eq!(w2.header().salt, salt);
assert_eq!(w2.frame_count(), 0);
assert_eq!(w2.last_commit_page_count(), None);
let _ = std::fs::remove_file(&p);
}
#[test]
fn single_commit_frame_round_trips() {
let p = tmp_wal("one_frame");
let mut w = Wal::create(&p).unwrap();
let content = page(0xab);
w.append_frame(7, &content, Some(42)).unwrap();
assert_eq!(w.frame_count(), 1);
assert_eq!(w.last_commit_page_count(), Some(42));
drop(w);
let mut w2 = Wal::open(&p).unwrap();
assert_eq!(w2.frame_count(), 1);
assert_eq!(w2.last_commit_page_count(), Some(42));
let read = w2.read_page(7).unwrap().expect("frame should be visible");
assert_eq!(read.as_ref(), content.as_ref());
assert!(
w2.read_page(99).unwrap().is_none(),
"untouched page is None"
);
let _ = std::fs::remove_file(&p);
}
#[test]
fn multi_frame_commits_and_latest_wins() {
let p = tmp_wal("latest_wins");
let mut w = Wal::create(&p).unwrap();
w.append_frame(1, &page(1), Some(10)).unwrap();
w.append_frame(1, &page(2), Some(10)).unwrap();
w.append_frame(1, &page(3), Some(10)).unwrap();
w.append_frame(2, &page(9), Some(10)).unwrap();
assert_eq!(w.frame_count(), 4);
drop(w);
let mut w2 = Wal::open(&p).unwrap();
assert_eq!(w2.read_page(1).unwrap().unwrap().as_ref(), page(3).as_ref());
assert_eq!(w2.read_page(2).unwrap().unwrap().as_ref(), page(9).as_ref());
let _ = std::fs::remove_file(&p);
}
#[test]
fn orphan_dirty_tail_preserves_previous_commit() {
let p = tmp_wal("dirty_tail");
let mut w = Wal::create(&p).unwrap();
w.append_frame(5, &page(50), Some(10)).unwrap(); w.append_frame(5, &page(51), None).unwrap(); drop(w);
let mut w2 = Wal::open(&p).unwrap();
let got = w2
.read_page(5)
.unwrap()
.expect("committed V1 should still be visible");
assert_eq!(got.as_ref(), page(50).as_ref());
assert_eq!(w2.frame_count(), 2);
let _ = std::fs::remove_file(&p);
}
#[test]
fn uncommitted_frame_for_untouched_page_returns_none() {
let p = tmp_wal("dirty_only");
let mut w = Wal::create(&p).unwrap();
w.append_frame(7, &page(70), None).unwrap(); drop(w);
let mut w2 = Wal::open(&p).unwrap();
assert_eq!(w2.read_page(7).unwrap(), None);
let _ = std::fs::remove_file(&p);
}
#[test]
fn truncate_resets_to_empty_and_rolls_salt() {
let p = tmp_wal("truncate");
let mut w = Wal::create(&p).unwrap();
w.append_frame(1, &page(11), Some(5)).unwrap();
w.append_frame(2, &page(22), Some(5)).unwrap();
let seq_before = w.header().checkpoint_seq;
let salt_before = w.header().salt;
w.truncate().unwrap();
assert_eq!(w.frame_count(), 0);
assert_eq!(w.last_commit_page_count(), None);
assert_eq!(w.header().checkpoint_seq, seq_before + 1);
let _ = salt_before;
drop(w);
let mut w2 = Wal::open(&p).unwrap();
assert_eq!(w2.frame_count(), 0);
assert_eq!(w2.read_page(1).unwrap(), None);
assert_eq!(w2.read_page(2).unwrap(), None);
let _ = std::fs::remove_file(&p);
}
#[test]
fn bad_magic_file_is_rejected() {
let p = tmp_wal("bad_magic");
std::fs::write(&p, b"not a WAL file").unwrap();
let err = Wal::open(&p).unwrap_err();
assert!(format!("{err}").contains("bad magic"));
let _ = std::fs::remove_file(&p);
}
#[test]
fn corrupt_frame_body_marks_end_of_log() {
let p = tmp_wal("bit_flip");
let mut w = Wal::create(&p).unwrap();
w.append_frame(1, &page(0x11), Some(5)).unwrap();
w.append_frame(2, &page(0x22), Some(5)).unwrap();
drop(w);
let body_offset = WAL_HEADER_SIZE + FRAME_SIZE + FRAME_HEADER_SIZE;
let mut buf = std::fs::read(&p).unwrap();
buf[body_offset] ^= 0xff;
std::fs::write(&p, &buf).unwrap();
let mut w2 = Wal::open(&p).unwrap();
assert_eq!(
w2.read_page(1).unwrap().unwrap().as_ref(),
page(0x11).as_ref()
);
assert_eq!(w2.read_page(2).unwrap(), None);
assert_eq!(w2.frame_count(), 1);
let _ = std::fs::remove_file(&p);
}
#[test]
fn partial_trailing_frame_is_ignored() {
let p = tmp_wal("partial");
let mut w = Wal::create(&p).unwrap();
w.append_frame(42, &page(42), Some(1)).unwrap();
drop(w);
{
let mut f = OpenOptions::new().write(true).open(&p).unwrap();
f.seek(SeekFrom::End(0)).unwrap();
f.write_all(&[0xaa; 2000]).unwrap();
}
let mut w2 = Wal::open(&p).unwrap();
assert_eq!(
w2.read_page(42).unwrap().unwrap().as_ref(),
page(42).as_ref()
);
assert_eq!(w2.frame_count(), 1);
let _ = std::fs::remove_file(&p);
}
}