reddb_server/storage/wal/
checkpoint.rs1use std::collections::{HashMap, HashSet};
22use std::io;
23use std::path::Path;
24
25use super::reader::WalReader;
26use super::record::WalRecord;
27use super::writer::WalWriter;
28use crate::storage::engine::{Page, Pager, PAGE_SIZE};
29
30#[derive(Debug, Clone, Copy, PartialEq, Eq)]
32pub enum CheckpointMode {
33 Passive,
35 Full,
37 Restart,
39 Truncate,
41}
42
43#[derive(Debug, Clone, Default)]
45pub struct CheckpointResult {
46 pub transactions_processed: u64,
48 pub pages_checkpointed: u64,
50 pub records_processed: u64,
52 pub checkpoint_lsn: u64,
54 pub wal_truncated: bool,
56}
57
58#[derive(Debug)]
60pub enum CheckpointError {
61 Io(io::Error),
63 Pager(String),
65 CorruptedWal(String),
67 NoWal,
69}
70
71impl std::fmt::Display for CheckpointError {
72 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
73 match self {
74 Self::Io(e) => write!(f, "I/O error: {}", e),
75 Self::Pager(msg) => write!(f, "Pager error: {}", msg),
76 Self::CorruptedWal(msg) => write!(f, "Corrupted WAL: {}", msg),
77 Self::NoWal => write!(f, "No WAL file found"),
78 }
79 }
80}
81
82impl std::error::Error for CheckpointError {}
83
84impl From<io::Error> for CheckpointError {
85 fn from(e: io::Error) -> Self {
86 Self::Io(e)
87 }
88}
89
90#[derive(Debug, Clone, Copy, PartialEq, Eq)]
92enum TxState {
93 Active,
94 Committed,
95 Aborted,
96}
97
98#[derive(Debug)]
100struct PendingWrite {
101 tx_id: u64,
102 page_id: u32,
103 data: Vec<u8>,
104 lsn: u64,
105}
106
107pub struct Checkpointer {
111 mode: CheckpointMode,
113}
114
115impl Checkpointer {
116 pub fn new(mode: CheckpointMode) -> Self {
118 Self { mode }
119 }
120
121 pub fn default_mode() -> Self {
123 Self::new(CheckpointMode::Full)
124 }
125
126 pub fn checkpoint(
139 &self,
140 pager: &Pager,
141 wal_path: &Path,
142 ) -> Result<CheckpointResult, CheckpointError> {
143 let wal_reader = match WalReader::open(wal_path) {
145 Ok(r) => r,
146 Err(e) if e.kind() == io::ErrorKind::NotFound => {
147 return Ok(CheckpointResult::default());
149 }
150 Err(e) => return Err(CheckpointError::Io(e)),
151 };
152
153 let mut tx_states: HashMap<u64, TxState> = HashMap::new();
155 let mut pending_writes: Vec<PendingWrite> = Vec::new();
156 let mut records_processed: u64 = 0;
157 let mut last_lsn: u64 = 0;
158
159 for record_result in wal_reader.iter() {
160 let (lsn, record) = record_result.map_err(CheckpointError::Io)?;
161 records_processed += 1;
162 last_lsn = lsn;
163
164 match record {
165 WalRecord::Begin { tx_id } => {
166 tx_states.insert(tx_id, TxState::Active);
167 }
168 WalRecord::Commit { tx_id } => {
169 tx_states.insert(tx_id, TxState::Committed);
170 }
171 WalRecord::Rollback { tx_id } => {
172 tx_states.insert(tx_id, TxState::Aborted);
173 }
174 WalRecord::PageWrite {
175 tx_id,
176 page_id,
177 data,
178 } => {
179 pending_writes.push(PendingWrite {
180 tx_id,
181 page_id,
182 data,
183 lsn,
184 });
185 }
186 WalRecord::Checkpoint {
187 lsn: _checkpoint_lsn,
188 } => {
189 }
192 WalRecord::TxCommitBatch { .. } => {
193 }
196 }
197 }
198
199 let committed_txs: HashSet<u64> = tx_states
201 .iter()
202 .filter(|(_, state)| **state == TxState::Committed)
203 .map(|(tx_id, _)| *tx_id)
204 .collect();
205
206 let mut latest_writes: HashMap<u32, Vec<u8>> = HashMap::new();
209
210 for write in pending_writes {
211 if committed_txs.contains(&write.tx_id) {
212 latest_writes.insert(write.page_id, write.data);
214 }
215 }
216
217 if !latest_writes.is_empty() {
219 pager
220 .set_checkpoint_in_progress(true, last_lsn)
221 .map_err(|e| CheckpointError::Pager(e.to_string()))?;
222 }
223
224 let mut pages_checkpointed: u64 = 0;
226
227 for (page_id, data) in &latest_writes {
228 if data.len() != PAGE_SIZE {
230 return Err(CheckpointError::CorruptedWal(format!(
231 "Page {} has wrong size: {} (expected {})",
232 page_id,
233 data.len(),
234 PAGE_SIZE
235 )));
236 }
237
238 let mut page_data = [0u8; PAGE_SIZE];
239 page_data.copy_from_slice(data);
240 let page = Page::from_bytes(page_data);
241
242 pager
244 .write_page(*page_id, page)
245 .map_err(|e| CheckpointError::Pager(e.to_string()))?;
246
247 pages_checkpointed += 1;
248 }
249
250 pager
252 .sync()
253 .map_err(|e| CheckpointError::Pager(e.to_string()))?;
254
255 if !latest_writes.is_empty() {
257 pager
258 .complete_checkpoint(last_lsn)
259 .map_err(|e| CheckpointError::Pager(e.to_string()))?;
260 }
261
262 let wal_truncated = matches!(
264 self.mode,
265 CheckpointMode::Restart | CheckpointMode::Truncate
266 );
267
268 if wal_truncated {
269 let mut wal_writer = WalWriter::open(wal_path)?;
270 wal_writer.truncate()?;
271
272 let checkpoint_record = WalRecord::Checkpoint { lsn: last_lsn };
274 wal_writer.append(&checkpoint_record)?;
275 wal_writer.sync()?;
276 }
277
278 Ok(CheckpointResult {
279 transactions_processed: committed_txs.len() as u64,
280 pages_checkpointed,
281 records_processed,
282 checkpoint_lsn: last_lsn,
283 wal_truncated,
284 })
285 }
286
287 pub fn recover(pager: &Pager, wal_path: &Path) -> Result<CheckpointResult, CheckpointError> {
302 if let Ok(header) = pager.header() {
304 if header.checkpoint_in_progress {
305 let _ = pager.set_checkpoint_in_progress(false, 0);
308 }
309 }
310 let checkpointer = Self::new(CheckpointMode::Truncate);
311 checkpointer.checkpoint(pager, wal_path)
312 }
313}
314
315#[cfg(test)]
316mod tests {
317 use super::*;
318 use crate::storage::engine::PageType;
319 use std::fs;
320 use std::time::{SystemTime, UNIX_EPOCH};
321
322 fn temp_dir() -> std::path::PathBuf {
323 let timestamp = SystemTime::now()
324 .duration_since(UNIX_EPOCH)
325 .unwrap()
326 .as_nanos();
327 std::env::temp_dir().join(format!("reddb_checkpoint_test_{}", timestamp))
328 }
329
330 fn cleanup(dir: &Path) {
331 let _ = fs::remove_dir_all(dir);
332 }
333
334 #[test]
335 fn test_checkpoint_empty_wal() {
336 let dir = temp_dir();
337 let _ = fs::create_dir_all(&dir);
338 let db_path = dir.join("test.db");
339 let wal_path = dir.join("test.wal");
340
341 let pager = Pager::open_default(&db_path).unwrap();
343
344 let checkpointer = Checkpointer::default_mode();
346 let result = checkpointer.checkpoint(&pager, &wal_path).unwrap();
347
348 assert_eq!(result.transactions_processed, 0);
349 assert_eq!(result.pages_checkpointed, 0);
350
351 cleanup(&dir);
352 }
353
354 #[test]
355 fn test_checkpoint_committed_transaction() {
356 let dir = temp_dir();
357 let _ = fs::create_dir_all(&dir);
358 let db_path = dir.join("test.db");
359 let wal_path = dir.join("test.wal");
360
361 let pager = Pager::open_default(&db_path).unwrap();
363
364 let page = pager.allocate_page(PageType::BTreeLeaf).unwrap();
366 let page_id = page.page_id();
367
368 {
370 let mut wal_writer = WalWriter::open(&wal_path).unwrap();
371
372 wal_writer.append(&WalRecord::Begin { tx_id: 1 }).unwrap();
374
375 let mut page_data = [0u8; PAGE_SIZE];
377 page_data[0] = 0x42; wal_writer
379 .append(&WalRecord::PageWrite {
380 tx_id: 1,
381 page_id,
382 data: page_data.to_vec(),
383 })
384 .unwrap();
385
386 wal_writer.append(&WalRecord::Commit { tx_id: 1 }).unwrap();
388
389 wal_writer.sync().unwrap();
390 }
391
392 let checkpointer = Checkpointer::new(CheckpointMode::Full);
394 let result = checkpointer.checkpoint(&pager, &wal_path).unwrap();
395
396 assert_eq!(result.transactions_processed, 1);
397 assert_eq!(result.pages_checkpointed, 1);
398 assert_eq!(result.records_processed, 3);
399
400 let read_page = pager.read_page(page_id).unwrap();
402 assert_eq!(read_page.as_bytes()[0], 0x42);
403
404 cleanup(&dir);
405 }
406
407 #[test]
408 fn test_checkpoint_aborted_transaction() {
409 let dir = temp_dir();
410 let _ = fs::create_dir_all(&dir);
411 let db_path = dir.join("test.db");
412 let wal_path = dir.join("test.wal");
413
414 let pager = Pager::open_default(&db_path).unwrap();
416
417 let page = pager.allocate_page(PageType::BTreeLeaf).unwrap();
419 let page_id = page.page_id();
420
421 {
423 let mut wal_writer = WalWriter::open(&wal_path).unwrap();
424
425 wal_writer.append(&WalRecord::Begin { tx_id: 1 }).unwrap();
427
428 let mut page_data = [0u8; PAGE_SIZE];
430 page_data[0] = 0x42;
431 wal_writer
432 .append(&WalRecord::PageWrite {
433 tx_id: 1,
434 page_id,
435 data: page_data.to_vec(),
436 })
437 .unwrap();
438
439 wal_writer
441 .append(&WalRecord::Rollback { tx_id: 1 })
442 .unwrap();
443
444 wal_writer.sync().unwrap();
445 }
446
447 let checkpointer = Checkpointer::new(CheckpointMode::Full);
449 let result = checkpointer.checkpoint(&pager, &wal_path).unwrap();
450
451 assert_eq!(result.transactions_processed, 0);
453 assert_eq!(result.pages_checkpointed, 0);
454
455 let read_page = pager.read_page(page_id).unwrap();
457 assert_ne!(read_page.as_bytes()[0], 0x42);
458
459 cleanup(&dir);
460 }
461
462 #[test]
463 fn test_checkpoint_mixed_transactions() {
464 let dir = temp_dir();
465 let _ = fs::create_dir_all(&dir);
466 let db_path = dir.join("test.db");
467 let wal_path = dir.join("test.wal");
468
469 let pager = Pager::open_default(&db_path).unwrap();
471
472 let page1 = pager.allocate_page(PageType::BTreeLeaf).unwrap();
474 let page2 = pager.allocate_page(PageType::BTreeLeaf).unwrap();
475 let page1_id = page1.page_id();
476 let page2_id = page2.page_id();
477
478 {
480 let mut wal_writer = WalWriter::open(&wal_path).unwrap();
481
482 wal_writer.append(&WalRecord::Begin { tx_id: 1 }).unwrap();
484 let mut page_data1 = [0u8; PAGE_SIZE];
485 page_data1[0] = 0x11;
486 wal_writer
487 .append(&WalRecord::PageWrite {
488 tx_id: 1,
489 page_id: page1_id,
490 data: page_data1.to_vec(),
491 })
492 .unwrap();
493 wal_writer.append(&WalRecord::Commit { tx_id: 1 }).unwrap();
494
495 wal_writer.append(&WalRecord::Begin { tx_id: 2 }).unwrap();
497 let mut page_data2 = [0u8; PAGE_SIZE];
498 page_data2[0] = 0x22;
499 wal_writer
500 .append(&WalRecord::PageWrite {
501 tx_id: 2,
502 page_id: page2_id,
503 data: page_data2.to_vec(),
504 })
505 .unwrap();
506 wal_writer
507 .append(&WalRecord::Rollback { tx_id: 2 })
508 .unwrap();
509
510 wal_writer.append(&WalRecord::Begin { tx_id: 3 }).unwrap();
512 let mut page_data3 = [0u8; PAGE_SIZE];
513 page_data3[0] = 0x33;
514 wal_writer
515 .append(&WalRecord::PageWrite {
516 tx_id: 3,
517 page_id: page2_id,
518 data: page_data3.to_vec(),
519 })
520 .unwrap();
521 wal_writer.append(&WalRecord::Commit { tx_id: 3 }).unwrap();
522
523 wal_writer.sync().unwrap();
524 }
525
526 let checkpointer = Checkpointer::new(CheckpointMode::Full);
528 let result = checkpointer.checkpoint(&pager, &wal_path).unwrap();
529
530 assert_eq!(result.transactions_processed, 2);
532 assert_eq!(result.pages_checkpointed, 2);
533
534 let read_page1 = pager.read_page(page1_id).unwrap();
536 assert_eq!(read_page1.as_bytes()[0], 0x11);
537
538 let read_page2 = pager.read_page(page2_id).unwrap();
539 assert_eq!(read_page2.as_bytes()[0], 0x33); cleanup(&dir);
542 }
543
544 #[test]
545 fn test_checkpoint_truncate() {
546 let dir = temp_dir();
547 let _ = fs::create_dir_all(&dir);
548 let db_path = dir.join("test.db");
549 let wal_path = dir.join("test.wal");
550
551 let pager = Pager::open_default(&db_path).unwrap();
553
554 {
556 let mut wal_writer = WalWriter::open(&wal_path).unwrap();
557 wal_writer.append(&WalRecord::Begin { tx_id: 1 }).unwrap();
558 wal_writer.append(&WalRecord::Commit { tx_id: 1 }).unwrap();
559 wal_writer.sync().unwrap();
560 }
561
562 let checkpointer = Checkpointer::new(CheckpointMode::Truncate);
564 let result = checkpointer.checkpoint(&pager, &wal_path).unwrap();
565
566 assert!(result.wal_truncated);
567
568 let wal_size = fs::metadata(&wal_path).unwrap().len();
570 assert!(
572 wal_size < 50,
573 "WAL should be truncated, but size is {}",
574 wal_size
575 );
576
577 cleanup(&dir);
578 }
579}