omendb_core/omen/
wal.rs

1//! Write-Ahead Log for crash-consistent operations
2//!
3//! Based on P-HNSW research: `NLog` (node ops) + `NlistLog` (neighbor ops)
4
5use std::fs::{File, OpenOptions};
6use std::io::{self, BufWriter, Read, Seek, SeekFrom, Write};
7use std::path::Path;
8
9/// Configure OpenOptions for cross-platform compatibility.
10/// On Windows, enables full file sharing to avoid "Access is denied" errors.
11#[cfg(windows)]
12fn configure_open_options(opts: &mut OpenOptions) {
13    use std::os::windows::fs::OpenOptionsExt;
14    // FILE_SHARE_READ | FILE_SHARE_WRITE | FILE_SHARE_DELETE
15    opts.share_mode(0x1 | 0x2 | 0x4);
16}
17
18#[cfg(not(windows))]
19fn configure_open_options(_opts: &mut OpenOptions) {
20    // No-op on Unix
21}
22
23/// WAL entry types
24#[derive(Debug, Clone, Copy, PartialEq, Eq)]
25#[repr(u8)]
26pub enum WalEntryType {
27    /// Insert a new node: {id, level, vector, metadata}
28    InsertNode = 1,
29    /// Delete a node: {id}
30    DeleteNode = 2,
31    /// Update neighbors: {id, level, [`neighbor_ids`]}
32    UpdateNeighbors = 3,
33    /// Update metadata: {id, metadata}
34    UpdateMetadata = 4,
35    /// Checkpoint marker - safe truncation point
36    Checkpoint = 100,
37}
38
39impl From<u8> for WalEntryType {
40    fn from(v: u8) -> Self {
41        match v {
42            1 => Self::InsertNode,
43            2 => Self::DeleteNode,
44            3 => Self::UpdateNeighbors,
45            4 => Self::UpdateMetadata,
46            _ => Self::Checkpoint, // Unknown entries treated as checkpoint
47        }
48    }
49}
50
51/// WAL entry header (20 bytes)
52/// Layout: `entry_type(1)` + reserved(3) + timestamp(8) + `data_len(4)` + checksum(4)
53#[derive(Debug, Clone)]
54pub struct WalEntryHeader {
55    pub entry_type: WalEntryType,
56    pub timestamp: u64, // Monotonic counter
57    pub data_len: u32,
58    pub checksum: u32,
59}
60
61impl WalEntryHeader {
62    pub const SIZE: usize = 20;
63
64    pub fn to_bytes(&self) -> [u8; Self::SIZE] {
65        let mut buf = [0u8; Self::SIZE];
66        buf[0] = self.entry_type as u8;
67        // bytes 1-3: reserved/padding
68        buf[4..12].copy_from_slice(&self.timestamp.to_le_bytes());
69        buf[12..16].copy_from_slice(&self.data_len.to_le_bytes());
70        buf[16..20].copy_from_slice(&self.checksum.to_le_bytes());
71        buf
72    }
73
74    pub fn from_bytes(buf: &[u8; Self::SIZE]) -> Self {
75        // Direct array indexing - infallible for fixed-size input buffer
76        Self {
77            entry_type: WalEntryType::from(buf[0]),
78            timestamp: u64::from_le_bytes([
79                buf[4], buf[5], buf[6], buf[7], buf[8], buf[9], buf[10], buf[11],
80            ]),
81            data_len: u32::from_le_bytes([buf[12], buf[13], buf[14], buf[15]]),
82            checksum: u32::from_le_bytes([buf[16], buf[17], buf[18], buf[19]]),
83        }
84    }
85}
86
87/// WAL entry (header + data)
88#[derive(Debug, Clone)]
89pub struct WalEntry {
90    pub header: WalEntryHeader,
91    pub data: Vec<u8>,
92}
93
94impl WalEntry {
95    /// Create insert node entry
96    #[must_use]
97    pub fn insert_node(
98        timestamp: u64,
99        string_id: &str,
100        level: u8,
101        vector: &[f32],
102        metadata: &[u8],
103    ) -> Self {
104        let mut data = Vec::new();
105
106        // String ID (length-prefixed)
107        data.extend_from_slice(&(string_id.len() as u32).to_le_bytes());
108        data.extend_from_slice(string_id.as_bytes());
109
110        // Level
111        data.push(level);
112
113        // Vector (length-prefixed f32 array)
114        data.extend_from_slice(&(vector.len() as u32).to_le_bytes());
115        for &val in vector {
116            data.extend_from_slice(&val.to_le_bytes());
117        }
118
119        // Metadata (length-prefixed)
120        data.extend_from_slice(&(metadata.len() as u32).to_le_bytes());
121        data.extend_from_slice(metadata);
122
123        let checksum = crc32fast::hash(&data);
124
125        Self {
126            header: WalEntryHeader {
127                entry_type: WalEntryType::InsertNode,
128                timestamp,
129                data_len: data.len() as u32,
130                checksum,
131            },
132            data,
133        }
134    }
135
136    /// Create delete node entry
137    #[must_use]
138    pub fn delete_node(timestamp: u64, string_id: &str) -> Self {
139        let mut data = Vec::new();
140        data.extend_from_slice(&(string_id.len() as u32).to_le_bytes());
141        data.extend_from_slice(string_id.as_bytes());
142
143        let checksum = crc32fast::hash(&data);
144
145        Self {
146            header: WalEntryHeader {
147                entry_type: WalEntryType::DeleteNode,
148                timestamp,
149                data_len: data.len() as u32,
150                checksum,
151            },
152            data,
153        }
154    }
155
156    /// Create update neighbors entry
157    #[must_use]
158    pub fn update_neighbors(timestamp: u64, node_id: u32, level: u8, neighbors: &[u32]) -> Self {
159        let mut data = Vec::new();
160
161        // Node ID
162        data.extend_from_slice(&node_id.to_le_bytes());
163
164        // Level
165        data.push(level);
166
167        // Neighbors (length-prefixed)
168        data.extend_from_slice(&(neighbors.len() as u32).to_le_bytes());
169        for &neighbor in neighbors {
170            data.extend_from_slice(&neighbor.to_le_bytes());
171        }
172
173        let checksum = crc32fast::hash(&data);
174
175        Self {
176            header: WalEntryHeader {
177                entry_type: WalEntryType::UpdateNeighbors,
178                timestamp,
179                data_len: data.len() as u32,
180                checksum,
181            },
182            data,
183        }
184    }
185
186    /// Create checkpoint entry
187    #[must_use]
188    pub fn checkpoint(timestamp: u64) -> Self {
189        Self {
190            header: WalEntryHeader {
191                entry_type: WalEntryType::Checkpoint,
192                timestamp,
193                data_len: 0,
194                checksum: 0,
195            },
196            data: Vec::new(),
197        }
198    }
199
200    /// Verify entry checksum
201    #[must_use]
202    pub fn verify(&self) -> bool {
203        if self.data.is_empty() {
204            return self.header.checksum == 0;
205        }
206        crc32fast::hash(&self.data) == self.header.checksum
207    }
208}
209
210/// Write-Ahead Log
211pub struct Wal {
212    file: BufWriter<File>,
213    #[allow(dead_code)]
214    path: std::path::PathBuf,
215    next_timestamp: u64,
216    entry_count: u64,
217}
218
219impl Wal {
220    /// Open or create WAL file
221    pub fn open(path: impl AsRef<Path>) -> io::Result<Self> {
222        let path = path.as_ref().to_path_buf();
223        let mut opts = OpenOptions::new();
224        // Use write mode instead of append for Windows compatibility
225        // (append mode on Windows may prevent truncation)
226        opts.read(true).write(true).create(true);
227        configure_open_options(&mut opts);
228        let mut file = opts.open(&path)?;
229
230        let metadata = file.metadata()?;
231        let file_len = metadata.len();
232
233        // Seek to end for append-like behavior
234        if file_len > 0 {
235            file.seek(SeekFrom::End(0))?;
236        }
237
238        let mut wal = Self {
239            file: BufWriter::new(file),
240            path,
241            next_timestamp: 0,
242            entry_count: 0,
243        };
244
245        // Scan to find last timestamp
246        if file_len > 0 {
247            wal.scan_for_timestamp()?;
248        }
249
250        Ok(wal)
251    }
252
253    /// Scan WAL to find highest timestamp
254    fn scan_for_timestamp(&mut self) -> io::Result<()> {
255        let file = self.file.get_mut();
256        file.seek(SeekFrom::Start(0))?;
257
258        let mut header_buf = [0u8; WalEntryHeader::SIZE];
259        let mut max_timestamp = 0u64;
260        let mut count = 0u64;
261
262        loop {
263            match file.read_exact(&mut header_buf) {
264                Ok(()) => {
265                    let header = WalEntryHeader::from_bytes(&header_buf);
266                    max_timestamp = max_timestamp.max(header.timestamp);
267                    count += 1;
268
269                    // Skip data
270                    if header.data_len > 0 {
271                        file.seek(SeekFrom::Current(header.data_len as i64))?;
272                    }
273                }
274                Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => break,
275                Err(e) => return Err(e),
276            }
277        }
278
279        self.next_timestamp = max_timestamp + 1;
280        self.entry_count = count;
281
282        // Seek to end for appending
283        file.seek(SeekFrom::End(0))?;
284
285        Ok(())
286    }
287
288    /// Append entry to WAL
289    pub fn append(&mut self, mut entry: WalEntry) -> io::Result<()> {
290        entry.header.timestamp = self.next_timestamp;
291        self.next_timestamp += 1;
292
293        self.file.write_all(&entry.header.to_bytes())?;
294        if !entry.data.is_empty() {
295            self.file.write_all(&entry.data)?;
296        }
297
298        self.entry_count += 1;
299        Ok(())
300    }
301
302    /// Flush WAL to disk
303    pub fn sync(&mut self) -> io::Result<()> {
304        self.file.flush()?;
305        self.file.get_mut().sync_all()
306    }
307
308    /// Read all entries after last checkpoint
309    pub fn entries_after_checkpoint(&mut self) -> io::Result<Vec<WalEntry>> {
310        let file = self.file.get_mut();
311        file.seek(SeekFrom::Start(0))?;
312
313        let mut all_entries = Vec::new();
314        let mut last_checkpoint_idx: Option<usize> = None;
315        let mut header_buf = [0u8; WalEntryHeader::SIZE];
316
317        loop {
318            match file.read_exact(&mut header_buf) {
319                Ok(()) => {
320                    let header = WalEntryHeader::from_bytes(&header_buf);
321                    let mut data = vec![0u8; header.data_len as usize];
322                    if header.data_len > 0 {
323                        file.read_exact(&mut data)?;
324                    }
325
326                    let entry = WalEntry { header, data };
327
328                    if entry.header.entry_type == WalEntryType::Checkpoint {
329                        last_checkpoint_idx = Some(all_entries.len());
330                    }
331
332                    all_entries.push(entry);
333                }
334                Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => break,
335                Err(e) => return Err(e),
336            }
337        }
338
339        // Return entries after last checkpoint
340        match last_checkpoint_idx {
341            Some(idx) => Ok(all_entries.split_off(idx + 1)),
342            None => Ok(all_entries),
343        }
344    }
345
346    /// Get entry count
347    #[must_use]
348    pub fn len(&self) -> u64 {
349        self.entry_count
350    }
351
352    /// Check if WAL is empty
353    #[must_use]
354    pub fn is_empty(&self) -> bool {
355        self.entry_count == 0
356    }
357
358    /// Truncate WAL (after checkpoint)
359    pub fn truncate(&mut self) -> io::Result<()> {
360        // Flush buffer before truncating (required on Windows)
361        self.file.flush()?;
362        self.file.get_mut().set_len(0)?;
363        self.file.get_mut().seek(SeekFrom::Start(0))?;
364        self.next_timestamp = 0;
365        self.entry_count = 0;
366        Ok(())
367    }
368}
369
370#[cfg(test)]
371mod tests {
372    use super::*;
373    use tempfile::tempdir;
374
375    #[test]
376    fn test_wal_roundtrip() {
377        let dir = tempdir().unwrap();
378        let wal_path = dir.path().join("test.wal");
379
380        {
381            let mut wal = Wal::open(&wal_path).unwrap();
382            wal.append(WalEntry::insert_node(0, "vec1", 0, &[1.0, 2.0, 3.0], b"{}"))
383                .unwrap();
384            wal.append(WalEntry::delete_node(0, "vec2")).unwrap();
385            wal.append(WalEntry::checkpoint(0)).unwrap();
386            wal.append(WalEntry::insert_node(0, "vec3", 1, &[4.0, 5.0, 6.0], b"{}"))
387                .unwrap();
388            wal.sync().unwrap();
389        }
390
391        {
392            let mut wal = Wal::open(&wal_path).unwrap();
393            let entries = wal.entries_after_checkpoint().unwrap();
394
395            // Should only have entries after checkpoint
396            assert_eq!(entries.len(), 1);
397            assert_eq!(entries[0].header.entry_type, WalEntryType::InsertNode);
398        }
399    }
400
401    #[test]
402    fn test_entry_checksum() {
403        let entry = WalEntry::insert_node(1, "test", 0, &[1.0, 2.0], b"metadata");
404        assert!(entry.verify());
405    }
406
407    #[test]
408    fn test_corrupted_entry_data_detected() {
409        let mut entry = WalEntry::insert_node(1, "test", 0, &[1.0, 2.0], b"metadata");
410        assert!(entry.verify());
411
412        // Corrupt the data
413        if !entry.data.is_empty() {
414            entry.data[0] ^= 0xFF;
415        }
416
417        // Verify should now fail
418        assert!(!entry.verify(), "Corrupted data should fail verification");
419    }
420
421    #[test]
422    fn test_corrupted_entry_checksum_detected() {
423        let mut entry = WalEntry::insert_node(1, "test", 0, &[1.0, 2.0], b"metadata");
424        assert!(entry.verify());
425
426        // Corrupt the checksum
427        entry.header.checksum ^= 0xFFFF_FFFF;
428
429        // Verify should now fail
430        assert!(
431            !entry.verify(),
432            "Corrupted checksum should fail verification"
433        );
434    }
435
436    #[test]
437    fn test_wal_recovery_skips_corrupted_entries() {
438        use std::io::Write;
439
440        let dir = tempdir().unwrap();
441        let wal_path = dir.path().join("test_corrupt.wal");
442
443        // Write valid entries
444        {
445            let mut wal = Wal::open(&wal_path).unwrap();
446            wal.append(WalEntry::insert_node(0, "vec1", 0, &[1.0, 2.0, 3.0], b"{}"))
447                .unwrap();
448            wal.append(WalEntry::insert_node(0, "vec2", 0, &[4.0, 5.0, 6.0], b"{}"))
449                .unwrap();
450            wal.sync().unwrap();
451        }
452
453        // Corrupt the middle of the file (corrupt second entry's data)
454        {
455            let mut file = OpenOptions::new()
456                .read(true)
457                .write(true)
458                .open(&wal_path)
459                .unwrap();
460
461            // Skip first entry header + data, then write garbage to second entry data
462            // First entry: header(20) + data(~50 bytes for vec1)
463            // Just corrupt some bytes in the middle of the file
464            file.seek(SeekFrom::Start(40)).unwrap();
465            file.write_all(&[0xFF, 0xFF, 0xFF, 0xFF]).unwrap();
466            file.sync_all().unwrap();
467        }
468
469        // Read entries - corrupted entries should fail verify()
470        {
471            let mut wal = Wal::open(&wal_path).unwrap();
472            let entries = wal.entries_after_checkpoint().unwrap();
473
474            // At least one entry should fail verification
475            let invalid_count = entries.iter().filter(|e| !e.verify()).count();
476            assert!(
477                invalid_count > 0,
478                "Expected at least one corrupted entry, got none"
479            );
480
481            // Valid entries should still verify correctly
482            let valid_count = entries.iter().filter(|e| e.verify()).count();
483            // At least the structure should be readable (may have 0-2 valid entries
484            // depending on exact corruption location)
485            assert!(
486                valid_count + invalid_count == entries.len(),
487                "All entries should be either valid or invalid"
488            );
489        }
490    }
491}