use std::collections::HashMap;
use std::fs::{File, OpenOptions};
use std::io::{Read, Seek, SeekFrom, Write};
use std::path::{Path, PathBuf};
use crate::pager::codec::PAGE_SIZE;
const WAL_MAGIC: &[u8; 8] = b"SQL4WAL\0";
const WAL_VERSION: u32 = 1;
const WAL_HEADER_SIZE: usize = 32;
const FRAME_HEADER_SIZE: usize = 24;
const FRAME_SIZE: usize = FRAME_HEADER_SIZE + PAGE_SIZE;
const FRAME_TYPE_DATA: u32 = 1;
const FRAME_TYPE_COMMIT: u32 = 2;
const CHECKPOINT_THRESHOLD: usize = 100;
#[derive(Debug, Clone)]
struct Frame {
page_id: u32,
frame_type: u32,
txn_id: u32,
checksum: u32,
data: Vec<u8>,
}
impl Frame {
fn encode(&self) -> Vec<u8> {
let mut buf = Vec::with_capacity(FRAME_SIZE);
buf.extend_from_slice(&self.page_id.to_le_bytes());
buf.extend_from_slice(&self.frame_type.to_le_bytes());
buf.extend_from_slice(&self.txn_id.to_le_bytes());
buf.extend_from_slice(&self.checksum.to_le_bytes());
buf.extend_from_slice(&[0u8; 8]); buf.extend_from_slice(&self.data);
buf
}
fn decode(buf: &[u8]) -> Option<Frame> {
if buf.len() < FRAME_SIZE { return None; }
let page_id = u32::from_le_bytes(buf[0..4].try_into().ok()?);
let frame_type = u32::from_le_bytes(buf[4..8].try_into().ok()?);
let txn_id = u32::from_le_bytes(buf[8..12].try_into().ok()?);
let checksum = u32::from_le_bytes(buf[12..16].try_into().ok()?);
let data = buf[FRAME_HEADER_SIZE..FRAME_SIZE].to_vec();
let actual = compute_checksum(&data);
if actual != checksum { return None; }
Some(Frame { page_id, frame_type, txn_id, checksum, data })
}
}
fn compute_checksum(data: &[u8]) -> u32 {
data.chunks(4).fold(0u32, |acc, chunk| {
let mut word = [0u8; 4];
word[..chunk.len()].copy_from_slice(chunk);
acc ^ u32::from_le_bytes(word)
})
}
fn write_wal_header(file: &mut File, frame_count: u32) -> std::io::Result<()> {
let mut hdr = vec![0u8; WAL_HEADER_SIZE];
hdr[0..8].copy_from_slice(WAL_MAGIC);
hdr[8..12].copy_from_slice(&WAL_VERSION.to_le_bytes());
hdr[12..16].copy_from_slice(&(PAGE_SIZE as u32).to_le_bytes());
hdr[16..20].copy_from_slice(&frame_count.to_le_bytes());
file.seek(SeekFrom::Start(0))?;
file.write_all(&hdr)?;
file.flush()
}
fn read_wal_frame_count(file: &mut File) -> std::io::Result<u32> {
let mut hdr = vec![0u8; WAL_HEADER_SIZE];
file.seek(SeekFrom::Start(0))?;
file.read_exact(&mut hdr)?;
if &hdr[0..8] != WAL_MAGIC { return Ok(0); }
Ok(u32::from_le_bytes(hdr[16..20].try_into().unwrap()))
}
pub struct Wal {
wal_path: PathBuf,
wal_file: File,
frame_count: usize,
committed: HashMap<u32, Vec<u8>>,
dirty: HashMap<u32, Vec<u8>>,
pre_image: HashMap<u32, Vec<u8>>,
next_txn_id: u32,
in_txn: bool,
}
impl Wal {
pub fn open<P: AsRef<Path>>(db_path: P) -> std::io::Result<Self> {
let db_path = db_path.as_ref();
let wal_path = db_path.with_extension("sql5wal");
let wal_exists = wal_path.exists();
let wal_file = OpenOptions::new()
.read(true).write(true).create(true)
.open(&wal_path)?;
let mut wal = Wal {
wal_path,
wal_file,
frame_count: 0,
committed: HashMap::new(),
dirty: HashMap::new(),
pre_image: HashMap::new(),
next_txn_id: 1,
in_txn: false,
};
if wal_exists {
wal.replay()?;
} else {
write_wal_header(&mut wal.wal_file, 0)?;
}
Ok(wal)
}
pub fn begin(&mut self) {
self.dirty.clear();
self.pre_image.clear();
self.in_txn = true;
}
pub fn commit(&mut self) -> std::io::Result<()> {
if !self.in_txn { return Ok(()); }
let txn_id = self.next_txn_id;
self.next_txn_id += 1;
let dirty_pages: Vec<(u32, Vec<u8>)> = self.dirty.iter()
.map(|(k, v)| (*k, v.clone()))
.collect();
for (page_id, data) in &dirty_pages {
self.write_frame(*page_id, FRAME_TYPE_DATA, txn_id, data)?;
}
let commit_data = vec![0u8; PAGE_SIZE];
self.write_frame(u32::MAX, FRAME_TYPE_COMMIT, txn_id, &commit_data)?;
for (page_id, data) in self.dirty.drain() {
self.committed.insert(page_id, data);
}
self.in_txn = false;
Ok(())
}
pub fn rollback(&mut self) {
for (page_id, original_data) in self.pre_image.drain() {
self.committed.insert(page_id, original_data);
}
self.dirty.clear();
self.in_txn = false;
}
pub fn save_original(&mut self, page_id: u32, original_data: Vec<u8>) {
if !self.in_txn { return; }
self.pre_image.entry(page_id).or_insert(original_data);
}
pub fn get_original(&self, page_id: u32) -> Option<&Vec<u8>> {
self.pre_image.get(&page_id)
}
pub fn in_txn(&self) -> bool {
self.in_txn
}
pub fn get_committed_copy(&self, page_id: u32) -> Option<Vec<u8>> {
if self.dirty.contains_key(&page_id) {
return self.dirty.get(&page_id).cloned();
}
self.committed.get(&page_id).cloned()
}
pub fn write_page(&mut self, page_id: u32, data: Vec<u8>) {
if self.in_txn {
self.dirty.insert(page_id, data);
} else {
self.committed.insert(page_id, data.clone());
let txn_id = self.next_txn_id;
self.next_txn_id += 1;
if let Err(e) = self.write_frame(page_id, FRAME_TYPE_DATA, txn_id, &data) {
eprintln!("WAL write_frame error: {}", e);
}
let commit_data = vec![0u8; PAGE_SIZE];
if let Err(e) = self.write_frame(u32::MAX, FRAME_TYPE_COMMIT, txn_id, &commit_data) {
eprintln!("WAL write_commit error: {}", e);
}
}
}
pub fn read_page(&self, page_id: u32) -> Option<&[u8]> {
self.dirty.get(&page_id)
.or_else(|| self.committed.get(&page_id))
.map(|v| v.as_slice())
}
pub fn needs_checkpoint(&self) -> bool {
self.frame_count >= CHECKPOINT_THRESHOLD
}
pub fn frame_count(&self) -> usize {
self.frame_count
}
pub fn checkpoint<F>(&mut self, mut write_back: F) -> std::io::Result<()>
where
F: FnMut(u32, &[u8]) -> std::io::Result<()>,
{
for (page_id, data) in &self.committed {
write_back(*page_id, data)?;
}
self.wal_file.set_len(WAL_HEADER_SIZE as u64)?;
write_wal_header(&mut self.wal_file, 0)?;
self.frame_count = 0;
self.committed.clear();
Ok(())
}
fn write_frame(&mut self, page_id: u32, frame_type: u32, txn_id: u32, data: &[u8]) -> std::io::Result<()> {
let checksum = compute_checksum(data);
let frame = Frame { page_id, frame_type, txn_id, checksum, data: data.to_vec() };
let encoded = frame.encode();
let target_size = WAL_HEADER_SIZE as u64 + (self.frame_count as u64 + 1) * FRAME_SIZE as u64;
self.wal_file.set_len(target_size)?;
self.wal_file.sync_all()?;
let frame_offset = WAL_HEADER_SIZE as u64 + (self.frame_count as u64) * FRAME_SIZE as u64;
self.wal_file.seek(SeekFrom::Start(frame_offset))?;
self.wal_file.write_all(&encoded)?;
self.wal_file.sync_all()?;
self.frame_count += 1;
write_wal_header(&mut self.wal_file, self.frame_count as u32)?;
self.wal_file.sync_all()?;
Ok(())
}
fn replay(&mut self) -> std::io::Result<()> {
let frame_count = read_wal_frame_count(&mut self.wal_file)? as usize;
if frame_count == 0 { return Ok(()); }
let mut frames: Vec<Frame> = Vec::new();
for i in 0..frame_count {
let offset = WAL_HEADER_SIZE as u64 + (i as u64) * FRAME_SIZE as u64;
self.wal_file.seek(SeekFrom::Start(offset))?;
let mut buf = vec![0u8; FRAME_SIZE];
if self.wal_file.read_exact(&mut buf).is_err() { break; }
if let Some(frame) = Frame::decode(&buf) {
frames.push(frame);
}
}
let committed_txns: std::collections::HashSet<u32> = frames.iter()
.filter(|f| f.frame_type == FRAME_TYPE_COMMIT)
.map(|f| f.txn_id)
.collect();
for frame in &frames {
if frame.frame_type == FRAME_TYPE_DATA && committed_txns.contains(&frame.txn_id) {
self.committed.insert(frame.page_id, frame.data.clone());
}
}
self.frame_count = frame_count;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
fn tmp_path(name: &str) -> PathBuf {
PathBuf::from(format!("/tmp/sql5_wal_{}.db", name))
}
fn cleanup(name: &str) {
let _ = std::fs::remove_file(tmp_path(name));
let _ = std::fs::remove_file(tmp_path(name).with_extension("sql5wal"));
}
#[test]
fn write_and_read_in_txn() {
cleanup("txn");
let mut wal = Wal::open(tmp_path("txn")).unwrap();
wal.begin();
wal.write_page(0, vec![1u8; PAGE_SIZE]);
assert_eq!(wal.read_page(0).unwrap()[0], 1);
wal.commit().unwrap();
assert_eq!(wal.read_page(0).unwrap()[0], 1);
cleanup("txn");
}
#[test]
fn rollback_discards_dirty() {
cleanup("rollback");
let mut wal = Wal::open(tmp_path("rollback")).unwrap();
wal.begin();
wal.write_page(1, vec![42u8; PAGE_SIZE]);
wal.rollback();
assert!(wal.read_page(1).is_none());
cleanup("rollback");
}
#[test]
fn replay_after_crash() {
cleanup("replay");
{
let mut wal = Wal::open(tmp_path("replay")).unwrap();
wal.begin();
wal.write_page(5, vec![99u8; PAGE_SIZE]);
wal.commit().unwrap();
}
{
let wal = Wal::open(tmp_path("replay")).unwrap();
let data = wal.read_page(5).expect("page 5 should be in WAL after replay");
assert_eq!(data[0], 99);
}
cleanup("replay");
}
#[test]
fn partial_txn_not_replayed() {
cleanup("partial");
{
let mut wal = Wal::open(tmp_path("partial")).unwrap();
wal.begin();
wal.write_frame(3, FRAME_TYPE_DATA, 99, &vec![0xABu8; PAGE_SIZE]).unwrap();
}
{
let wal = Wal::open(tmp_path("partial")).unwrap();
assert!(wal.read_page(3).is_none(), "uncommitted page should not be visible");
}
cleanup("partial");
}
#[test]
fn auto_commit_mode() {
cleanup("auto");
let mut wal = Wal::open(tmp_path("auto")).unwrap();
wal.write_page(0, vec![77u8; PAGE_SIZE]);
assert_eq!(wal.read_page(0).unwrap()[0], 77);
cleanup("auto");
}
#[test]
fn checkpoint_clears_wal() {
cleanup("chkpt");
let mut wal = Wal::open(tmp_path("chkpt")).unwrap();
wal.begin();
wal.write_page(0, vec![55u8; PAGE_SIZE]);
wal.write_page(1, vec![66u8; PAGE_SIZE]);
wal.commit().unwrap();
let mut written: Vec<(u32, u8)> = Vec::new();
wal.checkpoint(|pid, data| {
written.push((pid, data[0]));
Ok(())
}).unwrap();
assert_eq!(wal.frame_count, 0);
assert!(wal.committed.is_empty());
assert!(written.iter().any(|(p, v)| *p == 0 && *v == 55));
assert!(written.iter().any(|(p, v)| *p == 1 && *v == 66));
cleanup("chkpt");
}
#[test]
fn checksum_corruption_ignored() {
cleanup("corrupt");
{
let mut wal = Wal::open(tmp_path("corrupt")).unwrap();
wal.begin();
wal.write_page(7, vec![11u8; PAGE_SIZE]);
wal.commit().unwrap();
}
{
let wal_path = tmp_path("corrupt").with_extension("sql5wal");
let mut f = OpenOptions::new().write(true).open(&wal_path).unwrap();
f.seek(SeekFrom::Start((WAL_HEADER_SIZE + 12) as u64)).unwrap();
f.write_all(&[0xFF, 0xFF, 0xFF, 0xFF]).unwrap();
}
{
let wal = Wal::open(tmp_path("corrupt")).unwrap();
assert!(wal.read_page(7).is_none(), "corrupted frame should be ignored");
}
cleanup("corrupt");
}
#[test]
fn multiple_pages_in_transaction() {
cleanup("multi");
let mut wal = Wal::open(tmp_path("multi")).unwrap();
wal.begin();
wal.write_page(0, vec![10u8; PAGE_SIZE]);
wal.write_page(1, vec![20u8; PAGE_SIZE]);
wal.write_page(2, vec![30u8; PAGE_SIZE]);
assert_eq!(wal.read_page(0).unwrap()[0], 10);
assert_eq!(wal.read_page(1).unwrap()[0], 20);
assert_eq!(wal.read_page(2).unwrap()[0], 30);
wal.commit().unwrap();
assert_eq!(wal.read_page(0).unwrap()[0], 10);
assert_eq!(wal.read_page(1).unwrap()[0], 20);
assert_eq!(wal.read_page(2).unwrap()[0], 30);
cleanup("multi");
}
#[test]
fn commit_after_rollback() {
cleanup("commit_after_rb");
let mut wal = Wal::open(tmp_path("commit_after_rb")).unwrap();
wal.begin();
wal.write_page(0, vec![1u8; PAGE_SIZE]);
wal.rollback();
wal.begin();
wal.write_page(0, vec![2u8; PAGE_SIZE]);
wal.commit().unwrap();
assert_eq!(wal.read_page(0).unwrap()[0], 2);
cleanup("commit_after_rb");
}
#[test]
fn nested_transaction_fails() {
cleanup("nested");
let mut wal = Wal::open(tmp_path("nested")).unwrap();
wal.begin();
wal.begin();
wal.write_page(0, vec![99u8; PAGE_SIZE]);
assert_eq!(wal.read_page(0).unwrap()[0], 99);
wal.commit().unwrap();
assert_eq!(wal.read_page(0).unwrap()[0], 99);
cleanup("nested");
}
#[test]
fn update_page_in_wal() {
cleanup("update");
let mut wal = Wal::open(tmp_path("update")).unwrap();
wal.begin();
wal.write_page(0, vec![1u8; PAGE_SIZE]);
wal.commit().unwrap();
wal.begin();
wal.write_page(0, vec![2u8; PAGE_SIZE]);
wal.commit().unwrap();
assert_eq!(wal.read_page(0).unwrap()[0], 2);
cleanup("update");
}
#[test]
fn wal_preserves_multiple_commits() {
cleanup("multi_commit");
let mut wal = Wal::open(tmp_path("multi_commit")).unwrap();
wal.begin();
wal.write_page(0, vec![10u8; PAGE_SIZE]);
wal.commit().unwrap();
wal.begin();
wal.write_page(1, vec![20u8; PAGE_SIZE]);
wal.commit().unwrap();
assert_eq!(wal.read_page(0).unwrap()[0], 10);
assert_eq!(wal.read_page(1).unwrap()[0], 20);
cleanup("multi_commit");
}
#[test]
fn rollback_then_write_same_page() {
cleanup("rb_write");
let mut wal = Wal::open(tmp_path("rb_write")).unwrap();
wal.begin();
wal.write_page(0, vec![1u8; PAGE_SIZE]);
wal.rollback();
wal.begin();
wal.write_page(0, vec![2u8; PAGE_SIZE]);
wal.commit().unwrap();
assert_eq!(wal.read_page(0).unwrap()[0], 2);
cleanup("rb_write");
}
#[test]
fn frame_count_after_commits() {
cleanup("frame_count");
let mut wal = Wal::open(tmp_path("frame_count")).unwrap();
wal.begin();
wal.write_page(0, vec![1u8; PAGE_SIZE]);
wal.commit().unwrap();
let count1 = wal.frame_count();
wal.begin();
wal.write_page(1, vec![2u8; PAGE_SIZE]);
wal.commit().unwrap();
let count2 = wal.frame_count();
assert!(count2 > count1);
cleanup("frame_count");
}
#[test]
fn checkpoint_callback_called_correctly() {
cleanup("callback");
let mut wal = Wal::open(tmp_path("callback")).unwrap();
wal.begin();
wal.write_page(5, vec![42u8; PAGE_SIZE]);
wal.write_page(10, vec![43u8; PAGE_SIZE]);
wal.commit().unwrap();
let mut received: Vec<u32> = Vec::new();
wal.checkpoint(|page_id, _data| {
received.push(page_id);
Ok(())
}).unwrap();
assert!(received.contains(&5));
assert!(received.contains(&10));
cleanup("callback");
}
#[test]
fn committed_cache_cleared_after_checkpoint() {
cleanup("clear_cache");
let mut wal = Wal::open(tmp_path("clear_cache")).unwrap();
wal.begin();
wal.write_page(0, vec![55u8; PAGE_SIZE]);
wal.commit().unwrap();
assert!(!wal.committed.is_empty());
wal.checkpoint(|_, _| Ok(())).unwrap();
assert!(wal.committed.is_empty());
cleanup("clear_cache");
}
#[test]
fn needs_checkpoint_threshold() {
cleanup("threshold");
let mut wal = Wal::open(tmp_path("threshold")).unwrap();
assert!(!wal.needs_checkpoint());
for i in 0..120 {
wal.begin();
wal.write_page(i as u32, vec![i as u8; PAGE_SIZE]);
wal.commit().unwrap();
}
assert!(wal.needs_checkpoint());
cleanup("threshold");
}
}