use anyhow::{Context, Result};
use serde::{Deserialize, Serialize};
use std::fs::{File, OpenOptions};
use std::io::{BufReader, BufWriter, Read, Seek, SeekFrom, Write};
use std::path::Path;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum LogEntry {
Insert { id: String, vector: Vec<f32> },
Update { id: String, vector: Vec<f32> },
Delete { id: String },
BeginTx { tx_id: u64 },
CommitTx { tx_id: u64 },
AbortTx { tx_id: u64 },
Checkpoint { sequence: u64 },
}
pub struct WriteAheadLog {
file: File,
writer: BufWriter<File>,
next_sequence: u64,
last_checkpoint: u64,
entry_count: u64,
}
impl WriteAheadLog {
pub fn open<P: AsRef<Path>>(path: P) -> Result<Self> {
let file = OpenOptions::new()
.read(true)
.write(true)
.create(true)
.truncate(false)
.open(&path)
.context("Failed to open WAL file")?;
let writer_file = file.try_clone()?;
let writer = BufWriter::new(writer_file);
let (entry_count, last_seq) = Self::scan_log(&file)?;
let next_sequence = if entry_count > 0 { last_seq + 1 } else { 0 };
Ok(Self {
file,
writer,
next_sequence,
last_checkpoint: 0,
entry_count,
})
}
pub fn append(&mut self, entry: LogEntry) -> Result<u64> {
let sequence = self.next_sequence;
self.next_sequence += 1;
self.entry_count += 1;
let record = LogRecord { sequence, entry };
let serialized = bincode::serialize(&record).context("Failed to serialize log entry")?;
let len = serialized.len() as u32;
self.writer.write_all(&len.to_le_bytes())?;
self.writer.write_all(&serialized)?;
self.writer.flush()?;
Ok(sequence)
}
pub fn replay(&mut self) -> Result<Vec<LogEntry>> {
let mut reader = BufReader::new(self.file.try_clone()?);
reader.seek(SeekFrom::Start(0))?;
let mut entries = Vec::new();
let mut last_checkpoint_seq = 0;
loop {
let mut len_bytes = [0u8; 4];
match reader.read_exact(&mut len_bytes) {
Ok(_) => {}
Err(ref e) if e.kind() == std::io::ErrorKind::UnexpectedEof => break,
Err(e) => return Err(e.into()),
}
let len = u32::from_le_bytes(len_bytes) as usize;
let mut buffer = vec![0u8; len];
reader.read_exact(&mut buffer)?;
let record: LogRecord =
bincode::deserialize(&buffer).context("Failed to deserialize log entry")?;
if let LogEntry::Checkpoint { sequence } = record.entry {
last_checkpoint_seq = sequence;
entries.clear(); } else {
entries.push(record.entry);
}
}
self.last_checkpoint = last_checkpoint_seq;
Ok(entries)
}
pub fn checkpoint(&mut self) -> Result<()> {
if self.entry_count == 0 {
return Ok(());
}
let checkpoint_seq = self.next_sequence - 1;
let record = LogRecord {
sequence: checkpoint_seq,
entry: LogEntry::Checkpoint {
sequence: checkpoint_seq,
},
};
let serialized = bincode::serialize(&record).context("Failed to serialize checkpoint")?;
let len = serialized.len() as u32;
self.writer.write_all(&len.to_le_bytes())?;
self.writer.write_all(&serialized)?;
self.entry_count += 1;
self.writer.flush()?;
self.file.sync_all()?;
self.last_checkpoint = checkpoint_seq;
Ok(())
}
pub fn truncate(&mut self) -> Result<()> {
self.writer.flush()?;
self.file.sync_all()?;
self.file.set_len(0)?;
self.file.seek(SeekFrom::Start(0))?;
self.next_sequence = 0;
self.last_checkpoint = 0;
self.entry_count = 0;
Ok(())
}
pub fn len(&self) -> u64 {
self.entry_count
}
pub fn is_empty(&self) -> bool {
self.entry_count == 0
}
fn scan_log(file: &File) -> Result<(u64, u64)> {
let mut reader = BufReader::new(file.try_clone()?);
reader.seek(SeekFrom::Start(0))?;
let mut entry_count = 0u64;
let mut last_seq = 0u64;
loop {
let mut len_bytes = [0u8; 4];
match reader.read_exact(&mut len_bytes) {
Ok(_) => {}
Err(ref e) if e.kind() == std::io::ErrorKind::UnexpectedEof => break,
Err(e) => return Err(e.into()),
}
let len = u32::from_le_bytes(len_bytes) as usize;
let mut buffer = vec![0u8; len];
reader.read_exact(&mut buffer)?;
if let Ok(record) = bincode::deserialize::<LogRecord>(&buffer) {
last_seq = record.sequence;
entry_count += 1;
}
}
Ok((entry_count, last_seq))
}
}
#[derive(Debug, Serialize, Deserialize)]
struct LogRecord {
sequence: u64,
entry: LogEntry,
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::NamedTempFile;
#[test]
fn test_create_wal() {
let temp_file = NamedTempFile::new().unwrap();
let wal = WriteAheadLog::open(temp_file.path()).unwrap();
assert_eq!(wal.len(), 0);
assert!(wal.is_empty());
}
#[test]
fn test_append_entry() {
let temp_file = NamedTempFile::new().unwrap();
let mut wal = WriteAheadLog::open(temp_file.path()).unwrap();
let entry = LogEntry::Insert {
id: "doc1".to_string(),
vector: vec![1.0, 2.0, 3.0],
};
let seq = wal.append(entry).unwrap();
assert_eq!(seq, 0);
assert_eq!(wal.len(), 1);
}
#[test]
fn test_replay_empty() {
let temp_file = NamedTempFile::new().unwrap();
let mut wal = WriteAheadLog::open(temp_file.path()).unwrap();
let entries = wal.replay().unwrap();
assert_eq!(entries.len(), 0);
}
#[test]
fn test_replay_entries() {
let temp_file = NamedTempFile::new().unwrap();
let mut wal = WriteAheadLog::open(temp_file.path()).unwrap();
wal.append(LogEntry::Insert {
id: "doc1".to_string(),
vector: vec![1.0],
})
.unwrap();
wal.append(LogEntry::Update {
id: "doc1".to_string(),
vector: vec![2.0],
})
.unwrap();
wal.append(LogEntry::Delete {
id: "doc1".to_string(),
})
.unwrap();
let entries = wal.replay().unwrap();
assert_eq!(entries.len(), 3);
match &entries[0] {
LogEntry::Insert { id, .. } => assert_eq!(id, "doc1"),
_ => panic!("Expected Insert"),
}
match &entries[1] {
LogEntry::Update { id, .. } => assert_eq!(id, "doc1"),
_ => panic!("Expected Update"),
}
match &entries[2] {
LogEntry::Delete { id } => assert_eq!(id, "doc1"),
_ => panic!("Expected Delete"),
}
}
#[test]
fn test_checkpoint() {
let temp_file = NamedTempFile::new().unwrap();
let mut wal = WriteAheadLog::open(temp_file.path()).unwrap();
wal.append(LogEntry::Insert {
id: "doc1".to_string(),
vector: vec![1.0],
})
.unwrap();
wal.append(LogEntry::Insert {
id: "doc2".to_string(),
vector: vec![2.0],
})
.unwrap();
wal.checkpoint().unwrap();
wal.append(LogEntry::Insert {
id: "doc3".to_string(),
vector: vec![3.0],
})
.unwrap();
let entries = wal.replay().unwrap();
assert_eq!(entries.len(), 1);
match &entries[0] {
LogEntry::Insert { id, .. } => assert_eq!(id, "doc3"),
_ => panic!("Expected Insert"),
}
}
#[test]
fn test_truncate() {
let temp_file = NamedTempFile::new().unwrap();
let mut wal = WriteAheadLog::open(temp_file.path()).unwrap();
for i in 0..10 {
wal.append(LogEntry::Insert {
id: format!("doc{}", i),
vector: vec![i as f32],
})
.unwrap();
}
assert_eq!(wal.len(), 10);
wal.truncate().unwrap();
assert_eq!(wal.len(), 0);
assert!(wal.is_empty());
let entries = wal.replay().unwrap();
assert_eq!(entries.len(), 0);
}
#[test]
fn test_crash_recovery() {
let temp_file = NamedTempFile::new().unwrap();
{
let mut wal = WriteAheadLog::open(temp_file.path()).unwrap();
for i in 0..5 {
wal.append(LogEntry::Insert {
id: format!("doc{}", i),
vector: vec![i as f32],
})
.unwrap();
}
}
{
let mut wal = WriteAheadLog::open(temp_file.path()).unwrap();
let entries = wal.replay().unwrap();
assert_eq!(entries.len(), 5);
for (i, entry) in entries.iter().enumerate() {
match entry {
LogEntry::Insert { id, .. } => assert_eq!(id, &format!("doc{}", i)),
_ => panic!("Expected Insert"),
}
}
}
}
}