use std::collections::{HashMap, HashSet};
use std::io;
use std::path::Path;
use super::reader::WalReader;
use super::record::WalRecord;
use super::writer::WalWriter;
use crate::storage::engine::{Page, Pager, PAGE_SIZE};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CheckpointMode {
Passive,
Full,
Restart,
Truncate,
}
#[derive(Debug, Clone, Default)]
pub struct CheckpointResult {
pub transactions_processed: u64,
pub pages_checkpointed: u64,
pub records_processed: u64,
pub checkpoint_lsn: u64,
pub wal_truncated: bool,
}
#[derive(Debug)]
pub enum CheckpointError {
Io(io::Error),
Pager(String),
CorruptedWal(String),
NoWal,
}
impl std::fmt::Display for CheckpointError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Io(e) => write!(f, "I/O error: {}", e),
Self::Pager(msg) => write!(f, "Pager error: {}", msg),
Self::CorruptedWal(msg) => write!(f, "Corrupted WAL: {}", msg),
Self::NoWal => write!(f, "No WAL file found"),
}
}
}
impl std::error::Error for CheckpointError {}
impl From<io::Error> for CheckpointError {
fn from(e: io::Error) -> Self {
Self::Io(e)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum TxState {
Active,
Committed,
Aborted,
}
#[derive(Debug)]
struct PendingWrite {
tx_id: u64,
page_id: u32,
data: Vec<u8>,
lsn: u64,
}
pub struct Checkpointer {
mode: CheckpointMode,
}
impl Checkpointer {
pub fn new(mode: CheckpointMode) -> Self {
Self { mode }
}
pub fn default_mode() -> Self {
Self::new(CheckpointMode::Full)
}
pub fn checkpoint(
&self,
pager: &Pager,
wal_path: &Path,
) -> Result<CheckpointResult, CheckpointError> {
let wal_reader = match WalReader::open(wal_path) {
Ok(r) => r,
Err(e) if e.kind() == io::ErrorKind::NotFound => {
return Ok(CheckpointResult::default());
}
Err(e) => return Err(CheckpointError::Io(e)),
};
let mut tx_states: HashMap<u64, TxState> = HashMap::new();
let mut pending_writes: Vec<PendingWrite> = Vec::new();
let mut records_processed: u64 = 0;
let mut last_lsn: u64 = 0;
for record_result in wal_reader.iter() {
let (lsn, record) = record_result.map_err(CheckpointError::Io)?;
records_processed += 1;
last_lsn = lsn;
match record {
WalRecord::Begin { tx_id } => {
tx_states.insert(tx_id, TxState::Active);
}
WalRecord::Commit { tx_id } => {
tx_states.insert(tx_id, TxState::Committed);
}
WalRecord::Rollback { tx_id } => {
tx_states.insert(tx_id, TxState::Aborted);
}
WalRecord::PageWrite {
tx_id,
page_id,
data,
} => {
pending_writes.push(PendingWrite {
tx_id,
page_id,
data,
lsn,
});
}
WalRecord::Checkpoint {
lsn: _checkpoint_lsn,
} => {
}
WalRecord::TxCommitBatch { .. } => {
}
}
}
let committed_txs: HashSet<u64> = tx_states
.iter()
.filter(|(_, state)| **state == TxState::Committed)
.map(|(tx_id, _)| *tx_id)
.collect();
let mut latest_writes: HashMap<u32, Vec<u8>> = HashMap::new();
for write in pending_writes {
if committed_txs.contains(&write.tx_id) {
latest_writes.insert(write.page_id, write.data);
}
}
if !latest_writes.is_empty() {
pager
.set_checkpoint_in_progress(true, last_lsn)
.map_err(|e| CheckpointError::Pager(e.to_string()))?;
}
let mut pages_checkpointed: u64 = 0;
for (page_id, data) in &latest_writes {
if data.len() != PAGE_SIZE {
return Err(CheckpointError::CorruptedWal(format!(
"Page {} has wrong size: {} (expected {})",
page_id,
data.len(),
PAGE_SIZE
)));
}
let mut page_data = [0u8; PAGE_SIZE];
page_data.copy_from_slice(data);
let page = Page::from_bytes(page_data);
pager
.write_page(*page_id, page)
.map_err(|e| CheckpointError::Pager(e.to_string()))?;
pages_checkpointed += 1;
}
pager
.sync()
.map_err(|e| CheckpointError::Pager(e.to_string()))?;
if !latest_writes.is_empty() {
pager
.complete_checkpoint(last_lsn)
.map_err(|e| CheckpointError::Pager(e.to_string()))?;
}
let wal_truncated = matches!(
self.mode,
CheckpointMode::Restart | CheckpointMode::Truncate
);
if wal_truncated {
let mut wal_writer = WalWriter::open(wal_path)?;
wal_writer.truncate()?;
let checkpoint_record = WalRecord::Checkpoint { lsn: last_lsn };
wal_writer.append(&checkpoint_record)?;
wal_writer.sync()?;
}
Ok(CheckpointResult {
transactions_processed: committed_txs.len() as u64,
pages_checkpointed,
records_processed,
checkpoint_lsn: last_lsn,
wal_truncated,
})
}
pub fn recover(pager: &Pager, wal_path: &Path) -> Result<CheckpointResult, CheckpointError> {
if let Ok(header) = pager.header() {
if header.checkpoint_in_progress {
let _ = pager.set_checkpoint_in_progress(false, 0);
}
}
let checkpointer = Self::new(CheckpointMode::Truncate);
checkpointer.checkpoint(pager, wal_path)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::storage::engine::PageType;
use std::fs;
use std::time::{SystemTime, UNIX_EPOCH};
fn temp_dir() -> std::path::PathBuf {
let timestamp = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_nanos();
std::env::temp_dir().join(format!("reddb_checkpoint_test_{}", timestamp))
}
fn cleanup(dir: &Path) {
let _ = fs::remove_dir_all(dir);
}
#[test]
fn test_checkpoint_empty_wal() {
let dir = temp_dir();
let _ = fs::create_dir_all(&dir);
let db_path = dir.join("test.db");
let wal_path = dir.join("test.wal");
let pager = Pager::open_default(&db_path).unwrap();
let checkpointer = Checkpointer::default_mode();
let result = checkpointer.checkpoint(&pager, &wal_path).unwrap();
assert_eq!(result.transactions_processed, 0);
assert_eq!(result.pages_checkpointed, 0);
cleanup(&dir);
}
#[test]
fn test_checkpoint_committed_transaction() {
let dir = temp_dir();
let _ = fs::create_dir_all(&dir);
let db_path = dir.join("test.db");
let wal_path = dir.join("test.wal");
let pager = Pager::open_default(&db_path).unwrap();
let page = pager.allocate_page(PageType::BTreeLeaf).unwrap();
let page_id = page.page_id();
{
let mut wal_writer = WalWriter::open(&wal_path).unwrap();
wal_writer.append(&WalRecord::Begin { tx_id: 1 }).unwrap();
let mut page_data = [0u8; PAGE_SIZE];
page_data[0] = 0x42; wal_writer
.append(&WalRecord::PageWrite {
tx_id: 1,
page_id,
data: page_data.to_vec(),
})
.unwrap();
wal_writer.append(&WalRecord::Commit { tx_id: 1 }).unwrap();
wal_writer.sync().unwrap();
}
let checkpointer = Checkpointer::new(CheckpointMode::Full);
let result = checkpointer.checkpoint(&pager, &wal_path).unwrap();
assert_eq!(result.transactions_processed, 1);
assert_eq!(result.pages_checkpointed, 1);
assert_eq!(result.records_processed, 3);
let read_page = pager.read_page(page_id).unwrap();
assert_eq!(read_page.as_bytes()[0], 0x42);
cleanup(&dir);
}
#[test]
fn test_checkpoint_aborted_transaction() {
let dir = temp_dir();
let _ = fs::create_dir_all(&dir);
let db_path = dir.join("test.db");
let wal_path = dir.join("test.wal");
let pager = Pager::open_default(&db_path).unwrap();
let page = pager.allocate_page(PageType::BTreeLeaf).unwrap();
let page_id = page.page_id();
{
let mut wal_writer = WalWriter::open(&wal_path).unwrap();
wal_writer.append(&WalRecord::Begin { tx_id: 1 }).unwrap();
let mut page_data = [0u8; PAGE_SIZE];
page_data[0] = 0x42;
wal_writer
.append(&WalRecord::PageWrite {
tx_id: 1,
page_id,
data: page_data.to_vec(),
})
.unwrap();
wal_writer
.append(&WalRecord::Rollback { tx_id: 1 })
.unwrap();
wal_writer.sync().unwrap();
}
let checkpointer = Checkpointer::new(CheckpointMode::Full);
let result = checkpointer.checkpoint(&pager, &wal_path).unwrap();
assert_eq!(result.transactions_processed, 0);
assert_eq!(result.pages_checkpointed, 0);
let read_page = pager.read_page(page_id).unwrap();
assert_ne!(read_page.as_bytes()[0], 0x42);
cleanup(&dir);
}
#[test]
fn test_checkpoint_mixed_transactions() {
let dir = temp_dir();
let _ = fs::create_dir_all(&dir);
let db_path = dir.join("test.db");
let wal_path = dir.join("test.wal");
let pager = Pager::open_default(&db_path).unwrap();
let page1 = pager.allocate_page(PageType::BTreeLeaf).unwrap();
let page2 = pager.allocate_page(PageType::BTreeLeaf).unwrap();
let page1_id = page1.page_id();
let page2_id = page2.page_id();
{
let mut wal_writer = WalWriter::open(&wal_path).unwrap();
wal_writer.append(&WalRecord::Begin { tx_id: 1 }).unwrap();
let mut page_data1 = [0u8; PAGE_SIZE];
page_data1[0] = 0x11;
wal_writer
.append(&WalRecord::PageWrite {
tx_id: 1,
page_id: page1_id,
data: page_data1.to_vec(),
})
.unwrap();
wal_writer.append(&WalRecord::Commit { tx_id: 1 }).unwrap();
wal_writer.append(&WalRecord::Begin { tx_id: 2 }).unwrap();
let mut page_data2 = [0u8; PAGE_SIZE];
page_data2[0] = 0x22;
wal_writer
.append(&WalRecord::PageWrite {
tx_id: 2,
page_id: page2_id,
data: page_data2.to_vec(),
})
.unwrap();
wal_writer
.append(&WalRecord::Rollback { tx_id: 2 })
.unwrap();
wal_writer.append(&WalRecord::Begin { tx_id: 3 }).unwrap();
let mut page_data3 = [0u8; PAGE_SIZE];
page_data3[0] = 0x33;
wal_writer
.append(&WalRecord::PageWrite {
tx_id: 3,
page_id: page2_id,
data: page_data3.to_vec(),
})
.unwrap();
wal_writer.append(&WalRecord::Commit { tx_id: 3 }).unwrap();
wal_writer.sync().unwrap();
}
let checkpointer = Checkpointer::new(CheckpointMode::Full);
let result = checkpointer.checkpoint(&pager, &wal_path).unwrap();
assert_eq!(result.transactions_processed, 2);
assert_eq!(result.pages_checkpointed, 2);
let read_page1 = pager.read_page(page1_id).unwrap();
assert_eq!(read_page1.as_bytes()[0], 0x11);
let read_page2 = pager.read_page(page2_id).unwrap();
assert_eq!(read_page2.as_bytes()[0], 0x33);
cleanup(&dir);
}
#[test]
fn test_checkpoint_truncate() {
let dir = temp_dir();
let _ = fs::create_dir_all(&dir);
let db_path = dir.join("test.db");
let wal_path = dir.join("test.wal");
let pager = Pager::open_default(&db_path).unwrap();
{
let mut wal_writer = WalWriter::open(&wal_path).unwrap();
wal_writer.append(&WalRecord::Begin { tx_id: 1 }).unwrap();
wal_writer.append(&WalRecord::Commit { tx_id: 1 }).unwrap();
wal_writer.sync().unwrap();
}
let checkpointer = Checkpointer::new(CheckpointMode::Truncate);
let result = checkpointer.checkpoint(&pager, &wal_path).unwrap();
assert!(result.wal_truncated);
let wal_size = fs::metadata(&wal_path).unwrap().len();
assert!(
wal_size < 50,
"WAL should be truncated, but size is {}",
wal_size
);
cleanup(&dir);
}
}