Skip to main content

edgevec/persistence/
wal.rs

1use crate::persistence::entry::WalEntry;
2use crate::persistence::storage::StorageBackend;
3use crate::persistence::PersistenceError;
4use crc32fast::Hasher;
5use std::io::{self, Read};
6use thiserror::Error;
7
8/// Header size in bytes: sequence(8) + type(1) + pad(3) + len(4)
9pub const WAL_HEADER_SIZE: usize = 16;
10
11/// CRC checksum size in bytes
12pub const CRC_SIZE: usize = 4;
13
14/// Maximum allowed payload size (16MB) to prevent `DoS` attacks.
15pub const MAX_PAYLOAD_SIZE: usize = 16 * 1024 * 1024;
16
17/// Errors that can occur during WAL iteration.
18#[derive(Debug, Error)]
19pub enum WalError {
20    /// I/O error reading from the WAL.
21    #[error("io error: {0}")]
22    Io(#[from] io::Error),
23
24    /// Persistence error.
25    #[error("persistence error: {0}")]
26    Persistence(#[from] PersistenceError),
27
28    /// CRC32 checksum mismatch.
29    #[error("checksum mismatch: expected {expected:#010x}, got {actual:#010x}")]
30    ChecksumMismatch {
31        /// The expected checksum read from the file.
32        expected: u32,
33        /// The actual calculated checksum.
34        actual: u32,
35    },
36
37    /// File ended unexpectedly (truncated).
38    #[error("file truncated: expected {expected} bytes, got {actual}")]
39    Truncated {
40        /// Number of bytes expected to read.
41        expected: usize,
42        /// Number of bytes actually read.
43        actual: usize,
44    },
45
46    /// Payload size exceeds maximum allowed limit.
47    #[error("payload too large: size {size} exceeds max {max}")]
48    PayloadTooLarge {
49        /// The requested payload size.
50        size: usize,
51        /// The maximum allowed size.
52        max: usize,
53    },
54}
55
56/// Iterator over entries in a Write-Ahead Log.
57///
58/// Reads strictly sequentially. Does not load the whole file into memory.
59pub struct WalIterator<R> {
60    reader: R,
61}
62
63impl<R: Read> WalIterator<R> {
64    /// Creates a new `WalIterator` wrapping the given reader.
65    pub fn new(reader: R) -> Self {
66        Self { reader }
67    }
68
69    /// Helper to read exact bytes or detect clean EOF.
70    /// Returns `Ok(true)` if buffer filled, `Ok(false)` if clean EOF (0 bytes read),
71    /// `Err(Truncated)` if partial read, or `Err(Io)`.
72    fn read_exact_or_eof(&mut self, buf: &mut [u8]) -> Result<bool, WalError> {
73        let mut total_read = 0;
74        while total_read < buf.len() {
75            match self.reader.read(&mut buf[total_read..]) {
76                Ok(0) => {
77                    if total_read == 0 {
78                        return Ok(false); // Clean EOF
79                    }
80                    return Err(WalError::Truncated {
81                        expected: buf.len(),
82                        actual: total_read,
83                    });
84                }
85                Ok(n) => total_read += n,
86                Err(ref e) if e.kind() == io::ErrorKind::Interrupted => {}
87                Err(e) => return Err(WalError::Io(e)),
88            }
89        }
90        Ok(true)
91    }
92}
93
94impl<R: Read> Iterator for WalIterator<R> {
95    type Item = Result<(WalEntry, Vec<u8>), WalError>;
96
97    fn next(&mut self) -> Option<Self::Item> {
98        // 1. Read Header
99        let mut header_bytes = [0u8; WAL_HEADER_SIZE];
100        match self.read_exact_or_eof(&mut header_bytes) {
101            Ok(true) => {}            // Got header
102            Ok(false) => return None, // Clean EOF
103            Err(e) => return Some(Err(e)),
104        }
105
106        // 2. Parse WalEntry (Manual Little-Endian Deserialization)
107        // Layout: sequence(8) + type(1) + pad(3) + len(4)
108        // SAFETY: header_bytes is exactly 16 bytes (WAL_ENTRY_SIZE), so slicing
109        // [0..8] yields exactly 8 bytes and [12..16] yields exactly 4 bytes.
110        // The try_into() conversions are infallible for these exact-length slices.
111        let sequence = u64::from_le_bytes(
112            header_bytes[0..8]
113                .try_into()
114                .expect("slice is exactly 8 bytes from 16-byte header"),
115        );
116        let entry_type = header_bytes[8];
117        // Bytes 9..12 are padding, ignore them
118        let payload_len_u32 = u32::from_le_bytes(
119            header_bytes[12..16]
120                .try_into()
121                .expect("slice is exactly 4 bytes from 16-byte header"),
122        );
123
124        let entry = WalEntry::new(sequence, entry_type, payload_len_u32);
125
126        let payload_len = entry.payload_len as usize;
127
128        // Validation: Check payload size to prevent DoS (OOM)
129        if payload_len > MAX_PAYLOAD_SIZE {
130            return Some(Err(WalError::PayloadTooLarge {
131                size: payload_len,
132                max: MAX_PAYLOAD_SIZE,
133            }));
134        }
135
136        // 3. Read Payload
137        let mut payload = vec![0u8; payload_len];
138        match self.read_exact_or_eof(&mut payload) {
139            Ok(true) => {}
140            Ok(false) => {
141                return Some(Err(WalError::Truncated {
142                    expected: payload_len,
143                    actual: 0,
144                }))
145            }
146            Err(e) => return Some(Err(e)),
147        }
148
149        // 4. Read CRC
150        let mut crc_bytes = [0u8; CRC_SIZE];
151        match self.read_exact_or_eof(&mut crc_bytes) {
152            Ok(true) => {}
153            Ok(false) => {
154                return Some(Err(WalError::Truncated {
155                    expected: CRC_SIZE,
156                    actual: 0,
157                }))
158            }
159            Err(e) => return Some(Err(e)),
160        }
161        let stored_crc = u32::from_le_bytes(crc_bytes);
162
163        // 5. Compute CRC (Header + Payload)
164        // Note: CRC MUST be computed on the serialized bytes (header_bytes),
165        // which matches how they are written.
166        let mut hasher = Hasher::new();
167        hasher.update(&header_bytes);
168        hasher.update(&payload);
169        let calculated_crc = hasher.finalize();
170
171        // 6. Verify CRC
172        if calculated_crc != stored_crc {
173            return Some(Err(WalError::ChecksumMismatch {
174                expected: stored_crc,
175                actual: calculated_crc,
176            }));
177        }
178
179        // 7. Return Success
180        Some(Ok((entry, payload)))
181    }
182}
183
184/// Appends entries to the Write-Ahead Log.
185pub struct WalAppender {
186    backend: Box<dyn StorageBackend>,
187    next_sequence: u64,
188}
189
190impl WalAppender {
191    /// Creates a new `WalAppender` starting at the given sequence number.
192    #[must_use]
193    pub fn new(backend: Box<dyn StorageBackend>, next_sequence: u64) -> Self {
194        Self {
195            backend,
196            next_sequence,
197        }
198    }
199
200    /// Appends a new entry to the WAL.
201    ///
202    /// # Arguments
203    ///
204    /// * `entry_type` - Type of the entry (0=insert, 1=delete, etc.)
205    /// * `payload` - The data to store.
206    ///
207    /// # Errors
208    ///
209    /// Returns `WalError::Io` if writing fails, or `PayloadTooLarge` if payload exceeds limit.
210    pub fn append(&mut self, entry_type: u8, payload: &[u8]) -> Result<(), WalError> {
211        let payload_len = payload.len();
212        if payload_len > MAX_PAYLOAD_SIZE {
213            return Err(WalError::PayloadTooLarge {
214                size: payload_len,
215                max: MAX_PAYLOAD_SIZE,
216            });
217        }
218
219        // SAFETY: MAX_PAYLOAD_SIZE is 16MB, which fits in u32.
220        #[allow(clippy::cast_possible_truncation)]
221        let payload_len_u32 = payload_len as u32;
222
223        let entry_sequence = self.next_sequence;
224        self.next_sequence += 1;
225
226        // Serialize Header (Manual Little-Endian)
227        let mut header_bytes = [0u8; WAL_HEADER_SIZE];
228        header_bytes[0..8].copy_from_slice(&entry_sequence.to_le_bytes());
229        header_bytes[8] = entry_type;
230        header_bytes[9..12].fill(0); // Zero padding
231        header_bytes[12..16].copy_from_slice(&payload_len_u32.to_le_bytes());
232
233        // Calculate CRC on serialized header + payload
234        let mut hasher = Hasher::new();
235        hasher.update(&header_bytes);
236        hasher.update(payload);
237        let crc = hasher.finalize();
238
239        // Combine for single write (backend handles atomicity/sync)
240        let mut buffer = Vec::with_capacity(WAL_HEADER_SIZE + payload_len + CRC_SIZE);
241        buffer.extend_from_slice(&header_bytes);
242        buffer.extend_from_slice(payload);
243        buffer.extend_from_slice(&crc.to_le_bytes());
244
245        self.backend.append(&buffer)?;
246
247        Ok(())
248    }
249
250    /// Flushes the underlying writer to ensure durability.
251    ///
252    /// This is now a no-op as `StorageBackend::append` implies durability.
253    /// Retained for API compatibility.
254    ///
255    /// # Errors
256    ///
257    /// Returns `WalError::Io` if flushing fails.
258    pub fn sync(&mut self) -> Result<(), WalError> {
259        // Backend append handles sync.
260        Ok(())
261    }
262}
263
264#[cfg(test)]
265mod tests {
266    use super::*;
267    use std::mem::{align_of, size_of};
268
269    #[test]
270    fn test_wal_constants() {
271        assert_eq!(WAL_HEADER_SIZE, 16);
272        assert_eq!(CRC_SIZE, 4);
273    }
274
275    #[test]
276    fn test_wal_entry_layout() {
277        assert_eq!(size_of::<WalEntry>(), WAL_HEADER_SIZE);
278        assert_eq!(align_of::<WalEntry>(), 8);
279    }
280
281    #[test]
282    fn test_wal_replay_integrity() {
283        use crate::persistence::storage::{MemoryBackend, StorageBackend};
284        use std::io::Cursor;
285
286        // 1. Setup Backend
287        let memory = MemoryBackend::new();
288        let backend = Box::new(memory.clone());
289
290        // 2. Write 100 entries
291        let mut appender = WalAppender::new(backend, 0);
292        #[allow(clippy::cast_sign_loss)]
293        for i in 0..100_i32 {
294            let payload = (i as u32).to_le_bytes(); // 4 bytes payload
295            appender.append(0, &payload).expect("append failed");
296        }
297
298        // 3. "Reopen" / Replay
299        let read_backend = Box::new(memory); // New Box, same underlying data
300        let data = read_backend.read().expect("read failed");
301
302        let cursor = Cursor::new(data);
303        let iterator = WalIterator::new(cursor);
304
305        let mut count = 0;
306        #[allow(clippy::cast_possible_truncation)]
307        for (i, result) in iterator.enumerate() {
308            let (entry, payload) = result.expect("replay failed");
309            assert_eq!(entry.sequence, i as u64);
310            assert_eq!(entry.entry_type, 0);
311
312            // Verify payload
313            let expected_payload = (i as u32).to_le_bytes();
314            assert_eq!(payload, expected_payload);
315
316            count += 1;
317        }
318
319        assert_eq!(count, 100);
320    }
321}