reddb_server/storage/wal/
checkpoint.rs1use std::collections::{HashMap, HashSet};
22use std::io;
23use std::path::Path;
24
25use super::reader::WalReader;
26use super::record::WalRecord;
27use super::writer::WalWriter;
28use crate::storage::engine::{Page, Pager, PAGE_SIZE};
29
30#[derive(Debug, Clone, Copy, PartialEq, Eq)]
32pub enum CheckpointMode {
33 Passive,
35 Full,
37 Restart,
39 Truncate,
41}
42
43#[derive(Debug, Clone, Default)]
45pub struct CheckpointResult {
46 pub transactions_processed: u64,
48 pub pages_checkpointed: u64,
50 pub records_processed: u64,
52 pub checkpoint_lsn: u64,
54 pub wal_truncated: bool,
56}
57
58#[derive(Debug)]
60pub enum CheckpointError {
61 Io(io::Error),
63 Pager(String),
65 CorruptedWal(String),
67 NoWal,
69}
70
71impl std::fmt::Display for CheckpointError {
72 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
73 match self {
74 Self::Io(e) => write!(f, "I/O error: {}", e),
75 Self::Pager(msg) => write!(f, "Pager error: {}", msg),
76 Self::CorruptedWal(msg) => write!(f, "Corrupted WAL: {}", msg),
77 Self::NoWal => write!(f, "No WAL file found"),
78 }
79 }
80}
81
82impl std::error::Error for CheckpointError {}
83
84impl From<io::Error> for CheckpointError {
85 fn from(e: io::Error) -> Self {
86 Self::Io(e)
87 }
88}
89
90#[derive(Debug, Clone, Copy, PartialEq, Eq)]
92enum TxState {
93 Active,
94 Committed,
95 Aborted,
96}
97
98#[derive(Debug)]
100struct PendingWrite {
101 tx_id: u64,
102 page_id: u32,
103 data: Vec<u8>,
104 lsn: u64,
105}
106
107pub struct Checkpointer {
111 mode: CheckpointMode,
113}
114
115impl Checkpointer {
116 pub fn new(mode: CheckpointMode) -> Self {
118 Self { mode }
119 }
120
121 pub fn default_mode() -> Self {
123 Self::new(CheckpointMode::Full)
124 }
125
126 pub fn checkpoint(
139 &self,
140 pager: &Pager,
141 wal_path: &Path,
142 ) -> Result<CheckpointResult, CheckpointError> {
143 let wal_reader = match WalReader::open(wal_path) {
145 Ok(r) => r,
146 Err(e) if e.kind() == io::ErrorKind::NotFound => {
147 return Ok(CheckpointResult::default());
149 }
150 Err(e) => return Err(CheckpointError::Io(e)),
151 };
152
153 let mut tx_states: HashMap<u64, TxState> = HashMap::new();
155 let mut pending_writes: Vec<PendingWrite> = Vec::new();
156 let mut records_processed: u64 = 0;
157 let mut last_lsn: u64 = 0;
158
159 for record_result in wal_reader.iter() {
160 let (lsn, record) = record_result.map_err(CheckpointError::Io)?;
161 records_processed += 1;
162 last_lsn = lsn;
163
164 match record {
165 WalRecord::Begin { tx_id } => {
166 tx_states.insert(tx_id, TxState::Active);
167 }
168 WalRecord::Commit { tx_id } => {
169 tx_states.insert(tx_id, TxState::Committed);
170 }
171 WalRecord::Rollback { tx_id } => {
172 tx_states.insert(tx_id, TxState::Aborted);
173 }
174 WalRecord::PageWrite {
175 tx_id,
176 page_id,
177 data,
178 } => {
179 pending_writes.push(PendingWrite {
180 tx_id,
181 page_id,
182 data,
183 lsn,
184 });
185 }
186 WalRecord::Checkpoint {
187 lsn: _checkpoint_lsn,
188 } => {
189 }
192 }
193 }
194
195 let committed_txs: HashSet<u64> = tx_states
197 .iter()
198 .filter(|(_, state)| **state == TxState::Committed)
199 .map(|(tx_id, _)| *tx_id)
200 .collect();
201
202 let mut latest_writes: HashMap<u32, Vec<u8>> = HashMap::new();
205
206 for write in pending_writes {
207 if committed_txs.contains(&write.tx_id) {
208 latest_writes.insert(write.page_id, write.data);
210 }
211 }
212
213 if !latest_writes.is_empty() {
215 pager
216 .set_checkpoint_in_progress(true, last_lsn)
217 .map_err(|e| CheckpointError::Pager(e.to_string()))?;
218 }
219
220 let mut pages_checkpointed: u64 = 0;
222
223 for (page_id, data) in &latest_writes {
224 if data.len() != PAGE_SIZE {
226 return Err(CheckpointError::CorruptedWal(format!(
227 "Page {} has wrong size: {} (expected {})",
228 page_id,
229 data.len(),
230 PAGE_SIZE
231 )));
232 }
233
234 let mut page_data = [0u8; PAGE_SIZE];
235 page_data.copy_from_slice(data);
236 let page = Page::from_bytes(page_data);
237
238 pager
240 .write_page(*page_id, page)
241 .map_err(|e| CheckpointError::Pager(e.to_string()))?;
242
243 pages_checkpointed += 1;
244 }
245
246 pager
248 .sync()
249 .map_err(|e| CheckpointError::Pager(e.to_string()))?;
250
251 if !latest_writes.is_empty() {
253 pager
254 .complete_checkpoint(last_lsn)
255 .map_err(|e| CheckpointError::Pager(e.to_string()))?;
256 }
257
258 let wal_truncated = matches!(
260 self.mode,
261 CheckpointMode::Restart | CheckpointMode::Truncate
262 );
263
264 if wal_truncated {
265 let mut wal_writer = WalWriter::open(wal_path)?;
266 wal_writer.truncate()?;
267
268 let checkpoint_record = WalRecord::Checkpoint { lsn: last_lsn };
270 wal_writer.append(&checkpoint_record)?;
271 wal_writer.sync()?;
272 }
273
274 Ok(CheckpointResult {
275 transactions_processed: committed_txs.len() as u64,
276 pages_checkpointed,
277 records_processed,
278 checkpoint_lsn: last_lsn,
279 wal_truncated,
280 })
281 }
282
283 pub fn recover(pager: &Pager, wal_path: &Path) -> Result<CheckpointResult, CheckpointError> {
298 if let Ok(header) = pager.header() {
300 if header.checkpoint_in_progress {
301 let _ = pager.set_checkpoint_in_progress(false, 0);
304 }
305 }
306 let checkpointer = Self::new(CheckpointMode::Truncate);
307 checkpointer.checkpoint(pager, wal_path)
308 }
309}
310
311#[cfg(test)]
312mod tests {
313 use super::*;
314 use crate::storage::engine::PageType;
315 use std::fs;
316 use std::time::{SystemTime, UNIX_EPOCH};
317
318 fn temp_dir() -> std::path::PathBuf {
319 let timestamp = SystemTime::now()
320 .duration_since(UNIX_EPOCH)
321 .unwrap()
322 .as_nanos();
323 std::env::temp_dir().join(format!("reddb_checkpoint_test_{}", timestamp))
324 }
325
326 fn cleanup(dir: &Path) {
327 let _ = fs::remove_dir_all(dir);
328 }
329
330 #[test]
331 fn test_checkpoint_empty_wal() {
332 let dir = temp_dir();
333 let _ = fs::create_dir_all(&dir);
334 let db_path = dir.join("test.db");
335 let wal_path = dir.join("test.wal");
336
337 let pager = Pager::open_default(&db_path).unwrap();
339
340 let checkpointer = Checkpointer::default_mode();
342 let result = checkpointer.checkpoint(&pager, &wal_path).unwrap();
343
344 assert_eq!(result.transactions_processed, 0);
345 assert_eq!(result.pages_checkpointed, 0);
346
347 cleanup(&dir);
348 }
349
350 #[test]
351 fn test_checkpoint_committed_transaction() {
352 let dir = temp_dir();
353 let _ = fs::create_dir_all(&dir);
354 let db_path = dir.join("test.db");
355 let wal_path = dir.join("test.wal");
356
357 let pager = Pager::open_default(&db_path).unwrap();
359
360 let page = pager.allocate_page(PageType::BTreeLeaf).unwrap();
362 let page_id = page.page_id();
363
364 {
366 let mut wal_writer = WalWriter::open(&wal_path).unwrap();
367
368 wal_writer.append(&WalRecord::Begin { tx_id: 1 }).unwrap();
370
371 let mut page_data = [0u8; PAGE_SIZE];
373 page_data[0] = 0x42; wal_writer
375 .append(&WalRecord::PageWrite {
376 tx_id: 1,
377 page_id,
378 data: page_data.to_vec(),
379 })
380 .unwrap();
381
382 wal_writer.append(&WalRecord::Commit { tx_id: 1 }).unwrap();
384
385 wal_writer.sync().unwrap();
386 }
387
388 let checkpointer = Checkpointer::new(CheckpointMode::Full);
390 let result = checkpointer.checkpoint(&pager, &wal_path).unwrap();
391
392 assert_eq!(result.transactions_processed, 1);
393 assert_eq!(result.pages_checkpointed, 1);
394 assert_eq!(result.records_processed, 3);
395
396 let read_page = pager.read_page(page_id).unwrap();
398 assert_eq!(read_page.as_bytes()[0], 0x42);
399
400 cleanup(&dir);
401 }
402
403 #[test]
404 fn test_checkpoint_aborted_transaction() {
405 let dir = temp_dir();
406 let _ = fs::create_dir_all(&dir);
407 let db_path = dir.join("test.db");
408 let wal_path = dir.join("test.wal");
409
410 let pager = Pager::open_default(&db_path).unwrap();
412
413 let page = pager.allocate_page(PageType::BTreeLeaf).unwrap();
415 let page_id = page.page_id();
416
417 {
419 let mut wal_writer = WalWriter::open(&wal_path).unwrap();
420
421 wal_writer.append(&WalRecord::Begin { tx_id: 1 }).unwrap();
423
424 let mut page_data = [0u8; PAGE_SIZE];
426 page_data[0] = 0x42;
427 wal_writer
428 .append(&WalRecord::PageWrite {
429 tx_id: 1,
430 page_id,
431 data: page_data.to_vec(),
432 })
433 .unwrap();
434
435 wal_writer
437 .append(&WalRecord::Rollback { tx_id: 1 })
438 .unwrap();
439
440 wal_writer.sync().unwrap();
441 }
442
443 let checkpointer = Checkpointer::new(CheckpointMode::Full);
445 let result = checkpointer.checkpoint(&pager, &wal_path).unwrap();
446
447 assert_eq!(result.transactions_processed, 0);
449 assert_eq!(result.pages_checkpointed, 0);
450
451 let read_page = pager.read_page(page_id).unwrap();
453 assert_ne!(read_page.as_bytes()[0], 0x42);
454
455 cleanup(&dir);
456 }
457
458 #[test]
459 fn test_checkpoint_mixed_transactions() {
460 let dir = temp_dir();
461 let _ = fs::create_dir_all(&dir);
462 let db_path = dir.join("test.db");
463 let wal_path = dir.join("test.wal");
464
465 let pager = Pager::open_default(&db_path).unwrap();
467
468 let page1 = pager.allocate_page(PageType::BTreeLeaf).unwrap();
470 let page2 = pager.allocate_page(PageType::BTreeLeaf).unwrap();
471 let page1_id = page1.page_id();
472 let page2_id = page2.page_id();
473
474 {
476 let mut wal_writer = WalWriter::open(&wal_path).unwrap();
477
478 wal_writer.append(&WalRecord::Begin { tx_id: 1 }).unwrap();
480 let mut page_data1 = [0u8; PAGE_SIZE];
481 page_data1[0] = 0x11;
482 wal_writer
483 .append(&WalRecord::PageWrite {
484 tx_id: 1,
485 page_id: page1_id,
486 data: page_data1.to_vec(),
487 })
488 .unwrap();
489 wal_writer.append(&WalRecord::Commit { tx_id: 1 }).unwrap();
490
491 wal_writer.append(&WalRecord::Begin { tx_id: 2 }).unwrap();
493 let mut page_data2 = [0u8; PAGE_SIZE];
494 page_data2[0] = 0x22;
495 wal_writer
496 .append(&WalRecord::PageWrite {
497 tx_id: 2,
498 page_id: page2_id,
499 data: page_data2.to_vec(),
500 })
501 .unwrap();
502 wal_writer
503 .append(&WalRecord::Rollback { tx_id: 2 })
504 .unwrap();
505
506 wal_writer.append(&WalRecord::Begin { tx_id: 3 }).unwrap();
508 let mut page_data3 = [0u8; PAGE_SIZE];
509 page_data3[0] = 0x33;
510 wal_writer
511 .append(&WalRecord::PageWrite {
512 tx_id: 3,
513 page_id: page2_id,
514 data: page_data3.to_vec(),
515 })
516 .unwrap();
517 wal_writer.append(&WalRecord::Commit { tx_id: 3 }).unwrap();
518
519 wal_writer.sync().unwrap();
520 }
521
522 let checkpointer = Checkpointer::new(CheckpointMode::Full);
524 let result = checkpointer.checkpoint(&pager, &wal_path).unwrap();
525
526 assert_eq!(result.transactions_processed, 2);
528 assert_eq!(result.pages_checkpointed, 2);
529
530 let read_page1 = pager.read_page(page1_id).unwrap();
532 assert_eq!(read_page1.as_bytes()[0], 0x11);
533
534 let read_page2 = pager.read_page(page2_id).unwrap();
535 assert_eq!(read_page2.as_bytes()[0], 0x33); cleanup(&dir);
538 }
539
540 #[test]
541 fn test_checkpoint_truncate() {
542 let dir = temp_dir();
543 let _ = fs::create_dir_all(&dir);
544 let db_path = dir.join("test.db");
545 let wal_path = dir.join("test.wal");
546
547 let pager = Pager::open_default(&db_path).unwrap();
549
550 {
552 let mut wal_writer = WalWriter::open(&wal_path).unwrap();
553 wal_writer.append(&WalRecord::Begin { tx_id: 1 }).unwrap();
554 wal_writer.append(&WalRecord::Commit { tx_id: 1 }).unwrap();
555 wal_writer.sync().unwrap();
556 }
557
558 let checkpointer = Checkpointer::new(CheckpointMode::Truncate);
560 let result = checkpointer.checkpoint(&pager, &wal_path).unwrap();
561
562 assert!(result.wal_truncated);
563
564 let wal_size = fs::metadata(&wal_path).unwrap().len();
566 assert!(
568 wal_size < 50,
569 "WAL should be truncated, but size is {}",
570 wal_size
571 );
572
573 cleanup(&dir);
574 }
575}