reddb_server/storage/wal/
checkpoint.rs1use std::collections::{HashMap, HashSet};
22use std::io;
23use std::path::Path;
24
25use super::reader::WalReader;
26use super::record::WalRecord;
27use super::writer::WalWriter;
28use crate::storage::engine::{Page, Pager, PAGE_SIZE};
29
30#[derive(Debug, Clone, Copy, PartialEq, Eq)]
32pub enum CheckpointMode {
33 Passive,
35 Full,
37 Restart,
39 Truncate,
41}
42
43#[derive(Debug, Clone, Default)]
45pub struct CheckpointResult {
46 pub transactions_processed: u64,
48 pub pages_checkpointed: u64,
50 pub records_processed: u64,
52 pub checkpoint_lsn: u64,
54 pub wal_truncated: bool,
56}
57
58#[derive(Debug)]
60pub enum CheckpointError {
61 Io(io::Error),
63 Pager(String),
65 CorruptedWal(String),
67 NoWal,
69}
70
71impl std::fmt::Display for CheckpointError {
72 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
73 match self {
74 Self::Io(e) => write!(f, "I/O error: {}", e),
75 Self::Pager(msg) => write!(f, "Pager error: {}", msg),
76 Self::CorruptedWal(msg) => write!(f, "Corrupted WAL: {}", msg),
77 Self::NoWal => write!(f, "No WAL file found"),
78 }
79 }
80}
81
82impl std::error::Error for CheckpointError {}
83
84impl From<io::Error> for CheckpointError {
85 fn from(e: io::Error) -> Self {
86 Self::Io(e)
87 }
88}
89
90#[derive(Debug, Clone, Copy, PartialEq, Eq)]
92enum TxState {
93 Active,
94 Committed,
95 Aborted,
96}
97
98#[derive(Debug)]
100struct PendingWrite {
101 tx_id: u64,
102 page_id: u32,
103 data: Vec<u8>,
104 lsn: u64,
105}
106
107pub struct Checkpointer {
111 mode: CheckpointMode,
113}
114
115impl Checkpointer {
116 pub fn new(mode: CheckpointMode) -> Self {
118 Self { mode }
119 }
120
121 pub fn default_mode() -> Self {
123 Self::new(CheckpointMode::Full)
124 }
125
126 pub fn checkpoint(
139 &self,
140 pager: &Pager,
141 wal_path: &Path,
142 ) -> Result<CheckpointResult, CheckpointError> {
143 let wal_reader = match WalReader::open(wal_path) {
145 Ok(r) => r,
146 Err(e) if e.kind() == io::ErrorKind::NotFound => {
147 return Ok(CheckpointResult::default());
149 }
150 Err(e) => return Err(CheckpointError::Io(e)),
151 };
152
153 let mut tx_states: HashMap<u64, TxState> = HashMap::new();
155 let mut pending_writes: Vec<PendingWrite> = Vec::new();
156 let mut records_processed: u64 = 0;
157 let mut last_lsn: u64 = 0;
158
159 for record_result in wal_reader.iter() {
160 let (lsn, record) = record_result.map_err(CheckpointError::Io)?;
161 records_processed += 1;
162 last_lsn = lsn;
163
164 match record {
165 WalRecord::Begin { tx_id } => {
166 tx_states.insert(tx_id, TxState::Active);
167 }
168 WalRecord::Commit { tx_id } => {
169 tx_states.insert(tx_id, TxState::Committed);
170 }
171 WalRecord::Rollback { tx_id } => {
172 tx_states.insert(tx_id, TxState::Aborted);
173 }
174 WalRecord::PageWrite {
175 tx_id,
176 page_id,
177 data,
178 } => {
179 pending_writes.push(PendingWrite {
180 tx_id,
181 page_id,
182 data,
183 lsn,
184 });
185 }
186 WalRecord::Checkpoint {
187 lsn: _checkpoint_lsn,
188 } => {
189 }
192 WalRecord::TxCommitBatch { .. } => {
193 }
196 WalRecord::FullPageImage { .. } => {
197 }
201 }
202 }
203
204 let committed_txs: HashSet<u64> = tx_states
206 .iter()
207 .filter(|(_, state)| **state == TxState::Committed)
208 .map(|(tx_id, _)| *tx_id)
209 .collect();
210
211 let mut latest_writes: HashMap<u32, Vec<u8>> = HashMap::new();
214
215 for write in pending_writes {
216 if committed_txs.contains(&write.tx_id) {
217 latest_writes.insert(write.page_id, write.data);
219 }
220 }
221
222 if !latest_writes.is_empty() {
224 pager
225 .set_checkpoint_in_progress(true, last_lsn)
226 .map_err(|e| CheckpointError::Pager(e.to_string()))?;
227 }
228
229 let mut pages_checkpointed: u64 = 0;
231
232 for (page_id, data) in &latest_writes {
233 if data.len() != PAGE_SIZE {
235 return Err(CheckpointError::CorruptedWal(format!(
236 "Page {} has wrong size: {} (expected {})",
237 page_id,
238 data.len(),
239 PAGE_SIZE
240 )));
241 }
242
243 let mut page_data = [0u8; PAGE_SIZE];
244 page_data.copy_from_slice(data);
245 let page = Page::from_bytes(page_data);
246
247 pager
249 .write_page(*page_id, page)
250 .map_err(|e| CheckpointError::Pager(e.to_string()))?;
251
252 pages_checkpointed += 1;
253 }
254
255 pager
257 .sync()
258 .map_err(|e| CheckpointError::Pager(e.to_string()))?;
259
260 if !latest_writes.is_empty() {
262 pager
263 .complete_checkpoint(last_lsn)
264 .map_err(|e| CheckpointError::Pager(e.to_string()))?;
265 }
266
267 let wal_truncated = matches!(
269 self.mode,
270 CheckpointMode::Restart | CheckpointMode::Truncate
271 );
272
273 if wal_truncated {
274 let mut wal_writer = WalWriter::open(wal_path)?;
275 wal_writer.truncate()?;
276
277 let checkpoint_record = WalRecord::Checkpoint { lsn: last_lsn };
279 wal_writer.append(&checkpoint_record)?;
280 wal_writer.sync()?;
281 }
282
283 Ok(CheckpointResult {
284 transactions_processed: committed_txs.len() as u64,
285 pages_checkpointed,
286 records_processed,
287 checkpoint_lsn: last_lsn,
288 wal_truncated,
289 })
290 }
291
292 pub fn recover(pager: &Pager, wal_path: &Path) -> Result<CheckpointResult, CheckpointError> {
307 if let Ok(header) = pager.header() {
309 if header.checkpoint_in_progress {
310 let _ = pager.set_checkpoint_in_progress(false, 0);
313 }
314 }
315 let checkpointer = Self::new(CheckpointMode::Truncate);
316 checkpointer.checkpoint(pager, wal_path)
317 }
318}
319
320#[cfg(test)]
321mod tests {
322 use super::*;
323 use crate::storage::engine::PageType;
324 use std::fs;
325 use std::time::{SystemTime, UNIX_EPOCH};
326
327 fn temp_dir() -> std::path::PathBuf {
328 let timestamp = SystemTime::now()
329 .duration_since(UNIX_EPOCH)
330 .unwrap()
331 .as_nanos();
332 std::env::temp_dir().join(format!("reddb_checkpoint_test_{}", timestamp))
333 }
334
335 fn cleanup(dir: &Path) {
336 let _ = fs::remove_dir_all(dir);
337 }
338
339 #[test]
340 fn test_checkpoint_empty_wal() {
341 let dir = temp_dir();
342 let _ = fs::create_dir_all(&dir);
343 let db_path = dir.join("test.db");
344 let wal_path = dir.join("test.wal");
345
346 let pager = Pager::open_default(&db_path).unwrap();
348
349 let checkpointer = Checkpointer::default_mode();
351 let result = checkpointer.checkpoint(&pager, &wal_path).unwrap();
352
353 assert_eq!(result.transactions_processed, 0);
354 assert_eq!(result.pages_checkpointed, 0);
355
356 cleanup(&dir);
357 }
358
359 #[test]
360 fn test_checkpoint_committed_transaction() {
361 let dir = temp_dir();
362 let _ = fs::create_dir_all(&dir);
363 let db_path = dir.join("test.db");
364 let wal_path = dir.join("test.wal");
365
366 let pager = Pager::open_default(&db_path).unwrap();
368
369 let page = pager.allocate_page(PageType::BTreeLeaf).unwrap();
371 let page_id = page.page_id();
372
373 {
375 let mut wal_writer = WalWriter::open(&wal_path).unwrap();
376
377 wal_writer.append(&WalRecord::Begin { tx_id: 1 }).unwrap();
379
380 let mut page_data = [0u8; PAGE_SIZE];
382 page_data[0] = 0x42; wal_writer
384 .append(&WalRecord::PageWrite {
385 tx_id: 1,
386 page_id,
387 data: page_data.to_vec(),
388 })
389 .unwrap();
390
391 wal_writer.append(&WalRecord::Commit { tx_id: 1 }).unwrap();
393
394 wal_writer.sync().unwrap();
395 }
396
397 let checkpointer = Checkpointer::new(CheckpointMode::Full);
399 let result = checkpointer.checkpoint(&pager, &wal_path).unwrap();
400
401 assert_eq!(result.transactions_processed, 1);
402 assert_eq!(result.pages_checkpointed, 1);
403 assert_eq!(result.records_processed, 3);
404
405 let read_page = pager.read_page(page_id).unwrap();
407 assert_eq!(read_page.as_bytes()[0], 0x42);
408
409 cleanup(&dir);
410 }
411
412 #[test]
413 fn test_checkpoint_aborted_transaction() {
414 let dir = temp_dir();
415 let _ = fs::create_dir_all(&dir);
416 let db_path = dir.join("test.db");
417 let wal_path = dir.join("test.wal");
418
419 let pager = Pager::open_default(&db_path).unwrap();
421
422 let page = pager.allocate_page(PageType::BTreeLeaf).unwrap();
424 let page_id = page.page_id();
425
426 {
428 let mut wal_writer = WalWriter::open(&wal_path).unwrap();
429
430 wal_writer.append(&WalRecord::Begin { tx_id: 1 }).unwrap();
432
433 let mut page_data = [0u8; PAGE_SIZE];
435 page_data[0] = 0x42;
436 wal_writer
437 .append(&WalRecord::PageWrite {
438 tx_id: 1,
439 page_id,
440 data: page_data.to_vec(),
441 })
442 .unwrap();
443
444 wal_writer
446 .append(&WalRecord::Rollback { tx_id: 1 })
447 .unwrap();
448
449 wal_writer.sync().unwrap();
450 }
451
452 let checkpointer = Checkpointer::new(CheckpointMode::Full);
454 let result = checkpointer.checkpoint(&pager, &wal_path).unwrap();
455
456 assert_eq!(result.transactions_processed, 0);
458 assert_eq!(result.pages_checkpointed, 0);
459
460 let read_page = pager.read_page(page_id).unwrap();
462 assert_ne!(read_page.as_bytes()[0], 0x42);
463
464 cleanup(&dir);
465 }
466
467 #[test]
468 fn test_checkpoint_mixed_transactions() {
469 let dir = temp_dir();
470 let _ = fs::create_dir_all(&dir);
471 let db_path = dir.join("test.db");
472 let wal_path = dir.join("test.wal");
473
474 let pager = Pager::open_default(&db_path).unwrap();
476
477 let page1 = pager.allocate_page(PageType::BTreeLeaf).unwrap();
479 let page2 = pager.allocate_page(PageType::BTreeLeaf).unwrap();
480 let page1_id = page1.page_id();
481 let page2_id = page2.page_id();
482
483 {
485 let mut wal_writer = WalWriter::open(&wal_path).unwrap();
486
487 wal_writer.append(&WalRecord::Begin { tx_id: 1 }).unwrap();
489 let mut page_data1 = [0u8; PAGE_SIZE];
490 page_data1[0] = 0x11;
491 wal_writer
492 .append(&WalRecord::PageWrite {
493 tx_id: 1,
494 page_id: page1_id,
495 data: page_data1.to_vec(),
496 })
497 .unwrap();
498 wal_writer.append(&WalRecord::Commit { tx_id: 1 }).unwrap();
499
500 wal_writer.append(&WalRecord::Begin { tx_id: 2 }).unwrap();
502 let mut page_data2 = [0u8; PAGE_SIZE];
503 page_data2[0] = 0x22;
504 wal_writer
505 .append(&WalRecord::PageWrite {
506 tx_id: 2,
507 page_id: page2_id,
508 data: page_data2.to_vec(),
509 })
510 .unwrap();
511 wal_writer
512 .append(&WalRecord::Rollback { tx_id: 2 })
513 .unwrap();
514
515 wal_writer.append(&WalRecord::Begin { tx_id: 3 }).unwrap();
517 let mut page_data3 = [0u8; PAGE_SIZE];
518 page_data3[0] = 0x33;
519 wal_writer
520 .append(&WalRecord::PageWrite {
521 tx_id: 3,
522 page_id: page2_id,
523 data: page_data3.to_vec(),
524 })
525 .unwrap();
526 wal_writer.append(&WalRecord::Commit { tx_id: 3 }).unwrap();
527
528 wal_writer.sync().unwrap();
529 }
530
531 let checkpointer = Checkpointer::new(CheckpointMode::Full);
533 let result = checkpointer.checkpoint(&pager, &wal_path).unwrap();
534
535 assert_eq!(result.transactions_processed, 2);
537 assert_eq!(result.pages_checkpointed, 2);
538
539 let read_page1 = pager.read_page(page1_id).unwrap();
541 assert_eq!(read_page1.as_bytes()[0], 0x11);
542
543 let read_page2 = pager.read_page(page2_id).unwrap();
544 assert_eq!(read_page2.as_bytes()[0], 0x33); cleanup(&dir);
547 }
548
549 #[test]
550 fn test_checkpoint_truncate() {
551 let dir = temp_dir();
552 let _ = fs::create_dir_all(&dir);
553 let db_path = dir.join("test.db");
554 let wal_path = dir.join("test.wal");
555
556 let pager = Pager::open_default(&db_path).unwrap();
558
559 {
561 let mut wal_writer = WalWriter::open(&wal_path).unwrap();
562 wal_writer.append(&WalRecord::Begin { tx_id: 1 }).unwrap();
563 wal_writer.append(&WalRecord::Commit { tx_id: 1 }).unwrap();
564 wal_writer.sync().unwrap();
565 }
566
567 let checkpointer = Checkpointer::new(CheckpointMode::Truncate);
569 let result = checkpointer.checkpoint(&pager, &wal_path).unwrap();
570
571 assert!(result.wal_truncated);
572
573 let wal_size = fs::metadata(&wal_path).unwrap().len();
575 assert!(
577 wal_size < 50,
578 "WAL should be truncated, but size is {}",
579 wal_size
580 );
581
582 cleanup(&dir);
583 }
584}