Skip to main content

dbx_core/wal/
checkpoint.rs

1//! Checkpoint manager for WAL maintenance and crash recovery.
2//!
3//! The checkpoint manager coordinates with the WAL to:
4//! - Apply WAL changes to the persistent storage
5//! - Trim old WAL records after successful checkpoint
6//! - Recover database state by replaying WAL records
7//!
8//! # Example
9//!
10//! ```rust
11//! use dbx_core::wal::WriteAheadLog;
12//! use dbx_core::wal::checkpoint::CheckpointManager;
13//! use std::sync::Arc;
14//! use std::path::Path;
15//!
16//! # fn main() -> dbx_core::DbxResult<()> {
17//! let wal = Arc::new(WriteAheadLog::open(Path::new("./wal.log"))?);
18//! let checkpoint_mgr = CheckpointManager::new(wal, Path::new("./wal.log"));
19//!
20//! // Perform checkpoint (apply WAL to storage)
21//! // checkpoint_mgr.checkpoint(&db)?;
22//! # Ok(())
23//! # }
24//! ```
25
26use crate::error::{DbxError, DbxResult};
27use crate::wal::{WalRecord, WriteAheadLog};
28use std::fs::OpenOptions;
29use std::io::Write;
30use std::path::{Path, PathBuf};
31use std::sync::Arc;
32use std::time::Duration;
33
34/// Checkpoint manager for WAL maintenance.
35///
36/// Manages periodic checkpoints and WAL trimming to keep the WAL file size bounded.
37pub struct CheckpointManager {
38    /// Reference to the WAL
39    wal: Arc<WriteAheadLog>,
40
41    /// Checkpoint interval (default: 30 seconds)
42    interval: Duration,
43
44    /// Path to the WAL file (for trimming)
45    wal_path: PathBuf,
46}
47
48impl CheckpointManager {
49    /// Creates a new checkpoint manager.
50    ///
51    /// # Arguments
52    ///
53    /// * `wal` - Shared reference to the WAL
54    /// * `wal_path` - Path to the WAL file
55    ///
56    /// # Example
57    ///
58    /// ```rust
59    /// # use dbx_core::wal::WriteAheadLog;
60    /// # use dbx_core::wal::checkpoint::CheckpointManager;
61    /// # use std::sync::Arc;
62    /// # use std::path::Path;
63    /// # fn main() -> dbx_core::DbxResult<()> {
64    /// let wal = Arc::new(WriteAheadLog::open(Path::new("./wal.log"))?);
65    /// let checkpoint_mgr = CheckpointManager::new(wal, Path::new("./wal.log"));
66    /// # Ok(())
67    /// # }
68    /// ```
69    pub fn new(wal: Arc<WriteAheadLog>, wal_path: &Path) -> Self {
70        Self {
71            wal,
72            interval: Duration::from_secs(30),
73            wal_path: wal_path.to_path_buf(),
74        }
75    }
76
77    /// Sets the checkpoint interval.
78    ///
79    /// # Arguments
80    ///
81    /// * `interval` - Duration between checkpoints
82    pub fn with_interval(mut self, interval: Duration) -> Self {
83        self.interval = interval;
84        self
85    }
86
87    /// Performs a checkpoint.
88    ///
89    /// Applies all WAL records to the database and writes a checkpoint marker.
90    /// This method should be called by the Database engine.
91    ///
92    /// # Arguments
93    ///
94    /// * `apply_fn` - Function to apply a WAL record to the database
95    ///
96    /// # Returns
97    ///
98    /// The sequence number of the checkpoint
99    ///
100    /// # Example
101    ///
102    /// ```rust,no_run
103    /// # use dbx_core::wal::WriteAheadLog;
104    /// # use dbx_core::wal::checkpoint::CheckpointManager;
105    /// # use dbx_core::wal::WalRecord;
106    /// # use std::sync::Arc;
107    /// # use std::path::Path;
108    /// # fn main() -> dbx_core::DbxResult<()> {
109    /// let wal = Arc::new(WriteAheadLog::open(Path::new("./wal.log"))?);
110    /// let checkpoint_mgr = CheckpointManager::new(wal, Path::new("./wal.log"));
111    ///
112    /// let apply_fn = |record: &WalRecord| -> dbx_core::DbxResult<()> {
113    ///     // Apply record to database
114    ///     Ok(())
115    /// };
116    ///
117    /// let checkpoint_seq = checkpoint_mgr.checkpoint(apply_fn)?;
118    /// # Ok(())
119    /// # }
120    /// ```
121    pub fn checkpoint<F>(&self, apply_fn: F) -> DbxResult<u64>
122    where
123        F: Fn(&WalRecord) -> DbxResult<()>,
124    {
125        // Replay all WAL records
126        let records = self.wal.replay()?;
127
128        for record in &records {
129            // Skip checkpoint markers
130            if matches!(record, WalRecord::Checkpoint { .. }) {
131                continue;
132            }
133
134            // Apply record to database
135            apply_fn(record)?;
136        }
137
138        // Write checkpoint marker
139        let seq = self.wal.current_sequence();
140        let checkpoint_record = WalRecord::Checkpoint { sequence: seq };
141        self.wal.append(&checkpoint_record)?;
142        self.wal.sync()?;
143
144        Ok(seq)
145    }
146
147    /// Recovers the database by replaying WAL records.
148    ///
149    /// This is called during database startup to restore the state after a crash.
150    ///
151    /// # Arguments
152    ///
153    /// * `wal_path` - Path to the WAL file
154    /// * `apply_fn` - Function to apply a WAL record to the database
155    ///
156    /// # Returns
157    ///
158    /// The number of records replayed
159    ///
160    /// # Example
161    ///
162    /// ```rust,no_run
163    /// # use dbx_core::wal::checkpoint::CheckpointManager;
164    /// # use dbx_core::wal::WalRecord;
165    /// # use std::path::Path;
166    /// # fn main() -> dbx_core::DbxResult<()> {
167    /// let apply_fn = |record: &WalRecord| -> dbx_core::DbxResult<()> {
168    ///     // Apply record to database
169    ///     Ok(())
170    /// };
171    ///
172    /// let count = CheckpointManager::recover(Path::new("./wal.log"), apply_fn)?;
173    /// println!("Replayed {} records", count);
174    /// # Ok(())
175    /// # }
176    /// ```
177    pub fn recover<F>(wal_path: &Path, apply_fn: F) -> DbxResult<usize>
178    where
179        F: Fn(&WalRecord) -> DbxResult<()>,
180    {
181        // Check if WAL file exists
182        if !wal_path.exists() {
183            return Ok(0);
184        }
185
186        let wal = WriteAheadLog::open(wal_path)?;
187        let records = wal.replay()?;
188
189        // Find the last checkpoint
190        let mut last_checkpoint_idx = None;
191        for (i, record) in records.iter().enumerate() {
192            if matches!(record, WalRecord::Checkpoint { .. }) {
193                last_checkpoint_idx = Some(i);
194            }
195        }
196
197        // Replay records after the last checkpoint
198        let start_idx = last_checkpoint_idx.map(|i| i + 1).unwrap_or(0);
199        let replay_count = records.len() - start_idx;
200
201        for record in &records[start_idx..] {
202            apply_fn(record)?;
203        }
204
205        Ok(replay_count)
206    }
207
208    /// Trims the WAL file by removing records before the specified sequence.
209    ///
210    /// This is called after a successful checkpoint to keep the WAL file size bounded.
211    ///
212    /// # Arguments
213    ///
214    /// * `sequence` - Sequence number to trim before
215    ///
216    /// # Example
217    ///
218    /// ```rust,no_run
219    /// # use dbx_core::wal::WriteAheadLog;
220    /// # use dbx_core::wal::checkpoint::CheckpointManager;
221    /// # use std::sync::Arc;
222    /// # use std::path::Path;
223    /// # fn main() -> dbx_core::DbxResult<()> {
224    /// let wal = Arc::new(WriteAheadLog::open(Path::new("./wal.log"))?);
225    /// let checkpoint_mgr = CheckpointManager::new(wal, Path::new("./wal.log"));
226    ///
227    /// // Trim records before sequence 100
228    /// checkpoint_mgr.trim_before(100)?;
229    /// # Ok(())
230    /// # }
231    /// ```
232    pub fn trim_before(&self, sequence: u64) -> DbxResult<()> {
233        // Read all records
234        let records = self.wal.replay()?;
235
236        // Find the last checkpoint with sequence >= target
237        let mut last_checkpoint_idx = None;
238        for (i, record) in records.iter().enumerate() {
239            if let WalRecord::Checkpoint { sequence: seq } = record
240                && *seq >= sequence
241            {
242                last_checkpoint_idx = Some(i);
243            }
244        }
245
246        // Keep only records from the last checkpoint onwards
247        let trimmed_records: Vec<WalRecord> = if let Some(idx) = last_checkpoint_idx {
248            records.into_iter().skip(idx).collect()
249        } else {
250            // No checkpoint found, keep all records
251            records
252        };
253
254        // Write trimmed records to a temporary file
255        let temp_path = self.wal_path.with_extension("tmp");
256        let mut temp_file = OpenOptions::new()
257            .create(true)
258            .write(true)
259            .truncate(true)
260            .open(&temp_path)?;
261
262        for record in &trimmed_records {
263            let encoded = bincode::serialize(record)
264                .map_err(|e| DbxError::Wal(format!("serialization failed: {}", e)))?;
265            let len = (encoded.len() as u32).to_le_bytes();
266            temp_file.write_all(&len)?;
267            temp_file.write_all(&encoded)?;
268        }
269
270        temp_file.sync_all()?;
271        drop(temp_file);
272
273        // Replace the original WAL file with the trimmed one
274        std::fs::rename(&temp_path, &self.wal_path)?;
275
276        Ok(())
277    }
278
279    /// Returns the checkpoint interval.
280    pub fn interval(&self) -> Duration {
281        self.interval
282    }
283}
284
285#[cfg(test)]
286mod tests {
287    use super::*;
288    use tempfile::NamedTempFile;
289
290    #[test]
291    fn checkpoint_applies_wal() {
292        use std::cell::RefCell;
293
294        let temp_file = NamedTempFile::new().unwrap();
295        let wal = Arc::new(WriteAheadLog::open(temp_file.path()).unwrap());
296        let checkpoint_mgr = CheckpointManager::new(wal.clone(), temp_file.path());
297
298        // Add some records
299        let record1 = WalRecord::Insert {
300            table: "users".to_string(),
301            key: b"user:1".to_vec(),
302            value: b"Alice".to_vec(),
303            ts: 0,
304        };
305        let record2 = WalRecord::Delete {
306            table: "users".to_string(),
307            key: b"user:2".to_vec(),
308            ts: 1,
309        };
310
311        wal.append(&record1).unwrap();
312        wal.append(&record2).unwrap();
313        wal.sync().unwrap();
314
315        // Checkpoint
316        let applied_records = RefCell::new(Vec::new());
317        let apply_fn = |record: &WalRecord| {
318            applied_records.borrow_mut().push(record.clone());
319            Ok(())
320        };
321
322        let checkpoint_seq = checkpoint_mgr.checkpoint(apply_fn).unwrap();
323        assert!(checkpoint_seq > 0);
324        let records = applied_records.borrow();
325        assert_eq!(records.len(), 2);
326        assert_eq!(records[0], record1);
327        assert_eq!(records[1], record2);
328    }
329
330    #[test]
331    fn recover_replays_after_checkpoint() {
332        use std::cell::RefCell;
333
334        let temp_file = NamedTempFile::new().unwrap();
335        let wal = Arc::new(WriteAheadLog::open(temp_file.path()).unwrap());
336
337        // Add records before checkpoint
338        let record1 = WalRecord::Insert {
339            table: "users".to_string(),
340            key: b"user:1".to_vec(),
341            value: b"Alice".to_vec(),
342            ts: 0,
343        };
344        wal.append(&record1).unwrap();
345
346        // Checkpoint
347        let checkpoint = WalRecord::Checkpoint { sequence: 1 };
348        wal.append(&checkpoint).unwrap();
349
350        // Add records after checkpoint
351        let record2 = WalRecord::Insert {
352            table: "users".to_string(),
353            key: b"user:2".to_vec(),
354            value: b"Bob".to_vec(),
355            ts: 2, // After checkpoint
356        };
357        wal.append(&record2).unwrap();
358        wal.sync().unwrap();
359
360        // Recover
361        let recovered_records = RefCell::new(Vec::new());
362        let apply_fn = |record: &WalRecord| {
363            recovered_records.borrow_mut().push(record.clone());
364            Ok(())
365        };
366
367        let count = CheckpointManager::recover(temp_file.path(), apply_fn).unwrap();
368
369        // Should only replay record2 (after checkpoint)
370        assert_eq!(count, 1);
371        let records = recovered_records.borrow();
372        assert_eq!(records.len(), 1);
373        assert_eq!(records[0], record2);
374    }
375
376    #[test]
377    fn trim_removes_old_records() {
378        let temp_file = NamedTempFile::new().unwrap();
379        let wal = Arc::new(WriteAheadLog::open(temp_file.path()).unwrap());
380        let checkpoint_mgr = CheckpointManager::new(wal.clone(), temp_file.path());
381
382        // Add records
383        let record1 = WalRecord::Insert {
384            table: "users".to_string(),
385            key: b"user:1".to_vec(),
386            value: b"Alice".to_vec(),
387            ts: 0,
388        };
389        wal.append(&record1).unwrap();
390
391        // Checkpoint
392        let checkpoint = WalRecord::Checkpoint { sequence: 1 };
393        wal.append(&checkpoint).unwrap();
394
395        let record2 = WalRecord::Insert {
396            table: "users".to_string(),
397            key: b"user:2".to_vec(),
398            value: b"Bob".to_vec(),
399            ts: 2,
400        };
401        wal.append(&record2).unwrap();
402        wal.sync().unwrap();
403
404        // Trim before sequence 1
405        checkpoint_mgr.trim_before(1).unwrap();
406
407        // Re-open and verify
408        let wal2 = WriteAheadLog::open(temp_file.path()).unwrap();
409        let records = wal2.replay().unwrap();
410
411        // Should only have checkpoint and record2
412        assert_eq!(records.len(), 2);
413        assert!(matches!(records[0], WalRecord::Checkpoint { sequence: 1 }));
414        assert_eq!(records[1], record2);
415    }
416
417    #[test]
418    fn recover_empty_wal() {
419        let temp_file = NamedTempFile::new().unwrap();
420        std::fs::remove_file(temp_file.path()).unwrap();
421
422        let apply_fn = |_: &WalRecord| Ok(());
423        let count = CheckpointManager::recover(temp_file.path(), apply_fn).unwrap();
424
425        assert_eq!(count, 0);
426    }
427
428    #[test]
429    fn checkpoint_interval() {
430        let temp_file = NamedTempFile::new().unwrap();
431        let wal = Arc::new(WriteAheadLog::open(temp_file.path()).unwrap());
432
433        let checkpoint_mgr =
434            CheckpointManager::new(wal, temp_file.path()).with_interval(Duration::from_secs(60));
435
436        assert_eq!(checkpoint_mgr.interval(), Duration::from_secs(60));
437    }
438}