use std::fs::{File, OpenOptions};
use std::io::{Read, Seek, SeekFrom, Write};
use std::path::Path;
use crate::error::Result;
use crate::wal::record::WalRecord;
pub struct WalWriter {
file: File,
next_lsn: u64,
next_tx_id: u64,
}
impl WalWriter {
pub fn new(path: impl AsRef<Path>) -> Result<Self> {
let path = path.as_ref();
let exists = path.exists() && path.metadata()?.len() > 0;
let mut file = OpenOptions::new()
.read(true)
.write(true)
.create(true)
.truncate(false)
.open(path)?;
let (next_lsn, next_tx_id) = if exists {
let records = Self::read_all_records_from_file(&mut file)?;
let max_lsn = records.iter().map(|r| r.lsn).max().unwrap_or(0);
let max_tx = records.iter().map(|r| r.tx_id).max().unwrap_or(0);
file.seek(SeekFrom::End(0))?;
(max_lsn + 1, max_tx + 1)
} else {
(1, 1)
};
Ok(Self {
file,
next_lsn,
next_tx_id,
})
}
pub fn begin_tx(&mut self) -> u64 {
let tx_id = self.next_tx_id;
self.next_tx_id += 1;
tx_id
}
pub fn log_page_write(
&mut self,
tx_id: u64,
page_id: u32,
before: &[u8],
after: &[u8],
) -> Result<u64> {
let lsn = self.alloc_lsn();
let record = WalRecord::page_write(lsn, tx_id, page_id, before, after);
self.append_record(&record)?;
Ok(lsn)
}
pub fn log_commit(&mut self, tx_id: u64) -> Result<u64> {
let lsn = self.alloc_lsn();
let record = WalRecord::commit(lsn, tx_id);
self.append_record(&record)?;
self.file.sync_all()?;
Ok(lsn)
}
pub fn log_checkpoint(&mut self) -> Result<u64> {
let lsn = self.alloc_lsn();
let record = WalRecord::checkpoint(lsn);
self.append_record(&record)?;
self.file.sync_all()?;
Ok(lsn)
}
pub fn current_lsn(&self) -> u64 {
self.next_lsn
}
pub fn truncate(&mut self) -> Result<()> {
self.file.set_len(0)?;
self.file.seek(SeekFrom::Start(0))?;
self.file.sync_all()?;
Ok(())
}
pub fn read_all_records(&mut self) -> Result<Vec<WalRecord>> {
self.file.seek(SeekFrom::Start(0))?;
Self::read_all_records_from_file(&mut self.file)
}
fn alloc_lsn(&mut self) -> u64 {
let lsn = self.next_lsn;
self.next_lsn += 1;
lsn
}
fn append_record(&mut self, record: &WalRecord) -> Result<()> {
let bytes = record.to_bytes();
self.file.write_all(&bytes)?;
Ok(())
}
fn read_all_records_from_file(file: &mut File) -> Result<Vec<WalRecord>> {
file.seek(SeekFrom::Start(0))?;
let mut all_bytes = Vec::new();
file.read_to_end(&mut all_bytes)?;
let mut records = Vec::new();
let mut offset = 0;
while offset < all_bytes.len() {
match WalRecord::from_bytes(&all_bytes[offset..]) {
Ok((record, consumed)) => {
records.push(record);
offset += consumed;
}
Err(_) => {
break;
}
}
}
file.seek(SeekFrom::End(0))?;
Ok(records)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::wal::record::WalOpType;
use tempfile::TempDir;
fn setup() -> (TempDir, WalWriter) {
let dir = TempDir::new().unwrap();
let wal = WalWriter::new(dir.path().join("wal.log")).unwrap();
(dir, wal)
}
#[test]
fn test_wal_writer_new_empty() {
let (_dir, wal) = setup();
assert_eq!(wal.current_lsn(), 1);
}
#[test]
fn test_wal_write_and_read_records() {
let (_dir, mut wal) = setup();
let tx = wal.begin_tx();
wal.log_page_write(tx, 5, &[0; 100], &[1; 100]).unwrap();
wal.log_commit(tx).unwrap();
let records = wal.read_all_records().unwrap();
assert_eq!(records.len(), 2);
assert_eq!(records[0].op_type, WalOpType::PageWrite);
assert_eq!(records[0].page_id, 5);
assert_eq!(records[1].op_type, WalOpType::Commit);
assert_eq!(records[1].tx_id, tx);
}
#[test]
fn test_wal_lsn_increments() {
let (_dir, mut wal) = setup();
let tx = wal.begin_tx();
let lsn1 = wal.log_page_write(tx, 1, &[], &[]).unwrap();
let lsn2 = wal.log_page_write(tx, 2, &[], &[]).unwrap();
let lsn3 = wal.log_commit(tx).unwrap();
assert_eq!(lsn1, 1);
assert_eq!(lsn2, 2);
assert_eq!(lsn3, 3);
}
#[test]
fn test_wal_checkpoint() {
let (_dir, mut wal) = setup();
let lsn = wal.log_checkpoint().unwrap();
assert_eq!(lsn, 1);
let records = wal.read_all_records().unwrap();
assert_eq!(records.len(), 1);
assert_eq!(records[0].op_type, WalOpType::Checkpoint);
}
#[test]
fn test_wal_truncate() {
let (_dir, mut wal) = setup();
let tx = wal.begin_tx();
wal.log_page_write(tx, 1, &[], &[]).unwrap();
wal.log_commit(tx).unwrap();
assert_eq!(wal.read_all_records().unwrap().len(), 2);
wal.truncate().unwrap();
assert_eq!(wal.read_all_records().unwrap().len(), 0);
}
#[test]
fn test_wal_reopen_resumes_lsn() {
let dir = TempDir::new().unwrap();
let path = dir.path().join("wal.log");
{
let mut wal = WalWriter::new(&path).unwrap();
let tx = wal.begin_tx();
wal.log_page_write(tx, 1, &[0; 50], &[1; 50]).unwrap();
wal.log_commit(tx).unwrap();
}
{
let mut wal = WalWriter::new(&path).unwrap();
let lsn = wal.log_checkpoint().unwrap();
assert!(lsn >= 3, "LSN should resume: got {lsn}");
let records = wal.read_all_records().unwrap();
assert_eq!(records.len(), 3);
}
}
#[test]
fn test_wal_multiple_transactions() {
let (_dir, mut wal) = setup();
let tx1 = wal.begin_tx();
wal.log_page_write(tx1, 1, &[0; 50], &[1; 50]).unwrap();
wal.log_commit(tx1).unwrap();
let tx2 = wal.begin_tx();
wal.log_page_write(tx2, 2, &[0; 50], &[2; 50]).unwrap();
let records = wal.read_all_records().unwrap();
assert_eq!(records.len(), 3); assert!(
records
.iter()
.any(|r| r.tx_id == tx1 && r.op_type == WalOpType::Commit)
);
assert!(
!records
.iter()
.any(|r| r.tx_id == tx2 && r.op_type == WalOpType::Commit)
);
}
}