dbx_core/wal/checkpoint.rs
1//! Checkpoint manager for WAL maintenance and crash recovery.
2//!
3//! The checkpoint manager coordinates with the WAL to:
4//! - Apply WAL changes to the persistent storage
5//! - Trim old WAL records after successful checkpoint
6//! - Recover database state by replaying WAL records
7//!
8//! # Example
9//!
10//! ```rust
11//! use dbx_core::wal::WriteAheadLog;
12//! use dbx_core::wal::checkpoint::CheckpointManager;
13//! use std::sync::Arc;
14//! use std::path::Path;
15//!
16//! # fn main() -> dbx_core::DbxResult<()> {
17//! let wal = Arc::new(WriteAheadLog::open(Path::new("./wal.log"))?);
18//! let checkpoint_mgr = CheckpointManager::new(wal, Path::new("./wal.log"));
19//!
20//! // Perform checkpoint (apply WAL to storage)
21//! // checkpoint_mgr.checkpoint(&db)?;
22//! # Ok(())
23//! # }
24//! ```
25
26use crate::error::{DbxError, DbxResult};
27use crate::wal::{WalRecord, WriteAheadLog};
28use std::fs::OpenOptions;
29use std::io::Write;
30use std::path::{Path, PathBuf};
31use std::sync::Arc;
32use std::time::Duration;
33
34/// Checkpoint manager for WAL maintenance.
35///
36/// Manages periodic checkpoints and WAL trimming to keep the WAL file size bounded.
37pub struct CheckpointManager {
38 /// Reference to the WAL
39 wal: Arc<WriteAheadLog>,
40
41 /// Checkpoint interval (default: 30 seconds)
42 interval: Duration,
43
44 /// Path to the WAL file (for trimming)
45 wal_path: PathBuf,
46}
47
48impl CheckpointManager {
49 /// Creates a new checkpoint manager.
50 ///
51 /// # Arguments
52 ///
53 /// * `wal` - Shared reference to the WAL
54 /// * `wal_path` - Path to the WAL file
55 ///
56 /// # Example
57 ///
58 /// ```rust
59 /// # use dbx_core::wal::WriteAheadLog;
60 /// # use dbx_core::wal::checkpoint::CheckpointManager;
61 /// # use std::sync::Arc;
62 /// # use std::path::Path;
63 /// # fn main() -> dbx_core::DbxResult<()> {
64 /// let wal = Arc::new(WriteAheadLog::open(Path::new("./wal.log"))?);
65 /// let checkpoint_mgr = CheckpointManager::new(wal, Path::new("./wal.log"));
66 /// # Ok(())
67 /// # }
68 /// ```
69 pub fn new(wal: Arc<WriteAheadLog>, wal_path: &Path) -> Self {
70 Self {
71 wal,
72 interval: Duration::from_secs(30),
73 wal_path: wal_path.to_path_buf(),
74 }
75 }
76
77 /// Sets the checkpoint interval.
78 ///
79 /// # Arguments
80 ///
81 /// * `interval` - Duration between checkpoints
82 pub fn with_interval(mut self, interval: Duration) -> Self {
83 self.interval = interval;
84 self
85 }
86
87 /// Performs a checkpoint.
88 ///
89 /// Applies all WAL records to the database and writes a checkpoint marker.
90 /// This method should be called by the Database engine.
91 ///
92 /// # Arguments
93 ///
94 /// * `apply_fn` - Function to apply a WAL record to the database
95 ///
96 /// # Returns
97 ///
98 /// The sequence number of the checkpoint
99 ///
100 /// # Example
101 ///
102 /// ```rust,no_run
103 /// # use dbx_core::wal::WriteAheadLog;
104 /// # use dbx_core::wal::checkpoint::CheckpointManager;
105 /// # use dbx_core::wal::WalRecord;
106 /// # use std::sync::Arc;
107 /// # use std::path::Path;
108 /// # fn main() -> dbx_core::DbxResult<()> {
109 /// let wal = Arc::new(WriteAheadLog::open(Path::new("./wal.log"))?);
110 /// let checkpoint_mgr = CheckpointManager::new(wal, Path::new("./wal.log"));
111 ///
112 /// let apply_fn = |record: &WalRecord| -> dbx_core::DbxResult<()> {
113 /// // Apply record to database
114 /// Ok(())
115 /// };
116 ///
117 /// let checkpoint_seq = checkpoint_mgr.checkpoint(apply_fn)?;
118 /// # Ok(())
119 /// # }
120 /// ```
121 pub fn checkpoint<F>(&self, apply_fn: F) -> DbxResult<u64>
122 where
123 F: Fn(&WalRecord) -> DbxResult<()>,
124 {
125 // Replay all WAL records
126 let records = self.wal.replay()?;
127
128 for record in &records {
129 // Skip checkpoint markers
130 if matches!(record, WalRecord::Checkpoint { .. }) {
131 continue;
132 }
133
134 // Apply record to database
135 apply_fn(record)?;
136 }
137
138 // Write checkpoint marker
139 let seq = self.wal.current_sequence();
140 let checkpoint_record = WalRecord::Checkpoint { sequence: seq };
141 self.wal.append(&checkpoint_record)?;
142 self.wal.sync()?;
143
144 Ok(seq)
145 }
146
147 /// Recovers the database by replaying WAL records.
148 ///
149 /// This is called during database startup to restore the state after a crash.
150 ///
151 /// # Arguments
152 ///
153 /// * `wal_path` - Path to the WAL file
154 /// * `apply_fn` - Function to apply a WAL record to the database
155 ///
156 /// # Returns
157 ///
158 /// The number of records replayed
159 ///
160 /// # Example
161 ///
162 /// ```rust,no_run
163 /// # use dbx_core::wal::checkpoint::CheckpointManager;
164 /// # use dbx_core::wal::WalRecord;
165 /// # use std::path::Path;
166 /// # fn main() -> dbx_core::DbxResult<()> {
167 /// let apply_fn = |record: &WalRecord| -> dbx_core::DbxResult<()> {
168 /// // Apply record to database
169 /// Ok(())
170 /// };
171 ///
172 /// let count = CheckpointManager::recover(Path::new("./wal.log"), apply_fn)?;
173 /// println!("Replayed {} records", count);
174 /// # Ok(())
175 /// # }
176 /// ```
177 pub fn recover<F>(wal_path: &Path, apply_fn: F) -> DbxResult<usize>
178 where
179 F: Fn(&WalRecord) -> DbxResult<()>,
180 {
181 // Check if WAL file exists
182 if !wal_path.exists() {
183 return Ok(0);
184 }
185
186 let wal = WriteAheadLog::open(wal_path)?;
187 let records = wal.replay()?;
188
189 // Find the last checkpoint
190 let mut last_checkpoint_idx = None;
191 for (i, record) in records.iter().enumerate() {
192 if matches!(record, WalRecord::Checkpoint { .. }) {
193 last_checkpoint_idx = Some(i);
194 }
195 }
196
197 // Replay records after the last checkpoint
198 let start_idx = last_checkpoint_idx.map(|i| i + 1).unwrap_or(0);
199 let replay_count = records.len() - start_idx;
200
201 for record in &records[start_idx..] {
202 apply_fn(record)?;
203 }
204
205 Ok(replay_count)
206 }
207
208 /// Trims the WAL file by removing records before the specified sequence.
209 ///
210 /// This is called after a successful checkpoint to keep the WAL file size bounded.
211 ///
212 /// # Arguments
213 ///
214 /// * `sequence` - Sequence number to trim before
215 ///
216 /// # Example
217 ///
218 /// ```rust,no_run
219 /// # use dbx_core::wal::WriteAheadLog;
220 /// # use dbx_core::wal::checkpoint::CheckpointManager;
221 /// # use std::sync::Arc;
222 /// # use std::path::Path;
223 /// # fn main() -> dbx_core::DbxResult<()> {
224 /// let wal = Arc::new(WriteAheadLog::open(Path::new("./wal.log"))?);
225 /// let checkpoint_mgr = CheckpointManager::new(wal, Path::new("./wal.log"));
226 ///
227 /// // Trim records before sequence 100
228 /// checkpoint_mgr.trim_before(100)?;
229 /// # Ok(())
230 /// # }
231 /// ```
232 pub fn trim_before(&self, sequence: u64) -> DbxResult<()> {
233 // Read all records
234 let records = self.wal.replay()?;
235
236 // Find the last checkpoint with sequence >= target
237 let mut last_checkpoint_idx = None;
238 for (i, record) in records.iter().enumerate() {
239 if let WalRecord::Checkpoint { sequence: seq } = record
240 && *seq >= sequence
241 {
242 last_checkpoint_idx = Some(i);
243 }
244 }
245
246 // Keep only records from the last checkpoint onwards
247 let trimmed_records: Vec<WalRecord> = if let Some(idx) = last_checkpoint_idx {
248 records.into_iter().skip(idx).collect()
249 } else {
250 // No checkpoint found, keep all records
251 records
252 };
253
254 // Write trimmed records to a temporary file
255 let temp_path = self.wal_path.with_extension("tmp");
256 let mut temp_file = OpenOptions::new()
257 .create(true)
258 .write(true)
259 .truncate(true)
260 .open(&temp_path)?;
261
262 for record in &trimmed_records {
263 let encoded = bincode::serialize(record)
264 .map_err(|e| DbxError::Wal(format!("serialization failed: {}", e)))?;
265 let len = (encoded.len() as u32).to_le_bytes();
266 temp_file.write_all(&len)?;
267 temp_file.write_all(&encoded)?;
268 }
269
270 temp_file.sync_all()?;
271 drop(temp_file);
272
273 // Replace the original WAL file with the trimmed one
274 std::fs::rename(&temp_path, &self.wal_path)?;
275
276 Ok(())
277 }
278
279 /// Returns the checkpoint interval.
280 pub fn interval(&self) -> Duration {
281 self.interval
282 }
283}
284
285#[cfg(test)]
286mod tests {
287 use super::*;
288 use tempfile::NamedTempFile;
289
290 #[test]
291 fn checkpoint_applies_wal() {
292 use std::cell::RefCell;
293
294 let temp_file = NamedTempFile::new().unwrap();
295 let wal = Arc::new(WriteAheadLog::open(temp_file.path()).unwrap());
296 let checkpoint_mgr = CheckpointManager::new(wal.clone(), temp_file.path());
297
298 // Add some records
299 let record1 = WalRecord::Insert {
300 table: "users".to_string(),
301 key: b"user:1".to_vec(),
302 value: b"Alice".to_vec(),
303 ts: 0,
304 };
305 let record2 = WalRecord::Delete {
306 table: "users".to_string(),
307 key: b"user:2".to_vec(),
308 ts: 1,
309 };
310
311 wal.append(&record1).unwrap();
312 wal.append(&record2).unwrap();
313 wal.sync().unwrap();
314
315 // Checkpoint
316 let applied_records = RefCell::new(Vec::new());
317 let apply_fn = |record: &WalRecord| {
318 applied_records.borrow_mut().push(record.clone());
319 Ok(())
320 };
321
322 let checkpoint_seq = checkpoint_mgr.checkpoint(apply_fn).unwrap();
323 assert!(checkpoint_seq > 0);
324 let records = applied_records.borrow();
325 assert_eq!(records.len(), 2);
326 assert_eq!(records[0], record1);
327 assert_eq!(records[1], record2);
328 }
329
330 #[test]
331 fn recover_replays_after_checkpoint() {
332 use std::cell::RefCell;
333
334 let temp_file = NamedTempFile::new().unwrap();
335 let wal = Arc::new(WriteAheadLog::open(temp_file.path()).unwrap());
336
337 // Add records before checkpoint
338 let record1 = WalRecord::Insert {
339 table: "users".to_string(),
340 key: b"user:1".to_vec(),
341 value: b"Alice".to_vec(),
342 ts: 0,
343 };
344 wal.append(&record1).unwrap();
345
346 // Checkpoint
347 let checkpoint = WalRecord::Checkpoint { sequence: 1 };
348 wal.append(&checkpoint).unwrap();
349
350 // Add records after checkpoint
351 let record2 = WalRecord::Insert {
352 table: "users".to_string(),
353 key: b"user:2".to_vec(),
354 value: b"Bob".to_vec(),
355 ts: 2, // After checkpoint
356 };
357 wal.append(&record2).unwrap();
358 wal.sync().unwrap();
359
360 // Recover
361 let recovered_records = RefCell::new(Vec::new());
362 let apply_fn = |record: &WalRecord| {
363 recovered_records.borrow_mut().push(record.clone());
364 Ok(())
365 };
366
367 let count = CheckpointManager::recover(temp_file.path(), apply_fn).unwrap();
368
369 // Should only replay record2 (after checkpoint)
370 assert_eq!(count, 1);
371 let records = recovered_records.borrow();
372 assert_eq!(records.len(), 1);
373 assert_eq!(records[0], record2);
374 }
375
376 #[test]
377 fn trim_removes_old_records() {
378 let temp_file = NamedTempFile::new().unwrap();
379 let wal = Arc::new(WriteAheadLog::open(temp_file.path()).unwrap());
380 let checkpoint_mgr = CheckpointManager::new(wal.clone(), temp_file.path());
381
382 // Add records
383 let record1 = WalRecord::Insert {
384 table: "users".to_string(),
385 key: b"user:1".to_vec(),
386 value: b"Alice".to_vec(),
387 ts: 0,
388 };
389 wal.append(&record1).unwrap();
390
391 // Checkpoint
392 let checkpoint = WalRecord::Checkpoint { sequence: 1 };
393 wal.append(&checkpoint).unwrap();
394
395 let record2 = WalRecord::Insert {
396 table: "users".to_string(),
397 key: b"user:2".to_vec(),
398 value: b"Bob".to_vec(),
399 ts: 2,
400 };
401 wal.append(&record2).unwrap();
402 wal.sync().unwrap();
403
404 // Trim before sequence 1
405 checkpoint_mgr.trim_before(1).unwrap();
406
407 // Re-open and verify
408 let wal2 = WriteAheadLog::open(temp_file.path()).unwrap();
409 let records = wal2.replay().unwrap();
410
411 // Should only have checkpoint and record2
412 assert_eq!(records.len(), 2);
413 assert!(matches!(records[0], WalRecord::Checkpoint { sequence: 1 }));
414 assert_eq!(records[1], record2);
415 }
416
417 #[test]
418 fn recover_empty_wal() {
419 let temp_file = NamedTempFile::new().unwrap();
420 std::fs::remove_file(temp_file.path()).unwrap();
421
422 let apply_fn = |_: &WalRecord| Ok(());
423 let count = CheckpointManager::recover(temp_file.path(), apply_fn).unwrap();
424
425 assert_eq!(count, 0);
426 }
427
428 #[test]
429 fn checkpoint_interval() {
430 let temp_file = NamedTempFile::new().unwrap();
431 let wal = Arc::new(WriteAheadLog::open(temp_file.path()).unwrap());
432
433 let checkpoint_mgr =
434 CheckpointManager::new(wal, temp_file.path()).with_interval(Duration::from_secs(60));
435
436 assert_eq!(checkpoint_mgr.interval(), Duration::from_secs(60));
437 }
438}