Skip to main content

memvid_core/io/
wal.rs

1use std::fs::File;
2use std::io::{Read, Seek, SeekFrom, Write};
3
4use crate::{
5    constants::{WAL_CHECKPOINT_PERIOD, WAL_CHECKPOINT_THRESHOLD},
6    error::{MemvidError, Result},
7    types::Header,
8};
9
10// Each WAL record header: [seq: u64][len: u32][reserved: 4 bytes][checksum: 32 bytes]
11const ENTRY_HEADER_SIZE: usize = 48;
12
13#[derive(Debug, Clone, Copy, PartialEq)]
14pub struct WalStats {
15    pub region_size: u64,
16    pub pending_bytes: u64,
17    pub appends_since_checkpoint: u64,
18    pub sequence: u64,
19}
20
21#[derive(Debug, Clone, PartialEq, Eq)]
22pub struct WalRecord {
23    pub sequence: u64,
24    pub payload: Vec<u8>,
25}
26
27#[derive(Debug)]
28pub struct EmbeddedWal {
29    file: File,
30    region_offset: u64,
31    region_size: u64,
32    write_head: u64,
33    checkpoint_head: u64,
34    pending_bytes: u64,
35    sequence: u64,
36    checkpoint_sequence: u64,
37    appends_since_checkpoint: u64,
38    read_only: bool,
39    skip_sync: bool,
40}
41
42impl EmbeddedWal {
43    pub fn open(file: &File, header: &Header) -> Result<Self> {
44        Self::open_internal(file, header, false)
45    }
46
47    pub fn open_read_only(file: &File, header: &Header) -> Result<Self> {
48        Self::open_internal(file, header, true)
49    }
50
51    fn open_internal(file: &File, header: &Header, read_only: bool) -> Result<Self> {
52        if header.wal_size == 0 {
53            return Err(MemvidError::InvalidHeader {
54                reason: "wal_size must be non-zero".into(),
55            });
56        }
57        let mut clone = file.try_clone()?;
58        let region_offset = header.wal_offset;
59        let region_size = header.wal_size;
60        let checkpoint_sequence = header.wal_sequence;
61
62        let (entries, next_head) = Self::scan_records(&mut clone, region_offset, region_size)?;
63
64        let pending_bytes = entries
65            .iter()
66            .filter(|entry| entry.sequence > checkpoint_sequence)
67            .map(|entry| entry.total_size)
68            .sum();
69        let sequence = entries
70            .last()
71            .map_or(checkpoint_sequence, |entry| entry.sequence);
72
73        let mut wal = Self {
74            file: clone,
75            region_offset,
76            region_size,
77            write_head: next_head % region_size,
78            checkpoint_head: header.wal_checkpoint_pos % region_size,
79            pending_bytes,
80            sequence,
81            checkpoint_sequence,
82            appends_since_checkpoint: 0,
83            read_only,
84            skip_sync: false,
85        };
86
87        if !wal.read_only {
88            wal.initialise_sentinel()?;
89        }
90        Ok(wal)
91    }
92
93    fn assert_writable(&self) -> Result<()> {
94        if self.read_only {
95            return Err(MemvidError::Lock(
96                "wal is read-only; reopen memory with write access".into(),
97            ));
98        }
99        Ok(())
100    }
101
102    pub fn append_entry(&mut self, payload: &[u8]) -> Result<u64> {
103        self.assert_writable()?;
104        let payload_len = payload.len();
105        if payload_len > u32::MAX as usize {
106            return Err(MemvidError::CheckpointFailed {
107                reason: "WAL payload too large".into(),
108            });
109        }
110
111        let entry_size = ENTRY_HEADER_SIZE as u64 + payload_len as u64;
112        if entry_size > self.region_size {
113            return Err(MemvidError::CheckpointFailed {
114                reason: "embedded WAL region too small for entry".into(),
115            });
116        }
117        if self.pending_bytes + entry_size > self.region_size {
118            return Err(MemvidError::CheckpointFailed {
119                reason: "embedded WAL region full".into(),
120            });
121        }
122
123        // Check if we need to wrap around
124        let wrapping = self.write_head + entry_size > self.region_size;
125        if wrapping {
126            // If wrapping would overwrite uncommitted data, return "WAL full" error
127            // instead of silently overwriting. This triggers WAL growth.
128            // The checkpoint_head marks where committed data starts - if we wrap and would
129            // write over any pending (uncommitted) data, we must grow the WAL instead.
130            if self.pending_bytes > 0 {
131                return Err(MemvidError::CheckpointFailed {
132                    reason: "embedded WAL region full".into(),
133                });
134            }
135            self.write_head = 0;
136        }
137
138        let next_sequence = self.sequence + 1;
139        tracing::debug!(
140            wal.write_head = self.write_head,
141            wal.sequence = next_sequence,
142            wal.payload_len = payload_len,
143            "wal append entry"
144        );
145        self.write_record(self.write_head, next_sequence, payload)?;
146
147        self.write_head = (self.write_head + entry_size) % self.region_size;
148        self.pending_bytes += entry_size;
149        self.sequence = self.sequence.wrapping_add(1);
150        self.appends_since_checkpoint = self.appends_since_checkpoint.saturating_add(1);
151
152        self.maybe_write_sentinel()?;
153
154        Ok(self.sequence)
155    }
156
157    #[must_use]
158    pub fn should_checkpoint(&self) -> bool {
159        if self.read_only || self.region_size == 0 {
160            return false;
161        }
162        let occupancy = self.pending_bytes as f64 / self.region_size as f64;
163        occupancy >= WAL_CHECKPOINT_THRESHOLD
164            || self.appends_since_checkpoint >= WAL_CHECKPOINT_PERIOD
165    }
166
167    pub fn record_checkpoint(&mut self, header: &mut Header) -> Result<()> {
168        self.assert_writable()?;
169        self.checkpoint_head = self.write_head;
170        self.pending_bytes = 0;
171        self.appends_since_checkpoint = 0;
172        self.checkpoint_sequence = self.sequence;
173        header.wal_checkpoint_pos = self.checkpoint_head;
174        header.wal_sequence = self.checkpoint_sequence;
175        self.maybe_write_sentinel()
176    }
177
178    pub fn pending_records(&mut self) -> Result<Vec<WalRecord>> {
179        self.records_after(self.checkpoint_sequence)
180    }
181
182    pub fn records_after(&mut self, sequence: u64) -> Result<Vec<WalRecord>> {
183        let (entries, next_head) =
184            Self::scan_records(&mut self.file, self.region_offset, self.region_size)?;
185
186        self.sequence = entries.last().map_or(self.sequence, |entry| entry.sequence);
187        self.pending_bytes = entries
188            .iter()
189            .filter(|entry| entry.sequence > self.checkpoint_sequence)
190            .map(|entry| entry.total_size)
191            .sum();
192        self.write_head = next_head % self.region_size;
193        if !self.read_only {
194            self.initialise_sentinel()?;
195        }
196
197        Ok(entries
198            .into_iter()
199            .filter(|entry| entry.sequence > sequence)
200            .map(|entry| WalRecord {
201                sequence: entry.sequence,
202                payload: entry.payload,
203            })
204            .collect())
205    }
206
207    #[must_use]
208    pub fn stats(&self) -> WalStats {
209        WalStats {
210            region_size: self.region_size,
211            pending_bytes: self.pending_bytes,
212            appends_since_checkpoint: self.appends_since_checkpoint,
213            sequence: self.sequence,
214        }
215    }
216
217    #[must_use]
218    pub fn region_offset(&self) -> u64 {
219        self.region_offset
220    }
221
222    #[must_use]
223    pub fn file(&self) -> &File {
224        &self.file
225    }
226
227    /// Enable or disable per-entry fsync.
228    ///
229    /// When `skip` is `true`, `write_record()` will not call `sync_all()` after
230    /// each WAL append. The caller **must** call [`flush()`](Self::flush) after
231    /// the batch to ensure durability.
232    pub fn set_skip_sync(&mut self, skip: bool) {
233        self.skip_sync = skip;
234    }
235
236    /// Force an `fsync` on the underlying WAL file.
237    ///
238    /// Call this after a batch of appends performed with `skip_sync = true`
239    /// to ensure all data is durable on disk.
240    pub fn flush(&mut self) -> Result<()> {
241        self.file.sync_all().map_err(Into::into)
242    }
243
244    fn initialise_sentinel(&mut self) -> Result<()> {
245        self.maybe_write_sentinel()
246    }
247
248    fn write_record(&mut self, position: u64, sequence: u64, payload: &[u8]) -> Result<()> {
249        self.assert_writable()?;
250        let digest = blake3::hash(payload);
251        let mut header = [0u8; ENTRY_HEADER_SIZE];
252        header[..8].copy_from_slice(&sequence.to_le_bytes());
253        header[8..12]
254            .copy_from_slice(&(u32::try_from(payload.len()).unwrap_or(u32::MAX)).to_le_bytes());
255        header[16..48].copy_from_slice(digest.as_bytes());
256
257        // Atomic write: combine header and payload into single buffer
258        // This prevents corruption if the file is closed mid-write
259        let mut combined = Vec::with_capacity(ENTRY_HEADER_SIZE + payload.len());
260        combined.extend_from_slice(&header);
261        combined.extend_from_slice(payload);
262
263        self.seek_and_write(position, &combined)?;
264        if tracing::enabled!(tracing::Level::DEBUG) {
265            if let Err(err) = self.debug_verify_header(position, sequence, payload.len()) {
266                tracing::warn!(error = %err, "wal header verify failed");
267            }
268        }
269
270        // Force fsync to ensure data is durable before returning
271        // Critical for preventing corruption during rapid file operations
272        // In batch mode (skip_sync=true), fsync is deferred to flush() for performance
273        if !self.skip_sync {
274            self.file.sync_all()?;
275        }
276
277        Ok(())
278    }
279
280    fn write_zero_header(&mut self, position: u64) -> Result<u64> {
281        self.assert_writable()?;
282        if self.region_size == 0 {
283            return Ok(0);
284        }
285        let mut pos = position % self.region_size;
286        let remaining = self.region_size - pos;
287        if remaining < ENTRY_HEADER_SIZE as u64 {
288            if remaining > 0 {
289                // Safe: remaining < ENTRY_HEADER_SIZE (48) so always fits in usize
290                #[allow(clippy::cast_possible_truncation)]
291                let zero_tail = vec![0u8; remaining as usize];
292                self.seek_and_write(pos, &zero_tail)?;
293            }
294            pos = 0;
295        }
296        let zero = [0u8; ENTRY_HEADER_SIZE];
297        self.seek_and_write(pos, &zero)?;
298        Ok(pos)
299    }
300
301    fn seek_and_write(&mut self, position: u64, bytes: &[u8]) -> Result<()> {
302        self.assert_writable()?;
303        let pos = position % self.region_size;
304        let absolute = self.region_offset + pos;
305        self.file.seek(SeekFrom::Start(absolute))?;
306        self.file.write_all(bytes)?;
307        Ok(())
308    }
309
310    fn maybe_write_sentinel(&mut self) -> Result<()> {
311        if self.read_only || self.region_size == 0 {
312            return Ok(());
313        }
314        if self.pending_bytes >= self.region_size {
315            return Ok(());
316        }
317        // Sentinel marks end of valid entries - always keep write_head in sync
318        let next = self.write_zero_header(self.write_head)?;
319        self.write_head = next;
320        Ok(())
321    }
322
323    fn scan_records(file: &mut File, offset: u64, size: u64) -> Result<(Vec<ScannedRecord>, u64)> {
324        let mut records = Vec::new();
325        let mut cursor = 0u64;
326        while cursor + ENTRY_HEADER_SIZE as u64 <= size {
327            file.seek(SeekFrom::Start(offset + cursor))?;
328            let mut header = [0u8; ENTRY_HEADER_SIZE];
329            file.read_exact(&mut header)?;
330
331            let sequence = u64::from_le_bytes(header[..8].try_into().map_err(|_| {
332                MemvidError::WalCorruption {
333                    offset: cursor,
334                    reason: "invalid wal sequence header".into(),
335                }
336            })?);
337            let length = u64::from(u32::from_le_bytes(header[8..12].try_into().map_err(
338                |_| MemvidError::WalCorruption {
339                    offset: cursor,
340                    reason: "invalid wal length header".into(),
341                },
342            )?));
343            let checksum = &header[16..48];
344
345            if sequence == 0 && length == 0 {
346                break;
347            }
348            if length == 0 || cursor + ENTRY_HEADER_SIZE as u64 + length > size {
349                tracing::error!(
350                    wal.scan_offset = cursor,
351                    wal.sequence = sequence,
352                    wal.length = length,
353                    wal.region_size = size,
354                    "wal record length invalid"
355                );
356                return Err(MemvidError::WalCorruption {
357                    offset: cursor,
358                    reason: "wal record length invalid".into(),
359                });
360            }
361
362            // Safe: length comes from u32::from_le_bytes above, so max is u32::MAX
363            // which fits in usize on all supported platforms (32-bit and 64-bit)
364            let length_usize = usize::try_from(length).map_err(|_| MemvidError::WalCorruption {
365                offset: cursor,
366                reason: "wal record length too large for platform".into(),
367            })?;
368            let mut payload = vec![0u8; length_usize];
369            file.read_exact(&mut payload)?;
370            let expected = blake3::hash(&payload);
371            if expected.as_bytes() != checksum {
372                return Err(MemvidError::WalCorruption {
373                    offset: cursor,
374                    reason: "wal record checksum mismatch".into(),
375                });
376            }
377
378            records.push(ScannedRecord {
379                sequence,
380                payload,
381                total_size: ENTRY_HEADER_SIZE as u64 + length,
382            });
383
384            cursor += ENTRY_HEADER_SIZE as u64 + length;
385        }
386
387        Ok((records, cursor))
388    }
389}
390
391#[derive(Debug)]
392struct ScannedRecord {
393    sequence: u64,
394    payload: Vec<u8>,
395    total_size: u64,
396}
397
398impl EmbeddedWal {
399    fn debug_verify_header(
400        &mut self,
401        position: u64,
402        expected_sequence: u64,
403        expected_len: usize,
404    ) -> Result<()> {
405        if self.region_size == 0 {
406            return Ok(());
407        }
408        let pos = position % self.region_size;
409        let absolute = self.region_offset + pos;
410        let mut buf = [0u8; ENTRY_HEADER_SIZE];
411        self.file.seek(SeekFrom::Start(absolute))?;
412        self.file.read_exact(&mut buf)?;
413
414        // Safe byte extraction - return early if malformed (debug function)
415        let seq = buf
416            .get(..8)
417            .and_then(|s| <[u8; 8]>::try_from(s).ok())
418            .map_or(0, u64::from_le_bytes);
419        let len = buf
420            .get(8..12)
421            .and_then(|s| <[u8; 4]>::try_from(s).ok())
422            .map_or(0, u32::from_le_bytes);
423
424        tracing::debug!(
425            wal.verify_position = pos,
426            wal.verify_sequence = seq,
427            wal.expected_sequence = expected_sequence,
428            wal.verify_length = len,
429            wal.expected_length = expected_len,
430            "wal header verify"
431        );
432        Ok(())
433    }
434}
435
436#[cfg(test)]
437mod tests {
438    use super::*;
439    use crate::constants::WAL_OFFSET;
440    use std::io::{Seek, SeekFrom, Write};
441    use tempfile::tempfile;
442
443    fn header_for(size: u64) -> Header {
444        Header {
445            magic: *b"MV2\0",
446            version: 0x0201,
447            footer_offset: 0,
448            wal_offset: WAL_OFFSET,
449            wal_size: size,
450            wal_checkpoint_pos: 0,
451            wal_sequence: 0,
452            toc_checksum: [0u8; 32],
453        }
454    }
455
456    fn prepare_wal(size: u64) -> (File, Header) {
457        let file = tempfile().expect("temp file");
458        file.set_len(WAL_OFFSET + size).expect("set_len");
459        let header = header_for(size);
460        (file, header)
461    }
462
463    #[test]
464    fn append_and_recover() {
465        let (file, header) = prepare_wal(1024);
466        let mut wal = EmbeddedWal::open(&file, &header).expect("open wal");
467
468        wal.append_entry(b"first").expect("append first");
469        wal.append_entry(b"second").expect("append second");
470
471        let records = wal.records_after(0).expect("records");
472        assert_eq!(records.len(), 2);
473        assert_eq!(records[0].payload, b"first");
474        assert_eq!(records[0].sequence, 1);
475        assert_eq!(records[1].payload, b"second");
476        assert_eq!(records[1].sequence, 2);
477    }
478
479    #[test]
480    fn wrap_and_checkpoint() {
481        let size = (ENTRY_HEADER_SIZE as u64 * 2) + 64;
482        let (file, mut header) = prepare_wal(size);
483        let mut wal = EmbeddedWal::open(&file, &header).expect("open wal");
484
485        wal.append_entry(&[0xAA; 32]).expect("append a");
486        wal.append_entry(&[0xBB; 32]).expect("append b");
487        wal.record_checkpoint(&mut header).expect("checkpoint");
488
489        assert!(wal.pending_records().expect("pending").is_empty());
490
491        wal.append_entry(&[0xCC; 32]).expect("append c");
492        let records = wal.pending_records().expect("after append");
493        assert_eq!(records.len(), 1);
494        assert_eq!(records[0].payload, vec![0xCC; 32]);
495    }
496
497    #[test]
498    fn corrupted_record_reports_offset() {
499        let (mut file, header) = prepare_wal(64);
500        // Write a record header that claims an impossible length so scan_records trips.
501        file.seek(SeekFrom::Start(header.wal_offset)).expect("seek");
502        let mut record = [0u8; ENTRY_HEADER_SIZE];
503        record[..8].copy_from_slice(&1u64.to_le_bytes()); // sequence
504        record[8..12].copy_from_slice(&(u32::MAX).to_le_bytes()); // absurd length
505        file.write_all(&record).expect("write corrupt header");
506        file.sync_all().expect("sync");
507
508        let err = EmbeddedWal::open(&file, &header).expect_err("open should fail");
509        match err {
510            MemvidError::WalCorruption { offset, reason } => {
511                assert_eq!(offset, 0);
512                assert!(reason.contains("length"), "reason should mention length");
513            }
514            other => panic!("unexpected error: {other:?}"),
515        }
516    }
517}