use crate::error::{DbxError, DbxResult};
use crate::wal::{WalRecord, WriteAheadLog};
use std::fs::OpenOptions;
use std::io::Write;
use std::path::{Path, PathBuf};
use std::sync::Arc;
use std::time::Duration;
pub struct CheckpointManager {
wal: Arc<WriteAheadLog>,
interval: Duration,
wal_path: PathBuf,
}
impl CheckpointManager {
pub fn new(wal: Arc<WriteAheadLog>, wal_path: &Path) -> Self {
Self {
wal,
interval: Duration::from_secs(30),
wal_path: wal_path.to_path_buf(),
}
}
pub fn with_interval(mut self, interval: Duration) -> Self {
self.interval = interval;
self
}
pub fn checkpoint<F>(&self, apply_fn: F) -> DbxResult<u64>
where
F: Fn(&WalRecord) -> DbxResult<()>,
{
let records = self.wal.replay()?;
for record in &records {
if matches!(record, WalRecord::Checkpoint { .. }) {
continue;
}
apply_fn(record)?;
}
let seq = self.wal.current_sequence();
let checkpoint_record = WalRecord::Checkpoint { sequence: seq };
self.wal.append(&checkpoint_record)?;
self.wal.sync()?;
Ok(seq)
}
pub fn recover<F>(wal_path: &Path, apply_fn: F) -> DbxResult<usize>
where
F: Fn(&WalRecord) -> DbxResult<()>,
{
if !wal_path.exists() {
return Ok(0);
}
let wal = WriteAheadLog::open(wal_path)?;
let records = wal.replay()?;
let mut last_checkpoint_idx = None;
for (i, record) in records.iter().enumerate() {
if matches!(record, WalRecord::Checkpoint { .. }) {
last_checkpoint_idx = Some(i);
}
}
let start_idx = last_checkpoint_idx.map(|i| i + 1).unwrap_or(0);
let replay_count = records.len() - start_idx;
for record in &records[start_idx..] {
apply_fn(record)?;
}
Ok(replay_count)
}
pub fn trim_before(&self, sequence: u64) -> DbxResult<()> {
let records = self.wal.replay()?;
let mut last_checkpoint_idx = None;
for (i, record) in records.iter().enumerate() {
if let WalRecord::Checkpoint { sequence: seq } = record
&& *seq >= sequence
{
last_checkpoint_idx = Some(i);
}
}
let trimmed_records: Vec<WalRecord> = if let Some(idx) = last_checkpoint_idx {
records.into_iter().skip(idx).collect()
} else {
records
};
let temp_path = self.wal_path.with_extension("tmp");
let mut temp_file = OpenOptions::new()
.create(true)
.write(true)
.truncate(true)
.open(&temp_path)?;
for record in &trimmed_records {
let encoded = bincode::serialize(record)
.map_err(|e| DbxError::Wal(format!("serialization failed: {}", e)))?;
let len = (encoded.len() as u32).to_le_bytes();
temp_file.write_all(&len)?;
temp_file.write_all(&encoded)?;
}
temp_file.sync_all()?;
drop(temp_file);
std::fs::rename(&temp_path, &self.wal_path)?;
Ok(())
}
pub fn interval(&self) -> Duration {
self.interval
}
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::NamedTempFile;
#[test]
fn checkpoint_applies_wal() {
use std::cell::RefCell;
let temp_file = NamedTempFile::new().unwrap();
let wal = Arc::new(WriteAheadLog::open(temp_file.path()).unwrap());
let checkpoint_mgr = CheckpointManager::new(wal.clone(), temp_file.path());
let record1 = WalRecord::Insert {
table: "users".to_string(),
key: b"user:1".to_vec(),
value: b"Alice".to_vec(),
ts: 0,
};
let record2 = WalRecord::Delete {
table: "users".to_string(),
key: b"user:2".to_vec(),
ts: 1,
};
wal.append(&record1).unwrap();
wal.append(&record2).unwrap();
wal.sync().unwrap();
let applied_records = RefCell::new(Vec::new());
let apply_fn = |record: &WalRecord| {
applied_records.borrow_mut().push(record.clone());
Ok(())
};
let checkpoint_seq = checkpoint_mgr.checkpoint(apply_fn).unwrap();
assert!(checkpoint_seq > 0);
let records = applied_records.borrow();
assert_eq!(records.len(), 2);
assert_eq!(records[0], record1);
assert_eq!(records[1], record2);
}
#[test]
fn recover_replays_after_checkpoint() {
use std::cell::RefCell;
let temp_file = NamedTempFile::new().unwrap();
let wal = Arc::new(WriteAheadLog::open(temp_file.path()).unwrap());
let record1 = WalRecord::Insert {
table: "users".to_string(),
key: b"user:1".to_vec(),
value: b"Alice".to_vec(),
ts: 0,
};
wal.append(&record1).unwrap();
let checkpoint = WalRecord::Checkpoint { sequence: 1 };
wal.append(&checkpoint).unwrap();
let record2 = WalRecord::Insert {
table: "users".to_string(),
key: b"user:2".to_vec(),
value: b"Bob".to_vec(),
ts: 2, };
wal.append(&record2).unwrap();
wal.sync().unwrap();
let recovered_records = RefCell::new(Vec::new());
let apply_fn = |record: &WalRecord| {
recovered_records.borrow_mut().push(record.clone());
Ok(())
};
let count = CheckpointManager::recover(temp_file.path(), apply_fn).unwrap();
assert_eq!(count, 1);
let records = recovered_records.borrow();
assert_eq!(records.len(), 1);
assert_eq!(records[0], record2);
}
#[test]
fn trim_removes_old_records() {
let temp_file = NamedTempFile::new().unwrap();
let wal = Arc::new(WriteAheadLog::open(temp_file.path()).unwrap());
let checkpoint_mgr = CheckpointManager::new(wal.clone(), temp_file.path());
let record1 = WalRecord::Insert {
table: "users".to_string(),
key: b"user:1".to_vec(),
value: b"Alice".to_vec(),
ts: 0,
};
wal.append(&record1).unwrap();
let checkpoint = WalRecord::Checkpoint { sequence: 1 };
wal.append(&checkpoint).unwrap();
let record2 = WalRecord::Insert {
table: "users".to_string(),
key: b"user:2".to_vec(),
value: b"Bob".to_vec(),
ts: 2,
};
wal.append(&record2).unwrap();
wal.sync().unwrap();
checkpoint_mgr.trim_before(1).unwrap();
let wal2 = WriteAheadLog::open(temp_file.path()).unwrap();
let records = wal2.replay().unwrap();
assert_eq!(records.len(), 2);
assert!(matches!(records[0], WalRecord::Checkpoint { sequence: 1 }));
assert_eq!(records[1], record2);
}
#[test]
fn recover_empty_wal() {
let temp_file = NamedTempFile::new().unwrap();
std::fs::remove_file(temp_file.path()).unwrap();
let apply_fn = |_: &WalRecord| Ok(());
let count = CheckpointManager::recover(temp_file.path(), apply_fn).unwrap();
assert_eq!(count, 0);
}
#[test]
fn checkpoint_interval() {
let temp_file = NamedTempFile::new().unwrap();
let wal = Arc::new(WriteAheadLog::open(temp_file.path()).unwrap());
let checkpoint_mgr =
CheckpointManager::new(wal, temp_file.path()).with_interval(Duration::from_secs(60));
assert_eq!(checkpoint_mgr.interval(), Duration::from_secs(60));
}
}