omendb_core/omen/
wal.rs

1//! Write-Ahead Log for crash-consistent operations
2//!
3//! Based on P-HNSW research: `NLog` (node ops) + `NlistLog` (neighbor ops)
4
5use std::fs::{File, OpenOptions};
6use std::io::{self, BufWriter, Read, Seek, SeekFrom, Write};
7use std::path::Path;
8
9/// Configure OpenOptions for cross-platform compatibility.
10/// On Windows, enables full file sharing to avoid "Access is denied" errors.
11#[cfg(windows)]
12fn configure_open_options(opts: &mut OpenOptions) {
13    use std::os::windows::fs::OpenOptionsExt;
14    // FILE_SHARE_READ | FILE_SHARE_WRITE | FILE_SHARE_DELETE
15    opts.share_mode(0x1 | 0x2 | 0x4);
16}
17
18#[cfg(not(windows))]
19fn configure_open_options(_opts: &mut OpenOptions) {
20    // No-op on Unix
21}
22
23/// WAL entry types
24#[derive(Debug, Clone, Copy, PartialEq, Eq)]
25#[repr(u8)]
26pub enum WalEntryType {
27    /// Insert a new node: {id, level, vector, metadata}
28    InsertNode = 1,
29    /// Delete a node: {id}
30    DeleteNode = 2,
31    /// Update neighbors: {id, level, [`neighbor_ids`]}
32    UpdateNeighbors = 3,
33    /// Update metadata: {id, metadata}
34    UpdateMetadata = 4,
35    /// Checkpoint marker - safe truncation point
36    Checkpoint = 100,
37}
38
39impl WalEntryType {
40    /// Try to parse a WAL entry type from a byte
41    ///
42    /// Returns None for unknown entry types (which should be skipped during recovery)
43    #[must_use]
44    pub fn from_byte(v: u8) -> Option<Self> {
45        match v {
46            1 => Some(Self::InsertNode),
47            2 => Some(Self::DeleteNode),
48            3 => Some(Self::UpdateNeighbors),
49            4 => Some(Self::UpdateMetadata),
50            100 => Some(Self::Checkpoint),
51            _ => None, // Unknown entry type - caller should handle
52        }
53    }
54}
55
56impl From<u8> for WalEntryType {
57    fn from(v: u8) -> Self {
58        // For backwards compatibility, unknown entries are treated as checkpoint
59        // But prefer using from_byte() which returns Option
60        Self::from_byte(v).unwrap_or(Self::Checkpoint)
61    }
62}
63
64/// WAL entry header (20 bytes)
65/// Layout: `entry_type(1)` + reserved(3) + timestamp(8) + `data_len(4)` + checksum(4)
66#[derive(Debug, Clone)]
67pub struct WalEntryHeader {
68    pub entry_type: WalEntryType,
69    pub timestamp: u64, // Monotonic counter
70    pub data_len: u32,
71    pub checksum: u32,
72}
73
74impl WalEntryHeader {
75    pub const SIZE: usize = 20;
76
77    pub fn to_bytes(&self) -> [u8; Self::SIZE] {
78        let mut buf = [0u8; Self::SIZE];
79        buf[0] = self.entry_type as u8;
80        // bytes 1-3: reserved/padding
81        buf[4..12].copy_from_slice(&self.timestamp.to_le_bytes());
82        buf[12..16].copy_from_slice(&self.data_len.to_le_bytes());
83        buf[16..20].copy_from_slice(&self.checksum.to_le_bytes());
84        buf
85    }
86
87    pub fn from_bytes(buf: &[u8; Self::SIZE]) -> Self {
88        // Direct array indexing - infallible for fixed-size input buffer
89        Self {
90            entry_type: WalEntryType::from(buf[0]),
91            timestamp: u64::from_le_bytes([
92                buf[4], buf[5], buf[6], buf[7], buf[8], buf[9], buf[10], buf[11],
93            ]),
94            data_len: u32::from_le_bytes([buf[12], buf[13], buf[14], buf[15]]),
95            checksum: u32::from_le_bytes([buf[16], buf[17], buf[18], buf[19]]),
96        }
97    }
98}
99
100/// WAL entry (header + data)
101#[derive(Debug, Clone)]
102pub struct WalEntry {
103    pub header: WalEntryHeader,
104    pub data: Vec<u8>,
105}
106
107impl WalEntry {
108    /// Create insert node entry
109    #[must_use]
110    pub fn insert_node(
111        timestamp: u64,
112        string_id: &str,
113        level: u8,
114        vector: &[f32],
115        metadata: &[u8],
116    ) -> Self {
117        let mut data = Vec::new();
118
119        // String ID (length-prefixed)
120        data.extend_from_slice(&(string_id.len() as u32).to_le_bytes());
121        data.extend_from_slice(string_id.as_bytes());
122
123        // Level
124        data.push(level);
125
126        // Vector (length-prefixed f32 array)
127        data.extend_from_slice(&(vector.len() as u32).to_le_bytes());
128        for &val in vector {
129            data.extend_from_slice(&val.to_le_bytes());
130        }
131
132        // Metadata (length-prefixed)
133        data.extend_from_slice(&(metadata.len() as u32).to_le_bytes());
134        data.extend_from_slice(metadata);
135
136        let checksum = crc32fast::hash(&data);
137
138        Self {
139            header: WalEntryHeader {
140                entry_type: WalEntryType::InsertNode,
141                timestamp,
142                data_len: data.len() as u32,
143                checksum,
144            },
145            data,
146        }
147    }
148
149    /// Create delete node entry
150    #[must_use]
151    pub fn delete_node(timestamp: u64, string_id: &str) -> Self {
152        let mut data = Vec::new();
153        data.extend_from_slice(&(string_id.len() as u32).to_le_bytes());
154        data.extend_from_slice(string_id.as_bytes());
155
156        let checksum = crc32fast::hash(&data);
157
158        Self {
159            header: WalEntryHeader {
160                entry_type: WalEntryType::DeleteNode,
161                timestamp,
162                data_len: data.len() as u32,
163                checksum,
164            },
165            data,
166        }
167    }
168
169    /// Create update neighbors entry
170    #[must_use]
171    pub fn update_neighbors(timestamp: u64, node_id: u32, level: u8, neighbors: &[u32]) -> Self {
172        let mut data = Vec::new();
173
174        // Node ID
175        data.extend_from_slice(&node_id.to_le_bytes());
176
177        // Level
178        data.push(level);
179
180        // Neighbors (length-prefixed)
181        data.extend_from_slice(&(neighbors.len() as u32).to_le_bytes());
182        for &neighbor in neighbors {
183            data.extend_from_slice(&neighbor.to_le_bytes());
184        }
185
186        let checksum = crc32fast::hash(&data);
187
188        Self {
189            header: WalEntryHeader {
190                entry_type: WalEntryType::UpdateNeighbors,
191                timestamp,
192                data_len: data.len() as u32,
193                checksum,
194            },
195            data,
196        }
197    }
198
199    /// Create checkpoint entry
200    #[must_use]
201    pub fn checkpoint(timestamp: u64) -> Self {
202        Self {
203            header: WalEntryHeader {
204                entry_type: WalEntryType::Checkpoint,
205                timestamp,
206                data_len: 0,
207                checksum: 0,
208            },
209            data: Vec::new(),
210        }
211    }
212
213    /// Verify entry checksum
214    #[must_use]
215    pub fn verify(&self) -> bool {
216        if self.data.is_empty() {
217            return self.header.checksum == 0;
218        }
219        crc32fast::hash(&self.data) == self.header.checksum
220    }
221}
222
223/// Write-Ahead Log
224pub struct Wal {
225    file: BufWriter<File>,
226    #[allow(dead_code)]
227    path: std::path::PathBuf,
228    next_timestamp: u64,
229    entry_count: u64,
230}
231
232impl Wal {
233    /// Open or create WAL file
234    pub fn open(path: impl AsRef<Path>) -> io::Result<Self> {
235        let path = path.as_ref().to_path_buf();
236        let mut opts = OpenOptions::new();
237        // Use write mode instead of append for Windows compatibility
238        // (append mode on Windows may prevent truncation)
239        opts.read(true).write(true).create(true);
240        configure_open_options(&mut opts);
241        let mut file = opts.open(&path)?;
242
243        let metadata = file.metadata()?;
244        let file_len = metadata.len();
245
246        // Seek to end for append-like behavior
247        if file_len > 0 {
248            file.seek(SeekFrom::End(0))?;
249        }
250
251        let mut wal = Self {
252            file: BufWriter::new(file),
253            path,
254            next_timestamp: 0,
255            entry_count: 0,
256        };
257
258        // Scan to find last timestamp
259        if file_len > 0 {
260            wal.scan_for_timestamp()?;
261        }
262
263        Ok(wal)
264    }
265
266    /// Scan WAL to find highest timestamp
267    fn scan_for_timestamp(&mut self) -> io::Result<()> {
268        let file = self.file.get_mut();
269        file.seek(SeekFrom::Start(0))?;
270
271        let mut header_buf = [0u8; WalEntryHeader::SIZE];
272        let mut max_timestamp = 0u64;
273        let mut count = 0u64;
274
275        // Maximum reasonable entry size (100MB) - protects against corrupted data_len
276        const MAX_ENTRY_SIZE: u32 = 100 * 1024 * 1024;
277
278        loop {
279            match file.read_exact(&mut header_buf) {
280                Ok(()) => {
281                    let header = WalEntryHeader::from_bytes(&header_buf);
282
283                    // Sanity check: reject obviously corrupted entries
284                    if header.data_len > MAX_ENTRY_SIZE {
285                        return Err(io::Error::new(
286                            io::ErrorKind::InvalidData,
287                            format!(
288                                "WAL entry has suspicious data_len: {} bytes (max: {})",
289                                header.data_len, MAX_ENTRY_SIZE
290                            ),
291                        ));
292                    }
293
294                    max_timestamp = max_timestamp.max(header.timestamp);
295                    count += 1;
296
297                    // Skip data
298                    if header.data_len > 0 {
299                        file.seek(SeekFrom::Current(header.data_len as i64))?;
300                    }
301                }
302                Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => break,
303                Err(e) => return Err(e),
304            }
305        }
306
307        self.next_timestamp = max_timestamp + 1;
308        self.entry_count = count;
309
310        // Seek to end for appending
311        file.seek(SeekFrom::End(0))?;
312
313        Ok(())
314    }
315
316    /// Append entry to WAL
317    pub fn append(&mut self, mut entry: WalEntry) -> io::Result<()> {
318        entry.header.timestamp = self.next_timestamp;
319        self.next_timestamp += 1;
320
321        self.file.write_all(&entry.header.to_bytes())?;
322        if !entry.data.is_empty() {
323            self.file.write_all(&entry.data)?;
324        }
325
326        self.entry_count += 1;
327        Ok(())
328    }
329
330    /// Flush WAL to disk
331    pub fn sync(&mut self) -> io::Result<()> {
332        self.file.flush()?;
333        self.file.get_mut().sync_all()
334    }
335
336    /// Read all entries after last checkpoint
337    ///
338    /// Note: Entries are validated via checksum. Invalid entries are skipped.
339    /// Unknown entry types are also skipped (not treated as checkpoints).
340    pub fn entries_after_checkpoint(&mut self) -> io::Result<Vec<WalEntry>> {
341        let file = self.file.get_mut();
342        file.seek(SeekFrom::Start(0))?;
343
344        let mut all_entries = Vec::new();
345        let mut last_checkpoint_idx: Option<usize> = None;
346        let mut header_buf = [0u8; WalEntryHeader::SIZE];
347
348        // Maximum reasonable entry size (100MB)
349        const MAX_ENTRY_SIZE: u32 = 100 * 1024 * 1024;
350
351        loop {
352            match file.read_exact(&mut header_buf) {
353                Ok(()) => {
354                    let header = WalEntryHeader::from_bytes(&header_buf);
355
356                    // Sanity check on data_len
357                    if header.data_len > MAX_ENTRY_SIZE {
358                        return Err(io::Error::new(
359                            io::ErrorKind::InvalidData,
360                            format!(
361                                "WAL entry has suspicious data_len: {} bytes",
362                                header.data_len
363                            ),
364                        ));
365                    }
366
367                    let mut data = vec![0u8; header.data_len as usize];
368                    if header.data_len > 0 {
369                        file.read_exact(&mut data)?;
370                    }
371
372                    let entry = WalEntry { header, data };
373
374                    // Skip entries that fail checksum verification
375                    if !entry.verify() {
376                        continue;
377                    }
378
379                    // Only count as checkpoint if it's actually a valid checkpoint entry type
380                    // (not an unknown entry type that defaulted to Checkpoint via From<u8>)
381                    let entry_type_byte = entry.header.entry_type as u8;
382                    if WalEntryType::from_byte(entry_type_byte) == Some(WalEntryType::Checkpoint) {
383                        last_checkpoint_idx = Some(all_entries.len());
384                    }
385
386                    all_entries.push(entry);
387                }
388                Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => break,
389                Err(e) => return Err(e),
390            }
391        }
392
393        // Return entries after last checkpoint
394        match last_checkpoint_idx {
395            Some(idx) => Ok(all_entries.split_off(idx + 1)),
396            None => Ok(all_entries),
397        }
398    }
399
400    /// Get entry count
401    #[must_use]
402    pub fn len(&self) -> u64 {
403        self.entry_count
404    }
405
406    /// Check if WAL is empty
407    #[must_use]
408    pub fn is_empty(&self) -> bool {
409        self.entry_count == 0
410    }
411
412    /// Truncate WAL (after checkpoint)
413    pub fn truncate(&mut self) -> io::Result<()> {
414        // Flush buffer before truncating (required on Windows)
415        self.file.flush()?;
416        self.file.get_mut().set_len(0)?;
417        self.file.get_mut().seek(SeekFrom::Start(0))?;
418        self.next_timestamp = 0;
419        self.entry_count = 0;
420        Ok(())
421    }
422}
423
424#[cfg(test)]
425mod tests {
426    use super::*;
427    use tempfile::tempdir;
428
429    #[test]
430    fn test_wal_roundtrip() {
431        let dir = tempdir().unwrap();
432        let wal_path = dir.path().join("test.wal");
433
434        {
435            let mut wal = Wal::open(&wal_path).unwrap();
436            wal.append(WalEntry::insert_node(0, "vec1", 0, &[1.0, 2.0, 3.0], b"{}"))
437                .unwrap();
438            wal.append(WalEntry::delete_node(0, "vec2")).unwrap();
439            wal.append(WalEntry::checkpoint(0)).unwrap();
440            wal.append(WalEntry::insert_node(0, "vec3", 1, &[4.0, 5.0, 6.0], b"{}"))
441                .unwrap();
442            wal.sync().unwrap();
443        }
444
445        {
446            let mut wal = Wal::open(&wal_path).unwrap();
447            let entries = wal.entries_after_checkpoint().unwrap();
448
449            // Should only have entries after checkpoint
450            assert_eq!(entries.len(), 1);
451            assert_eq!(entries[0].header.entry_type, WalEntryType::InsertNode);
452        }
453    }
454
455    #[test]
456    fn test_entry_checksum() {
457        let entry = WalEntry::insert_node(1, "test", 0, &[1.0, 2.0], b"metadata");
458        assert!(entry.verify());
459    }
460
461    #[test]
462    fn test_corrupted_entry_data_detected() {
463        let mut entry = WalEntry::insert_node(1, "test", 0, &[1.0, 2.0], b"metadata");
464        assert!(entry.verify());
465
466        // Corrupt the data
467        if !entry.data.is_empty() {
468            entry.data[0] ^= 0xFF;
469        }
470
471        // Verify should now fail
472        assert!(!entry.verify(), "Corrupted data should fail verification");
473    }
474
475    #[test]
476    fn test_corrupted_entry_checksum_detected() {
477        let mut entry = WalEntry::insert_node(1, "test", 0, &[1.0, 2.0], b"metadata");
478        assert!(entry.verify());
479
480        // Corrupt the checksum
481        entry.header.checksum ^= 0xFFFF_FFFF;
482
483        // Verify should now fail
484        assert!(
485            !entry.verify(),
486            "Corrupted checksum should fail verification"
487        );
488    }
489
490    #[test]
491    fn test_wal_recovery_skips_corrupted_entries() {
492        use std::io::Write;
493
494        let dir = tempdir().unwrap();
495        let wal_path = dir.path().join("test_corrupt.wal");
496
497        // Write valid entries
498        {
499            let mut wal = Wal::open(&wal_path).unwrap();
500            wal.append(WalEntry::insert_node(0, "vec1", 0, &[1.0, 2.0, 3.0], b"{}"))
501                .unwrap();
502            wal.append(WalEntry::insert_node(0, "vec2", 0, &[4.0, 5.0, 6.0], b"{}"))
503                .unwrap();
504            wal.sync().unwrap();
505        }
506
507        // Corrupt the middle of the file (corrupt first entry's data)
508        {
509            let mut file = OpenOptions::new()
510                .read(true)
511                .write(true)
512                .open(&wal_path)
513                .unwrap();
514
515            // Corrupt bytes in the first entry's data section (after header)
516            // Header is 20 bytes, so corrupt data at offset 25
517            file.seek(SeekFrom::Start(25)).unwrap();
518            file.write_all(&[0xFF, 0xFF, 0xFF, 0xFF]).unwrap();
519            file.sync_all().unwrap();
520        }
521
522        // Read entries - corrupted entries should be SKIPPED (not returned)
523        {
524            let mut wal = Wal::open(&wal_path).unwrap();
525            let entries = wal.entries_after_checkpoint().unwrap();
526
527            // All returned entries should pass verification (corrupted ones are skipped)
528            for entry in &entries {
529                assert!(entry.verify(), "All returned entries should be valid");
530            }
531
532            // We started with 2 entries, at least one should have been corrupted and skipped
533            // The exact count depends on how corruption affects parsing
534            assert!(
535                entries.len() <= 2,
536                "Should have at most 2 entries after skipping corrupted ones"
537            );
538        }
539    }
540}