chie_core/
wal.rs

1//! Write-Ahead Logging (WAL) for crash recovery.
2//!
3//! This module implements a write-ahead log that ensures durability and enables
4//! crash recovery for storage operations. All mutations are logged before being
5//! applied, allowing recovery from incomplete operations after a crash.
6//!
7//! # Example
8//!
9//! ```rust
10//! use chie_core::wal::{WriteAheadLog, LogEntry, Operation};
11//! use std::path::PathBuf;
12//!
13//! # async fn example() -> Result<(), Box<dyn std::error::Error>> {
14//! let mut wal = WriteAheadLog::new(PathBuf::from("/tmp/wal")).await?;
15//!
16//! // Log a write operation
17//! let entry = LogEntry {
18//!     sequence: 1,
19//!     operation: Operation::WriteChunk {
20//!         cid: "QmTest".to_string(),
21//!         chunk_index: 0,
22//!         data: vec![1, 2, 3],
23//!     },
24//!     timestamp_ms: 1234567890,
25//! };
26//!
27//! wal.append(&entry).await?;
28//!
29//! // Replay log after crash
30//! let entries = wal.replay().await?;
31//! for entry in entries {
32//!     // Apply logged operations
33//! }
34//!
35//! // Truncate log after successful checkpoint
36//! wal.truncate(10).await?;
37//! # Ok(())
38//! # }
39//! ```
40
41use serde::{Deserialize, Serialize};
42use std::path::PathBuf;
43use thiserror::Error;
44use tokio::fs::{self, OpenOptions};
45use tokio::io::{AsyncReadExt, AsyncWriteExt};
46
47/// WAL error types.
48#[derive(Debug, Error)]
49pub enum WalError {
50    #[error("IO error: {0}")]
51    Io(#[from] std::io::Error),
52
53    #[error("Serialization error: {0}")]
54    Serialization(String),
55
56    #[error("Deserialization error: {0}")]
57    Deserialization(String),
58
59    #[error("Corrupted WAL entry at sequence {0}")]
60    CorruptedEntry(u64),
61
62    #[error("Invalid WAL format")]
63    InvalidFormat,
64}
65
66/// Types of operations that can be logged.
67#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
68pub enum Operation {
69    /// Write a chunk to storage.
70    WriteChunk {
71        cid: String,
72        chunk_index: u64,
73        data: Vec<u8>,
74    },
75    /// Delete a chunk from storage.
76    DeleteChunk { cid: String, chunk_index: u64 },
77    /// Pin content.
78    PinContent { cid: String, chunk_count: u64 },
79    /// Unpin content.
80    UnpinContent { cid: String },
81    /// Update metadata.
82    UpdateMetadata { cid: String, metadata: Vec<u8> },
83    /// Checkpoint marker (all prior operations completed).
84    Checkpoint { sequence: u64 },
85}
86
87/// A log entry in the WAL.
88#[derive(Debug, Clone, Serialize, Deserialize)]
89pub struct LogEntry {
90    /// Sequence number (monotonically increasing).
91    pub sequence: u64,
92    /// Operation to perform.
93    pub operation: Operation,
94    /// Timestamp when logged (Unix milliseconds).
95    pub timestamp_ms: i64,
96}
97
98impl LogEntry {
99    /// Create a new log entry.
100    #[must_use]
101    pub fn new(sequence: u64, operation: Operation) -> Self {
102        let timestamp_ms = std::time::SystemTime::now()
103            .duration_since(std::time::UNIX_EPOCH)
104            .unwrap_or_default()
105            .as_millis() as i64;
106
107        Self {
108            sequence,
109            operation,
110            timestamp_ms,
111        }
112    }
113
114    /// Get the sequence number.
115    #[must_use]
116    #[inline]
117    pub const fn sequence(&self) -> u64 {
118        self.sequence
119    }
120
121    /// Get the operation.
122    #[must_use]
123    #[inline]
124    pub const fn operation(&self) -> &Operation {
125        &self.operation
126    }
127
128    /// Serialize to bytes with length prefix.
129    fn to_bytes(&self) -> Result<Vec<u8>, WalError> {
130        let data = crate::serde_helpers::encode(self)
131            .map_err(|e| WalError::Serialization(e.to_string()))?;
132
133        // Length prefix (4 bytes) + data
134        let len = data.len() as u32;
135        let mut result = Vec::with_capacity(4 + data.len());
136        result.extend_from_slice(&len.to_le_bytes());
137        result.extend_from_slice(&data);
138
139        Ok(result)
140    }
141
142    /// Deserialize from bytes.
143    fn from_bytes(bytes: &[u8]) -> Result<Self, WalError> {
144        crate::serde_helpers::decode(bytes).map_err(|e| WalError::Deserialization(e.to_string()))
145    }
146}
147
148/// Write-ahead log for crash recovery.
149pub struct WriteAheadLog {
150    log_path: PathBuf,
151    next_sequence: u64,
152    checkpoint_sequence: u64,
153}
154
155impl WriteAheadLog {
156    /// Create a new WAL or open an existing one.
157    pub async fn new(log_path: PathBuf) -> Result<Self, WalError> {
158        // Ensure parent directory exists
159        if let Some(parent) = log_path.parent() {
160            fs::create_dir_all(parent).await?;
161        }
162
163        let mut wal = Self {
164            log_path,
165            next_sequence: 1,
166            checkpoint_sequence: 0,
167        };
168
169        // Scan existing log to find next sequence number
170        if wal.log_path.exists() {
171            let entries = wal.replay().await?;
172            if let Some(last_entry) = entries.last() {
173                wal.next_sequence = last_entry.sequence + 1;
174
175                // Find latest checkpoint
176                for entry in entries.iter().rev() {
177                    if let Operation::Checkpoint { sequence } = entry.operation {
178                        wal.checkpoint_sequence = sequence;
179                        break;
180                    }
181                }
182            }
183        }
184
185        Ok(wal)
186    }
187
188    /// Append a new entry to the log.
189    pub async fn append(&mut self, entry: &LogEntry) -> Result<(), WalError> {
190        let bytes = entry.to_bytes()?;
191
192        let mut file = OpenOptions::new()
193            .create(true)
194            .append(true)
195            .open(&self.log_path)
196            .await?;
197
198        file.write_all(&bytes).await?;
199        file.sync_all().await?; // Ensure durability
200
201        self.next_sequence = self.next_sequence.max(entry.sequence + 1);
202
203        Ok(())
204    }
205
206    /// Append an operation, automatically assigning sequence number.
207    pub async fn log_operation(&mut self, operation: Operation) -> Result<u64, WalError> {
208        let sequence = self.next_sequence;
209        let entry = LogEntry::new(sequence, operation);
210        self.append(&entry).await?;
211        Ok(sequence)
212    }
213
214    /// Replay the log, returning all entries.
215    ///
216    /// This should be called during recovery to get all pending operations.
217    pub async fn replay(&self) -> Result<Vec<LogEntry>, WalError> {
218        if !self.log_path.exists() {
219            return Ok(Vec::new());
220        }
221
222        let mut file = fs::File::open(&self.log_path).await?;
223        let mut entries = Vec::new();
224
225        loop {
226            // Read length prefix
227            let mut len_bytes = [0u8; 4];
228            match file.read_exact(&mut len_bytes).await {
229                Ok(_) => {}
230                Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => break,
231                Err(e) => return Err(WalError::Io(e)),
232            }
233
234            let len = u32::from_le_bytes(len_bytes) as usize;
235
236            // Read entry data
237            let mut data = vec![0u8; len];
238            file.read_exact(&mut data).await?;
239
240            // Deserialize entry
241            let entry = LogEntry::from_bytes(&data)?;
242            entries.push(entry);
243        }
244
245        Ok(entries)
246    }
247
248    /// Truncate log up to and including the given sequence number.
249    ///
250    /// This is typically called after a successful checkpoint to remove old entries.
251    pub async fn truncate(&mut self, up_to_sequence: u64) -> Result<(), WalError> {
252        let entries = self.replay().await?;
253        let remaining: Vec<LogEntry> = entries
254            .into_iter()
255            .filter(|e| e.sequence > up_to_sequence)
256            .collect();
257
258        // Rewrite log with remaining entries
259        if self.log_path.exists() {
260            fs::remove_file(&self.log_path).await?;
261        }
262
263        for entry in &remaining {
264            self.append(entry).await?;
265        }
266
267        self.checkpoint_sequence = up_to_sequence;
268
269        Ok(())
270    }
271
272    /// Write a checkpoint entry.
273    pub async fn checkpoint(&mut self) -> Result<u64, WalError> {
274        let sequence = self.next_sequence;
275        let operation = Operation::Checkpoint { sequence };
276        self.log_operation(operation).await?;
277        self.checkpoint_sequence = sequence;
278        Ok(sequence)
279    }
280
281    /// Get entries since last checkpoint.
282    pub async fn entries_since_checkpoint(&self) -> Result<Vec<LogEntry>, WalError> {
283        let all_entries = self.replay().await?;
284        Ok(all_entries
285            .into_iter()
286            .filter(|e| e.sequence > self.checkpoint_sequence)
287            .collect())
288    }
289
290    /// Get the next sequence number.
291    #[must_use]
292    #[inline]
293    pub const fn next_sequence(&self) -> u64 {
294        self.next_sequence
295    }
296
297    /// Get the last checkpoint sequence number.
298    #[must_use]
299    #[inline]
300    pub const fn checkpoint_sequence(&self) -> u64 {
301        self.checkpoint_sequence
302    }
303
304    /// Clear the entire log.
305    pub async fn clear(&mut self) -> Result<(), WalError> {
306        if self.log_path.exists() {
307            fs::remove_file(&self.log_path).await?;
308        }
309        self.next_sequence = 1;
310        self.checkpoint_sequence = 0;
311        Ok(())
312    }
313}
314
315#[cfg(test)]
316mod tests {
317    use super::*;
318    use tempfile::TempDir;
319
320    #[tokio::test]
321    async fn test_wal_creation() {
322        let temp_dir = TempDir::new().unwrap();
323        let log_path = temp_dir.path().join("test.wal");
324
325        let wal = WriteAheadLog::new(log_path).await.unwrap();
326        assert_eq!(wal.next_sequence(), 1);
327        assert_eq!(wal.checkpoint_sequence(), 0);
328    }
329
330    #[tokio::test]
331    async fn test_wal_append_and_replay() {
332        let temp_dir = TempDir::new().unwrap();
333        let log_path = temp_dir.path().join("test.wal");
334
335        let mut wal = WriteAheadLog::new(log_path.clone()).await.unwrap();
336
337        // Append some entries
338        let op1 = Operation::WriteChunk {
339            cid: "QmTest1".to_string(),
340            chunk_index: 0,
341            data: vec![1, 2, 3],
342        };
343        let op2 = Operation::WriteChunk {
344            cid: "QmTest2".to_string(),
345            chunk_index: 1,
346            data: vec![4, 5, 6],
347        };
348
349        wal.log_operation(op1.clone()).await.unwrap();
350        wal.log_operation(op2.clone()).await.unwrap();
351
352        // Replay log
353        let entries = wal.replay().await.unwrap();
354        assert_eq!(entries.len(), 2);
355        assert_eq!(entries[0].sequence, 1);
356        assert_eq!(entries[1].sequence, 2);
357        assert_eq!(entries[0].operation, op1);
358        assert_eq!(entries[1].operation, op2);
359    }
360
361    #[tokio::test]
362    async fn test_wal_checkpoint() {
363        let temp_dir = TempDir::new().unwrap();
364        let log_path = temp_dir.path().join("test.wal");
365
366        let mut wal = WriteAheadLog::new(log_path).await.unwrap();
367
368        // Log some operations
369        wal.log_operation(Operation::PinContent {
370            cid: "QmTest".to_string(),
371            chunk_count: 5,
372        })
373        .await
374        .unwrap();
375
376        // Create checkpoint
377        let checkpoint_seq = wal.checkpoint().await.unwrap();
378        assert_eq!(checkpoint_seq, 2);
379        assert_eq!(wal.checkpoint_sequence(), 2);
380    }
381
382    #[tokio::test]
383    async fn test_wal_truncate() {
384        let temp_dir = TempDir::new().unwrap();
385        let log_path = temp_dir.path().join("test.wal");
386
387        let mut wal = WriteAheadLog::new(log_path).await.unwrap();
388
389        // Log multiple operations
390        for i in 0..5 {
391            wal.log_operation(Operation::WriteChunk {
392                cid: format!("QmTest{}", i),
393                chunk_index: i,
394                data: vec![i as u8],
395            })
396            .await
397            .unwrap();
398        }
399
400        // Truncate after sequence 3
401        wal.truncate(3).await.unwrap();
402
403        // Replay should only have entries 4 and 5
404        let entries = wal.replay().await.unwrap();
405        assert_eq!(entries.len(), 2);
406        assert_eq!(entries[0].sequence, 4);
407        assert_eq!(entries[1].sequence, 5);
408    }
409
410    #[tokio::test]
411    async fn test_wal_entries_since_checkpoint() {
412        let temp_dir = TempDir::new().unwrap();
413        let log_path = temp_dir.path().join("test.wal");
414
415        let mut wal = WriteAheadLog::new(log_path).await.unwrap();
416
417        // Log operations before checkpoint
418        wal.log_operation(Operation::PinContent {
419            cid: "QmTest1".to_string(),
420            chunk_count: 1,
421        })
422        .await
423        .unwrap();
424        wal.log_operation(Operation::PinContent {
425            cid: "QmTest2".to_string(),
426            chunk_count: 2,
427        })
428        .await
429        .unwrap();
430
431        // Checkpoint
432        wal.checkpoint().await.unwrap();
433
434        // Log operations after checkpoint
435        wal.log_operation(Operation::PinContent {
436            cid: "QmTest3".to_string(),
437            chunk_count: 3,
438        })
439        .await
440        .unwrap();
441
442        // Should only get operations after checkpoint
443        let entries = wal.entries_since_checkpoint().await.unwrap();
444        assert_eq!(entries.len(), 1);
445        assert_eq!(entries[0].sequence, 4);
446    }
447
448    #[tokio::test]
449    async fn test_wal_persistence() {
450        let temp_dir = TempDir::new().unwrap();
451        let log_path = temp_dir.path().join("test.wal");
452
453        {
454            let mut wal = WriteAheadLog::new(log_path.clone()).await.unwrap();
455            wal.log_operation(Operation::PinContent {
456                cid: "QmPersist".to_string(),
457                chunk_count: 10,
458            })
459            .await
460            .unwrap();
461        }
462
463        // Reopen WAL
464        let wal = WriteAheadLog::new(log_path).await.unwrap();
465        assert_eq!(wal.next_sequence(), 2); // Should continue from where we left off
466
467        let entries = wal.replay().await.unwrap();
468        assert_eq!(entries.len(), 1);
469    }
470
471    #[tokio::test]
472    async fn test_wal_clear() {
473        let temp_dir = TempDir::new().unwrap();
474        let log_path = temp_dir.path().join("test.wal");
475
476        let mut wal = WriteAheadLog::new(log_path).await.unwrap();
477
478        // Log some operations
479        for i in 0..3 {
480            wal.log_operation(Operation::DeleteChunk {
481                cid: format!("QmTest{}", i),
482                chunk_index: i,
483            })
484            .await
485            .unwrap();
486        }
487
488        // Clear log
489        wal.clear().await.unwrap();
490
491        // Should be empty
492        let entries = wal.replay().await.unwrap();
493        assert_eq!(entries.len(), 0);
494        assert_eq!(wal.next_sequence(), 1);
495    }
496}