use std::fs::{File, OpenOptions};
use std::io::{Read, Seek, SeekFrom, Write};
use std::path::Path;
use mentedb_core::error::{MenteError, MenteResult};
use tracing::{debug, info, trace};
pub type Lsn = u64;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[repr(u8)]
pub enum WalEntryType {
PageWrite = 1,
Commit = 2,
Checkpoint = 3,
}
impl TryFrom<u8> for WalEntryType {
type Error = MenteError;
fn try_from(v: u8) -> MenteResult<Self> {
match v {
1 => Ok(Self::PageWrite),
2 => Ok(Self::Commit),
3 => Ok(Self::Checkpoint),
_ => Err(MenteError::Storage(format!("invalid WAL entry type: {v}"))),
}
}
}
#[derive(Debug, Clone)]
pub struct WalEntry {
pub lsn: u64,
pub entry_type: WalEntryType,
pub page_id: u64,
pub data: Vec<u8>,
pub checksum: u32,
}
pub struct Wal {
file: File,
next_lsn: u64,
}
const MIN_PAYLOAD: usize = 17;
impl Wal {
pub fn open(dir_path: &Path) -> MenteResult<Self> {
let wal_path = dir_path.join("wal.log");
let exists = wal_path.exists()
&& std::fs::metadata(&wal_path)
.map(|m| m.len() > 0)
.unwrap_or(false);
let file = OpenOptions::new()
.read(true)
.write(true)
.create(true)
.truncate(false)
.open(&wal_path)?;
let mut wal = Self { file, next_lsn: 1 };
if exists {
let entries = wal.read_all_entries()?;
if let Some(last) = entries.last() {
wal.next_lsn = last.lsn + 1;
}
info!(
next_lsn = wal.next_lsn,
entries = entries.len(),
"opened existing WAL"
);
} else {
info!("created new WAL");
}
Ok(wal)
}
pub fn append(
&mut self,
entry_type: WalEntryType,
page_id: u64,
data: &[u8],
) -> MenteResult<Lsn> {
let lsn = self.next_lsn;
self.next_lsn += 1;
let compressed = lz4_flex::compress_prepend_size(data);
let payload_len = 8 + 1 + 8 + compressed.len();
let mut payload = Vec::with_capacity(payload_len);
payload.extend_from_slice(&lsn.to_le_bytes());
payload.push(entry_type as u8);
payload.extend_from_slice(&page_id.to_le_bytes());
payload.extend_from_slice(&compressed);
let crc = {
let mut h = crc32fast::Hasher::new();
h.update(&payload);
h.finalize()
};
self.file.seek(SeekFrom::End(0))?;
self.file.write_all(&(payload_len as u32).to_le_bytes())?;
self.file.write_all(&payload)?;
self.file.write_all(&crc.to_le_bytes())?;
trace!(lsn, page_id, "appended WAL entry");
Ok(lsn)
}
pub fn sync(&mut self) -> MenteResult<()> {
self.file.sync_data()?;
debug!("WAL synced");
Ok(())
}
pub fn iterate(&mut self) -> MenteResult<Vec<WalEntry>> {
self.read_all_entries()
}
pub fn truncate(&mut self, before_lsn: Lsn) -> MenteResult<()> {
let entries = self.read_all_entries()?;
let to_keep: Vec<&WalEntry> = entries.iter().filter(|e| e.lsn >= before_lsn).collect();
self.file.seek(SeekFrom::Start(0))?;
self.file.set_len(0)?;
for entry in to_keep {
let compressed = lz4_flex::compress_prepend_size(&entry.data);
let payload_len = 8 + 1 + 8 + compressed.len();
let mut payload = Vec::with_capacity(payload_len);
payload.extend_from_slice(&entry.lsn.to_le_bytes());
payload.push(entry.entry_type as u8);
payload.extend_from_slice(&entry.page_id.to_le_bytes());
payload.extend_from_slice(&compressed);
let crc = {
let mut h = crc32fast::Hasher::new();
h.update(&payload);
h.finalize()
};
self.file.write_all(&(payload_len as u32).to_le_bytes())?;
self.file.write_all(&payload)?;
self.file.write_all(&crc.to_le_bytes())?;
}
self.file.sync_data()?;
debug!(before_lsn, "WAL truncated");
Ok(())
}
pub fn next_lsn(&self) -> Lsn {
self.next_lsn
}
fn read_all_entries(&mut self) -> MenteResult<Vec<WalEntry>> {
self.file.seek(SeekFrom::Start(0))?;
let file_len = self.file.metadata()?.len();
let mut offset: u64 = 0;
let mut entries = Vec::new();
while offset + 4 <= file_len {
let mut len_buf = [0u8; 4];
if self.file.read_exact(&mut len_buf).is_err() {
break;
}
let payload_len = u32::from_le_bytes(len_buf) as usize;
offset += 4;
if payload_len < MIN_PAYLOAD || offset + payload_len as u64 + 4 > file_len {
break;
}
let mut payload = vec![0u8; payload_len];
if self.file.read_exact(&mut payload).is_err() {
break;
}
offset += payload_len as u64;
let mut crc_buf = [0u8; 4];
if self.file.read_exact(&mut crc_buf).is_err() {
break;
}
let stored_crc = u32::from_le_bytes(crc_buf);
offset += 4;
let computed_crc = {
let mut h = crc32fast::Hasher::new();
h.update(&payload);
h.finalize()
};
if computed_crc != stored_crc {
break; }
let lsn = u64::from_le_bytes(payload[0..8].try_into().unwrap());
let entry_type = match WalEntryType::try_from(payload[8]) {
Ok(t) => t,
Err(_) => break,
};
let page_id = u64::from_le_bytes(payload[9..17].try_into().unwrap());
let compressed_data = &payload[17..];
let data = lz4_flex::decompress_size_prepended(compressed_data)
.map_err(|e| MenteError::Storage(format!("LZ4 decompress failed: {e}")))?;
entries.push(WalEntry {
lsn,
entry_type,
page_id,
data,
checksum: stored_crc,
});
}
Ok(entries)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn setup() -> (tempfile::TempDir, Wal) {
let dir = tempfile::tempdir().unwrap();
let wal = Wal::open(dir.path()).unwrap();
(dir, wal)
}
#[test]
fn test_append_and_iterate() {
let (_dir, mut wal) = setup();
let lsn1 = wal.append(WalEntryType::PageWrite, 1, b"hello").unwrap();
let lsn2 = wal.append(WalEntryType::PageWrite, 2, b"world").unwrap();
assert_eq!(lsn1, 1);
assert_eq!(lsn2, 2);
let entries = wal.iterate().unwrap();
assert_eq!(entries.len(), 2);
assert_eq!(entries[0].lsn, 1);
assert_eq!(entries[0].data, b"hello");
assert_eq!(entries[1].lsn, 2);
assert_eq!(entries[1].data, b"world");
}
#[test]
fn test_sync() {
let (_dir, mut wal) = setup();
wal.append(WalEntryType::Commit, 0, b"").unwrap();
wal.sync().unwrap(); }
#[test]
fn test_truncate() {
let (_dir, mut wal) = setup();
wal.append(WalEntryType::PageWrite, 1, b"a").unwrap();
wal.append(WalEntryType::PageWrite, 2, b"b").unwrap();
wal.append(WalEntryType::Checkpoint, 0, b"").unwrap();
wal.truncate(3).unwrap();
let entries = wal.iterate().unwrap();
assert_eq!(entries.len(), 1);
assert_eq!(entries[0].lsn, 3);
}
#[test]
fn test_recovery_reopen() {
let dir = tempfile::tempdir().unwrap();
{
let mut wal = Wal::open(dir.path()).unwrap();
wal.append(WalEntryType::PageWrite, 10, b"recovery-data")
.unwrap();
wal.sync().unwrap();
}
{
let mut wal = Wal::open(dir.path()).unwrap();
assert_eq!(wal.next_lsn(), 2);
let entries = wal.iterate().unwrap();
assert_eq!(entries.len(), 1);
assert_eq!(entries[0].page_id, 10);
assert_eq!(entries[0].data, b"recovery-data");
}
}
#[test]
fn test_empty_data_entry() {
let (_dir, mut wal) = setup();
let lsn = wal.append(WalEntryType::Checkpoint, 0, b"").unwrap();
let entries = wal.iterate().unwrap();
assert_eq!(entries.len(), 1);
assert_eq!(entries[0].lsn, lsn);
assert!(entries[0].data.is_empty());
}
#[test]
fn test_large_data_compression() {
let (_dir, mut wal) = setup();
let big_data = vec![0xABu8; 8192];
wal.append(WalEntryType::PageWrite, 5, &big_data).unwrap();
let entries = wal.iterate().unwrap();
assert_eq!(entries.len(), 1);
assert_eq!(entries[0].data, big_data);
}
}