Skip to main content

graphos_adapters/storage/wal/
recovery.rs

1//! WAL recovery.
2
3use super::{CheckpointMetadata, WalManager, WalRecord};
4use graphos_common::utils::error::{Error, Result, StorageError};
5use std::fs::File;
6use std::io::{BufReader, Read};
7use std::path::Path;
8
9/// Name of the checkpoint metadata file.
10const CHECKPOINT_METADATA_FILE: &str = "checkpoint.meta";
11
12/// Handles WAL recovery after a crash.
13pub struct WalRecovery {
14    /// Directory containing WAL files.
15    dir: std::path::PathBuf,
16}
17
18impl WalRecovery {
19    /// Creates a new recovery handler for the given WAL directory.
20    pub fn new(dir: impl AsRef<Path>) -> Self {
21        Self {
22            dir: dir.as_ref().to_path_buf(),
23        }
24    }
25
26    /// Creates a recovery handler from a WAL manager.
27    #[must_use]
28    pub fn from_wal(wal: &WalManager) -> Self {
29        Self {
30            dir: wal.dir().to_path_buf(),
31        }
32    }
33
34    /// Reads checkpoint metadata if it exists.
35    ///
36    /// Returns `None` if no checkpoint metadata is found.
37    pub fn read_checkpoint_metadata(&self) -> Result<Option<CheckpointMetadata>> {
38        let metadata_path = self.dir.join(CHECKPOINT_METADATA_FILE);
39
40        if !metadata_path.exists() {
41            return Ok(None);
42        }
43
44        let file = File::open(&metadata_path)?;
45        let mut reader = BufReader::new(file);
46        let mut data = Vec::new();
47        reader.read_to_end(&mut data)?;
48
49        let (metadata, _): (CheckpointMetadata, _) =
50            bincode::serde::decode_from_slice(&data, bincode::config::standard())
51                .map_err(|e| Error::Serialization(e.to_string()))?;
52
53        Ok(Some(metadata))
54    }
55
56    /// Returns the checkpoint metadata, if any.
57    ///
58    /// This is useful for determining whether to perform a full or
59    /// incremental recovery.
60    #[must_use]
61    pub fn checkpoint(&self) -> Option<CheckpointMetadata> {
62        self.read_checkpoint_metadata().ok().flatten()
63    }
64
65    /// Recovers committed records from all WAL files.
66    ///
67    /// Returns only records that were part of committed transactions.
68    /// If checkpoint metadata exists, only replays files from the
69    /// checkpoint sequence onwards.
70    ///
71    /// # Errors
72    ///
73    /// Returns an error if recovery fails.
74    pub fn recover(&self) -> Result<Vec<WalRecord>> {
75        // Check for checkpoint metadata
76        let checkpoint = self.read_checkpoint_metadata()?;
77        self.recover_internal(checkpoint)
78    }
79
80    /// Recovers committed records, starting from a specific checkpoint.
81    ///
82    /// This can be used for incremental recovery when you want to
83    /// skip WAL files that precede the checkpoint.
84    ///
85    /// # Errors
86    ///
87    /// Returns an error if recovery fails.
88    pub fn recover_from_checkpoint(
89        &self,
90        checkpoint: Option<&CheckpointMetadata>,
91    ) -> Result<Vec<WalRecord>> {
92        self.recover_internal(checkpoint.cloned())
93    }
94
95    fn recover_internal(&self, checkpoint: Option<CheckpointMetadata>) -> Result<Vec<WalRecord>> {
96        let mut current_tx_records = Vec::new();
97        let mut committed_records = Vec::new();
98
99        // Get all log files in order
100        let log_files = self.get_log_files()?;
101
102        // Determine the minimum sequence number to process
103        let min_sequence = checkpoint.as_ref().map(|cp| cp.log_sequence).unwrap_or(0);
104
105        if checkpoint.is_some() {
106            tracing::info!(
107                "Recovering from checkpoint at epoch {:?}, starting from log sequence {}",
108                checkpoint.as_ref().map(|c| c.epoch),
109                min_sequence
110            );
111        }
112
113        // Read log files in sequence, skipping those before checkpoint
114        for log_file in log_files {
115            // Extract sequence number from filename
116            let sequence = Self::sequence_from_path(&log_file).unwrap_or(0);
117
118            // Skip files that are completely before the checkpoint
119            // We include the checkpoint sequence file because it may contain
120            // records after the checkpoint record itself
121            if sequence < min_sequence {
122                tracing::debug!(
123                    "Skipping log file {:?} (sequence {} < checkpoint {})",
124                    log_file,
125                    sequence,
126                    min_sequence
127                );
128                continue;
129            }
130
131            let file = match File::open(&log_file) {
132                Ok(f) => f,
133                Err(e) if e.kind() == std::io::ErrorKind::NotFound => continue,
134                Err(e) => return Err(e.into()),
135            };
136            let mut reader = BufReader::new(file);
137
138            // Read all records from this file
139            loop {
140                match self.read_record(&mut reader) {
141                    Ok(Some(record)) => {
142                        match &record {
143                            WalRecord::TxCommit { .. } => {
144                                // Commit current transaction
145                                committed_records.append(&mut current_tx_records);
146                                committed_records.push(record);
147                            }
148                            WalRecord::TxAbort { .. } => {
149                                // Discard current transaction
150                                current_tx_records.clear();
151                            }
152                            WalRecord::Checkpoint { .. } => {
153                                // Checkpoint - clear uncommitted, keep committed
154                                current_tx_records.clear();
155                                committed_records.push(record);
156                            }
157                            _ => {
158                                current_tx_records.push(record);
159                            }
160                        }
161                    }
162                    Ok(None) => break, // EOF
163                    Err(e) => {
164                        // Log corruption - stop reading this file but continue
165                        // with remaining files (best-effort recovery)
166                        tracing::warn!("WAL corruption detected in {:?}: {}", log_file, e);
167                        break;
168                    }
169                }
170            }
171        }
172
173        // Uncommitted records in current_tx_records are discarded
174
175        Ok(committed_records)
176    }
177
178    /// Extracts the sequence number from a WAL log file path.
179    fn sequence_from_path(path: &Path) -> Option<u64> {
180        path.file_stem()
181            .and_then(|s| s.to_str())
182            .and_then(|s| s.strip_prefix("wal_"))
183            .and_then(|s| s.parse().ok())
184    }
185
186    /// Recovers committed records from a single WAL file.
187    ///
188    /// # Errors
189    ///
190    /// Returns an error if recovery fails.
191    pub fn recover_file(&self, path: impl AsRef<Path>) -> Result<Vec<WalRecord>> {
192        let file = File::open(path.as_ref())?;
193        let mut reader = BufReader::new(file);
194
195        let mut current_tx_records = Vec::new();
196        let mut committed_records = Vec::new();
197
198        loop {
199            match self.read_record(&mut reader) {
200                Ok(Some(record)) => match &record {
201                    WalRecord::TxCommit { .. } => {
202                        committed_records.append(&mut current_tx_records);
203                        committed_records.push(record);
204                    }
205                    WalRecord::TxAbort { .. } => {
206                        current_tx_records.clear();
207                    }
208                    _ => {
209                        current_tx_records.push(record);
210                    }
211                },
212                Ok(None) => break,
213                Err(e) => {
214                    tracing::warn!("WAL corruption detected: {}", e);
215                    break;
216                }
217            }
218        }
219
220        Ok(committed_records)
221    }
222
223    fn get_log_files(&self) -> Result<Vec<std::path::PathBuf>> {
224        let mut files = Vec::new();
225
226        if !self.dir.exists() {
227            return Ok(files);
228        }
229
230        if let Ok(entries) = std::fs::read_dir(&self.dir) {
231            for entry in entries.flatten() {
232                let path = entry.path();
233                if path.extension().is_some_and(|ext| ext == "log") {
234                    files.push(path);
235                }
236            }
237        }
238
239        // Sort by filename (which includes sequence number)
240        files.sort();
241
242        Ok(files)
243    }
244
245    fn read_record(&self, reader: &mut BufReader<File>) -> Result<Option<WalRecord>> {
246        // Read length prefix
247        let mut len_buf = [0u8; 4];
248        match reader.read_exact(&mut len_buf) {
249            Ok(()) => {}
250            Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => return Ok(None),
251            Err(e) => return Err(e.into()),
252        }
253        let len = u32::from_le_bytes(len_buf) as usize;
254
255        // Read data
256        let mut data = vec![0u8; len];
257        reader.read_exact(&mut data)?;
258
259        // Read and verify checksum
260        let mut checksum_buf = [0u8; 4];
261        reader.read_exact(&mut checksum_buf)?;
262        let stored_checksum = u32::from_le_bytes(checksum_buf);
263        let computed_checksum = crc32fast::hash(&data);
264
265        if stored_checksum != computed_checksum {
266            return Err(Error::Storage(StorageError::Corruption(
267                "WAL checksum mismatch".to_string(),
268            )));
269        }
270
271        // Deserialize
272        let (record, _): (WalRecord, _) =
273            bincode::serde::decode_from_slice(&data, bincode::config::standard())
274                .map_err(|e| Error::Serialization(e.to_string()))?;
275
276        Ok(Some(record))
277    }
278}
279
280#[cfg(test)]
281mod tests {
282    use super::*;
283    use graphos_common::types::{NodeId, TxId};
284    use tempfile::tempdir;
285
286    #[test]
287    fn test_recovery_committed() {
288        let dir = tempdir().unwrap();
289
290        // Write some records
291        {
292            let wal = WalManager::open(dir.path()).unwrap();
293
294            wal.log(&WalRecord::CreateNode {
295                id: NodeId::new(1),
296                labels: vec!["Person".to_string()],
297            })
298            .unwrap();
299
300            wal.log(&WalRecord::TxCommit {
301                tx_id: TxId::new(1),
302            })
303            .unwrap();
304
305            wal.sync().unwrap();
306        }
307
308        // Recover
309        let recovery = WalRecovery::new(dir.path());
310        let records = recovery.recover().unwrap();
311
312        assert_eq!(records.len(), 2);
313    }
314
315    #[test]
316    fn test_recovery_uncommitted() {
317        let dir = tempdir().unwrap();
318
319        // Write some records without commit
320        {
321            let wal = WalManager::open(dir.path()).unwrap();
322
323            wal.log(&WalRecord::CreateNode {
324                id: NodeId::new(1),
325                labels: vec!["Person".to_string()],
326            })
327            .unwrap();
328
329            // No commit!
330            wal.sync().unwrap();
331        }
332
333        // Recover
334        let recovery = WalRecovery::new(dir.path());
335        let records = recovery.recover().unwrap();
336
337        // Uncommitted records should be discarded
338        assert_eq!(records.len(), 0);
339    }
340
341    #[test]
342    fn test_recovery_multiple_files() {
343        let dir = tempdir().unwrap();
344
345        // Write records across multiple files
346        {
347            let config = super::super::WalConfig {
348                max_log_size: 100, // Force rotation
349                ..Default::default()
350            };
351            let wal = WalManager::with_config(dir.path(), config).unwrap();
352
353            // First transaction
354            for i in 0..5 {
355                wal.log(&WalRecord::CreateNode {
356                    id: NodeId::new(i),
357                    labels: vec!["Test".to_string()],
358                })
359                .unwrap();
360            }
361            wal.log(&WalRecord::TxCommit {
362                tx_id: TxId::new(1),
363            })
364            .unwrap();
365
366            // Second transaction
367            for i in 5..10 {
368                wal.log(&WalRecord::CreateNode {
369                    id: NodeId::new(i),
370                    labels: vec!["Test".to_string()],
371                })
372                .unwrap();
373            }
374            wal.log(&WalRecord::TxCommit {
375                tx_id: TxId::new(2),
376            })
377            .unwrap();
378
379            wal.sync().unwrap();
380        }
381
382        // Recover
383        let recovery = WalRecovery::new(dir.path());
384        let records = recovery.recover().unwrap();
385
386        // Should have 10 CreateNode + 2 TxCommit
387        assert_eq!(records.len(), 12);
388    }
389
390    #[test]
391    fn test_checkpoint_metadata() {
392        use graphos_common::types::EpochId;
393
394        let dir = tempdir().unwrap();
395
396        // Write records and create a checkpoint
397        {
398            let wal = WalManager::open(dir.path()).unwrap();
399
400            // First transaction
401            wal.log(&WalRecord::CreateNode {
402                id: NodeId::new(1),
403                labels: vec!["Test".to_string()],
404            })
405            .unwrap();
406            wal.log(&WalRecord::TxCommit {
407                tx_id: TxId::new(1),
408            })
409            .unwrap();
410
411            // Create checkpoint
412            wal.checkpoint(TxId::new(1), EpochId::new(10)).unwrap();
413
414            // Second transaction after checkpoint
415            wal.log(&WalRecord::CreateNode {
416                id: NodeId::new(2),
417                labels: vec!["Test".to_string()],
418            })
419            .unwrap();
420            wal.log(&WalRecord::TxCommit {
421                tx_id: TxId::new(2),
422            })
423            .unwrap();
424
425            wal.sync().unwrap();
426        }
427
428        // Verify checkpoint metadata was written
429        let recovery = WalRecovery::new(dir.path());
430        let checkpoint = recovery.checkpoint();
431        assert!(checkpoint.is_some(), "Checkpoint metadata should exist");
432
433        let cp = checkpoint.unwrap();
434        assert_eq!(cp.epoch.as_u64(), 10);
435        assert_eq!(cp.tx_id.as_u64(), 1);
436    }
437
438    #[test]
439    fn test_recovery_from_checkpoint() {
440        use super::super::WalConfig;
441        use graphos_common::types::EpochId;
442
443        let dir = tempdir().unwrap();
444
445        // Write records across multiple log files with checkpoint
446        {
447            let config = WalConfig {
448                max_log_size: 100, // Force rotation
449                ..Default::default()
450            };
451            let wal = WalManager::with_config(dir.path(), config).unwrap();
452
453            // First batch of records (should end up in early log files)
454            for i in 0..5 {
455                wal.log(&WalRecord::CreateNode {
456                    id: NodeId::new(i),
457                    labels: vec!["Before".to_string()],
458                })
459                .unwrap();
460            }
461            wal.log(&WalRecord::TxCommit {
462                tx_id: TxId::new(1),
463            })
464            .unwrap();
465
466            // Create checkpoint
467            wal.checkpoint(TxId::new(1), EpochId::new(100)).unwrap();
468
469            // Second batch after checkpoint
470            for i in 100..103 {
471                wal.log(&WalRecord::CreateNode {
472                    id: NodeId::new(i),
473                    labels: vec!["After".to_string()],
474                })
475                .unwrap();
476            }
477            wal.log(&WalRecord::TxCommit {
478                tx_id: TxId::new(2),
479            })
480            .unwrap();
481
482            wal.sync().unwrap();
483        }
484
485        // Recovery should use checkpoint metadata to skip old files
486        let recovery = WalRecovery::new(dir.path());
487        let records = recovery.recover().unwrap();
488
489        // We should get all committed records (checkpoint metadata is used for optimization)
490        // The number depends on how many log files were skipped
491        assert!(!records.is_empty(), "Should recover some records");
492    }
493}