Skip to main content

memvid_core/io/
wal.rs

1use std::fs::File;
2use std::io::{Read, Seek, SeekFrom, Write};
3
4use crate::{
5    constants::{WAL_CHECKPOINT_PERIOD, WAL_CHECKPOINT_THRESHOLD},
6    error::{MemvidError, Result},
7    types::Header,
8};
9
10// Each WAL record header: [seq: u64][len: u32][reserved: 4 bytes][checksum: 32 bytes]
11const ENTRY_HEADER_SIZE: usize = 48;
12
13#[derive(Debug, Clone, Copy, PartialEq)]
14pub struct WalStats {
15    pub region_size: u64,
16    pub pending_bytes: u64,
17    pub appends_since_checkpoint: u64,
18    pub sequence: u64,
19}
20
21#[derive(Debug, Clone, PartialEq, Eq)]
22pub struct WalRecord {
23    pub sequence: u64,
24    pub payload: Vec<u8>,
25}
26
27#[derive(Debug)]
28pub struct EmbeddedWal {
29    file: File,
30    region_offset: u64,
31    region_size: u64,
32    write_head: u64,
33    checkpoint_head: u64,
34    pending_bytes: u64,
35    sequence: u64,
36    checkpoint_sequence: u64,
37    appends_since_checkpoint: u64,
38    read_only: bool,
39}
40
41impl EmbeddedWal {
42    pub fn open(file: &File, header: &Header) -> Result<Self> {
43        Self::open_internal(file, header, false)
44    }
45
46    pub fn open_read_only(file: &File, header: &Header) -> Result<Self> {
47        Self::open_internal(file, header, true)
48    }
49
50    fn open_internal(file: &File, header: &Header, read_only: bool) -> Result<Self> {
51        if header.wal_size == 0 {
52            return Err(MemvidError::InvalidHeader {
53                reason: "wal_size must be non-zero".into(),
54            });
55        }
56        let mut clone = file.try_clone()?;
57        let region_offset = header.wal_offset;
58        let region_size = header.wal_size;
59        let checkpoint_sequence = header.wal_sequence;
60
61        let (entries, next_head) = Self::scan_records(&mut clone, region_offset, region_size)?;
62
63        let pending_bytes = entries
64            .iter()
65            .filter(|entry| entry.sequence > checkpoint_sequence)
66            .map(|entry| entry.total_size)
67            .sum();
68        let sequence = entries
69            .last()
70            .map_or(checkpoint_sequence, |entry| entry.sequence);
71
72        let mut wal = Self {
73            file: clone,
74            region_offset,
75            region_size,
76            write_head: next_head % region_size,
77            checkpoint_head: header.wal_checkpoint_pos % region_size,
78            pending_bytes,
79            sequence,
80            checkpoint_sequence,
81            appends_since_checkpoint: 0,
82            read_only,
83        };
84
85        if !wal.read_only {
86            wal.initialise_sentinel()?;
87        }
88        Ok(wal)
89    }
90
91    fn assert_writable(&self) -> Result<()> {
92        if self.read_only {
93            return Err(MemvidError::Lock(
94                "wal is read-only; reopen memory with write access".into(),
95            ));
96        }
97        Ok(())
98    }
99
100    pub fn append_entry(&mut self, payload: &[u8]) -> Result<u64> {
101        self.assert_writable()?;
102        let payload_len = payload.len();
103        if payload_len > u32::MAX as usize {
104            return Err(MemvidError::CheckpointFailed {
105                reason: "WAL payload too large".into(),
106            });
107        }
108
109        let entry_size = ENTRY_HEADER_SIZE as u64 + payload_len as u64;
110        if entry_size > self.region_size {
111            return Err(MemvidError::CheckpointFailed {
112                reason: "embedded WAL region too small for entry".into(),
113            });
114        }
115        if self.pending_bytes + entry_size > self.region_size {
116            return Err(MemvidError::CheckpointFailed {
117                reason: "embedded WAL region full".into(),
118            });
119        }
120
121        // Check if we need to wrap around
122        let wrapping = self.write_head + entry_size > self.region_size;
123        if wrapping {
124            // If wrapping would overwrite uncommitted data, return "WAL full" error
125            // instead of silently overwriting. This triggers WAL growth.
126            // The checkpoint_head marks where committed data starts - if we wrap and would
127            // write over any pending (uncommitted) data, we must grow the WAL instead.
128            if self.pending_bytes > 0 {
129                return Err(MemvidError::CheckpointFailed {
130                    reason: "embedded WAL region full".into(),
131                });
132            }
133            self.write_head = 0;
134        }
135
136        let next_sequence = self.sequence + 1;
137        tracing::debug!(
138            wal.write_head = self.write_head,
139            wal.sequence = next_sequence,
140            wal.payload_len = payload_len,
141            "wal append entry"
142        );
143        self.write_record(self.write_head, next_sequence, payload)?;
144
145        self.write_head = (self.write_head + entry_size) % self.region_size;
146        self.pending_bytes += entry_size;
147        self.sequence = self.sequence.wrapping_add(1);
148        self.appends_since_checkpoint = self.appends_since_checkpoint.saturating_add(1);
149
150        self.maybe_write_sentinel()?;
151
152        Ok(self.sequence)
153    }
154
155    #[must_use]
156    pub fn should_checkpoint(&self) -> bool {
157        if self.read_only || self.region_size == 0 {
158            return false;
159        }
160        let occupancy = self.pending_bytes as f64 / self.region_size as f64;
161        occupancy >= WAL_CHECKPOINT_THRESHOLD
162            || self.appends_since_checkpoint >= WAL_CHECKPOINT_PERIOD
163    }
164
165    pub fn record_checkpoint(&mut self, header: &mut Header) -> Result<()> {
166        self.assert_writable()?;
167        self.checkpoint_head = self.write_head;
168        self.pending_bytes = 0;
169        self.appends_since_checkpoint = 0;
170        self.checkpoint_sequence = self.sequence;
171        header.wal_checkpoint_pos = self.checkpoint_head;
172        header.wal_sequence = self.checkpoint_sequence;
173        self.maybe_write_sentinel()
174    }
175
176    pub fn pending_records(&mut self) -> Result<Vec<WalRecord>> {
177        self.records_after(self.checkpoint_sequence)
178    }
179
180    pub fn records_after(&mut self, sequence: u64) -> Result<Vec<WalRecord>> {
181        let (entries, next_head) =
182            Self::scan_records(&mut self.file, self.region_offset, self.region_size)?;
183
184        self.sequence = entries.last().map_or(self.sequence, |entry| entry.sequence);
185        self.pending_bytes = entries
186            .iter()
187            .filter(|entry| entry.sequence > self.checkpoint_sequence)
188            .map(|entry| entry.total_size)
189            .sum();
190        self.write_head = next_head % self.region_size;
191        if !self.read_only {
192            self.initialise_sentinel()?;
193        }
194
195        Ok(entries
196            .into_iter()
197            .filter(|entry| entry.sequence > sequence)
198            .map(|entry| WalRecord {
199                sequence: entry.sequence,
200                payload: entry.payload,
201            })
202            .collect())
203    }
204
205    #[must_use]
206    pub fn stats(&self) -> WalStats {
207        WalStats {
208            region_size: self.region_size,
209            pending_bytes: self.pending_bytes,
210            appends_since_checkpoint: self.appends_since_checkpoint,
211            sequence: self.sequence,
212        }
213    }
214
215    #[must_use]
216    pub fn region_offset(&self) -> u64 {
217        self.region_offset
218    }
219
220    #[must_use]
221    pub fn file(&self) -> &File {
222        &self.file
223    }
224
225    fn initialise_sentinel(&mut self) -> Result<()> {
226        self.maybe_write_sentinel()
227    }
228
229    fn write_record(&mut self, position: u64, sequence: u64, payload: &[u8]) -> Result<()> {
230        self.assert_writable()?;
231        let digest = blake3::hash(payload);
232        let mut header = [0u8; ENTRY_HEADER_SIZE];
233        header[..8].copy_from_slice(&sequence.to_le_bytes());
234        header[8..12]
235            .copy_from_slice(&(u32::try_from(payload.len()).unwrap_or(u32::MAX)).to_le_bytes());
236        header[16..48].copy_from_slice(digest.as_bytes());
237
238        // Atomic write: combine header and payload into single buffer
239        // This prevents corruption if the file is closed mid-write
240        let mut combined = Vec::with_capacity(ENTRY_HEADER_SIZE + payload.len());
241        combined.extend_from_slice(&header);
242        combined.extend_from_slice(payload);
243
244        self.seek_and_write(position, &combined)?;
245        if tracing::enabled!(tracing::Level::DEBUG) {
246            if let Err(err) = self.debug_verify_header(position, sequence, payload.len()) {
247                tracing::warn!(error = %err, "wal header verify failed");
248            }
249        }
250
251        // Force fsync to ensure data is durable before returning
252        // Critical for preventing corruption during rapid file operations
253        self.file.sync_all()?;
254
255        Ok(())
256    }
257
258    fn write_zero_header(&mut self, position: u64) -> Result<u64> {
259        self.assert_writable()?;
260        if self.region_size == 0 {
261            return Ok(0);
262        }
263        let mut pos = position % self.region_size;
264        let remaining = self.region_size - pos;
265        if remaining < ENTRY_HEADER_SIZE as u64 {
266            if remaining > 0 {
267                // Safe: remaining < ENTRY_HEADER_SIZE (48) so always fits in usize
268                #[allow(clippy::cast_possible_truncation)]
269                let zero_tail = vec![0u8; remaining as usize];
270                self.seek_and_write(pos, &zero_tail)?;
271            }
272            pos = 0;
273        }
274        let zero = [0u8; ENTRY_HEADER_SIZE];
275        self.seek_and_write(pos, &zero)?;
276        Ok(pos)
277    }
278
279    fn seek_and_write(&mut self, position: u64, bytes: &[u8]) -> Result<()> {
280        self.assert_writable()?;
281        let pos = position % self.region_size;
282        let absolute = self.region_offset + pos;
283        self.file.seek(SeekFrom::Start(absolute))?;
284        self.file.write_all(bytes)?;
285        Ok(())
286    }
287
288    fn maybe_write_sentinel(&mut self) -> Result<()> {
289        if self.read_only || self.region_size == 0 {
290            return Ok(());
291        }
292        if self.pending_bytes >= self.region_size {
293            return Ok(());
294        }
295        // Sentinel marks end of valid entries - always keep write_head in sync
296        let next = self.write_zero_header(self.write_head)?;
297        self.write_head = next;
298        Ok(())
299    }
300
301    fn scan_records(file: &mut File, offset: u64, size: u64) -> Result<(Vec<ScannedRecord>, u64)> {
302        let mut records = Vec::new();
303        let mut cursor = 0u64;
304        while cursor + ENTRY_HEADER_SIZE as u64 <= size {
305            file.seek(SeekFrom::Start(offset + cursor))?;
306            let mut header = [0u8; ENTRY_HEADER_SIZE];
307            file.read_exact(&mut header)?;
308
309            let sequence = u64::from_le_bytes(header[..8].try_into().map_err(|_| {
310                MemvidError::WalCorruption {
311                    offset: cursor,
312                    reason: "invalid wal sequence header".into(),
313                }
314            })?);
315            let length = u64::from(u32::from_le_bytes(header[8..12].try_into().map_err(
316                |_| MemvidError::WalCorruption {
317                    offset: cursor,
318                    reason: "invalid wal length header".into(),
319                },
320            )?));
321            let checksum = &header[16..48];
322
323            if sequence == 0 && length == 0 {
324                break;
325            }
326            if length == 0 || cursor + ENTRY_HEADER_SIZE as u64 + length > size {
327                tracing::error!(
328                    wal.scan_offset = cursor,
329                    wal.sequence = sequence,
330                    wal.length = length,
331                    wal.region_size = size,
332                    "wal record length invalid"
333                );
334                return Err(MemvidError::WalCorruption {
335                    offset: cursor,
336                    reason: "wal record length invalid".into(),
337                });
338            }
339
340            // Safe: length comes from u32::from_le_bytes above, so max is u32::MAX
341            // which fits in usize on all supported platforms (32-bit and 64-bit)
342            let length_usize = usize::try_from(length).map_err(|_| MemvidError::WalCorruption {
343                offset: cursor,
344                reason: "wal record length too large for platform".into(),
345            })?;
346            let mut payload = vec![0u8; length_usize];
347            file.read_exact(&mut payload)?;
348            let expected = blake3::hash(&payload);
349            if expected.as_bytes() != checksum {
350                return Err(MemvidError::WalCorruption {
351                    offset: cursor,
352                    reason: "wal record checksum mismatch".into(),
353                });
354            }
355
356            records.push(ScannedRecord {
357                sequence,
358                payload,
359                total_size: ENTRY_HEADER_SIZE as u64 + length,
360            });
361
362            cursor += ENTRY_HEADER_SIZE as u64 + length;
363        }
364
365        Ok((records, cursor))
366    }
367}
368
369#[derive(Debug)]
370struct ScannedRecord {
371    sequence: u64,
372    payload: Vec<u8>,
373    total_size: u64,
374}
375
376impl EmbeddedWal {
377    fn debug_verify_header(
378        &mut self,
379        position: u64,
380        expected_sequence: u64,
381        expected_len: usize,
382    ) -> Result<()> {
383        if self.region_size == 0 {
384            return Ok(());
385        }
386        let pos = position % self.region_size;
387        let absolute = self.region_offset + pos;
388        let mut buf = [0u8; ENTRY_HEADER_SIZE];
389        self.file.seek(SeekFrom::Start(absolute))?;
390        self.file.read_exact(&mut buf)?;
391
392        // Safe byte extraction - return early if malformed (debug function)
393        let seq = buf
394            .get(..8)
395            .and_then(|s| <[u8; 8]>::try_from(s).ok())
396            .map_or(0, u64::from_le_bytes);
397        let len = buf
398            .get(8..12)
399            .and_then(|s| <[u8; 4]>::try_from(s).ok())
400            .map_or(0, u32::from_le_bytes);
401
402        tracing::debug!(
403            wal.verify_position = pos,
404            wal.verify_sequence = seq,
405            wal.expected_sequence = expected_sequence,
406            wal.verify_length = len,
407            wal.expected_length = expected_len,
408            "wal header verify"
409        );
410        Ok(())
411    }
412}
413
414#[cfg(test)]
415mod tests {
416    use super::*;
417    use crate::constants::WAL_OFFSET;
418    use std::io::{Seek, SeekFrom, Write};
419    use tempfile::tempfile;
420
421    fn header_for(size: u64) -> Header {
422        Header {
423            magic: *b"MV2\0",
424            version: 0x0201,
425            footer_offset: 0,
426            wal_offset: WAL_OFFSET,
427            wal_size: size,
428            wal_checkpoint_pos: 0,
429            wal_sequence: 0,
430            toc_checksum: [0u8; 32],
431        }
432    }
433
434    fn prepare_wal(size: u64) -> (File, Header) {
435        let file = tempfile().expect("temp file");
436        file.set_len(WAL_OFFSET + size).expect("set_len");
437        let header = header_for(size);
438        (file, header)
439    }
440
441    #[test]
442    fn append_and_recover() {
443        let (file, header) = prepare_wal(1024);
444        let mut wal = EmbeddedWal::open(&file, &header).expect("open wal");
445
446        wal.append_entry(b"first").expect("append first");
447        wal.append_entry(b"second").expect("append second");
448
449        let records = wal.records_after(0).expect("records");
450        assert_eq!(records.len(), 2);
451        assert_eq!(records[0].payload, b"first");
452        assert_eq!(records[0].sequence, 1);
453        assert_eq!(records[1].payload, b"second");
454        assert_eq!(records[1].sequence, 2);
455    }
456
457    #[test]
458    fn wrap_and_checkpoint() {
459        let size = (ENTRY_HEADER_SIZE as u64 * 2) + 64;
460        let (file, mut header) = prepare_wal(size);
461        let mut wal = EmbeddedWal::open(&file, &header).expect("open wal");
462
463        wal.append_entry(&[0xAA; 32]).expect("append a");
464        wal.append_entry(&[0xBB; 32]).expect("append b");
465        wal.record_checkpoint(&mut header).expect("checkpoint");
466
467        assert!(wal.pending_records().expect("pending").is_empty());
468
469        wal.append_entry(&[0xCC; 32]).expect("append c");
470        let records = wal.pending_records().expect("after append");
471        assert_eq!(records.len(), 1);
472        assert_eq!(records[0].payload, vec![0xCC; 32]);
473    }
474
475    #[test]
476    fn corrupted_record_reports_offset() {
477        let (mut file, header) = prepare_wal(64);
478        // Write a record header that claims an impossible length so scan_records trips.
479        file.seek(SeekFrom::Start(header.wal_offset)).expect("seek");
480        let mut record = [0u8; ENTRY_HEADER_SIZE];
481        record[..8].copy_from_slice(&1u64.to_le_bytes()); // sequence
482        record[8..12].copy_from_slice(&(u32::MAX).to_le_bytes()); // absurd length
483        file.write_all(&record).expect("write corrupt header");
484        file.sync_all().expect("sync");
485
486        let err = EmbeddedWal::open(&file, &header).expect_err("open should fail");
487        match err {
488            MemvidError::WalCorruption { offset, reason } => {
489                assert_eq!(offset, 0);
490                assert!(reason.contains("length"), "reason should mention length");
491            }
492            other => panic!("unexpected error: {other:?}"),
493        }
494    }
495}