memvid_core/io/
wal.rs

1use std::fs::File;
2use std::io::{Read, Seek, SeekFrom, Write};
3
4use crate::{
5    constants::{WAL_CHECKPOINT_PERIOD, WAL_CHECKPOINT_THRESHOLD},
6    error::{MemvidError, Result},
7    types::Header,
8};
9
10// Each WAL record header: [seq: u64][len: u32][reserved: 4 bytes][checksum: 32 bytes]
11const ENTRY_HEADER_SIZE: usize = 48;
12
13#[derive(Debug, Clone, Copy, PartialEq)]
14pub struct WalStats {
15    pub region_size: u64,
16    pub pending_bytes: u64,
17    pub appends_since_checkpoint: u64,
18    pub sequence: u64,
19}
20
21#[derive(Debug, Clone, PartialEq, Eq)]
22pub struct WalRecord {
23    pub sequence: u64,
24    pub payload: Vec<u8>,
25}
26
27#[derive(Debug)]
28pub struct EmbeddedWal {
29    file: File,
30    region_offset: u64,
31    region_size: u64,
32    write_head: u64,
33    checkpoint_head: u64,
34    pending_bytes: u64,
35    sequence: u64,
36    checkpoint_sequence: u64,
37    appends_since_checkpoint: u64,
38    read_only: bool,
39}
40
41impl EmbeddedWal {
42    pub fn open(file: &File, header: &Header) -> Result<Self> {
43        Self::open_internal(file, header, false)
44    }
45
46    pub fn open_read_only(file: &File, header: &Header) -> Result<Self> {
47        Self::open_internal(file, header, true)
48    }
49
50    fn open_internal(file: &File, header: &Header, read_only: bool) -> Result<Self> {
51        if header.wal_size == 0 {
52            return Err(MemvidError::InvalidHeader {
53                reason: "wal_size must be non-zero".into(),
54            });
55        }
56        let mut clone = file.try_clone()?;
57        let region_offset = header.wal_offset;
58        let region_size = header.wal_size;
59        let checkpoint_sequence = header.wal_sequence;
60
61        let (entries, next_head) = Self::scan_records(&mut clone, region_offset, region_size)?;
62
63        let pending_bytes = entries
64            .iter()
65            .filter(|entry| entry.sequence > checkpoint_sequence)
66            .map(|entry| entry.total_size)
67            .sum();
68        let sequence = entries
69            .last()
70            .map(|entry| entry.sequence)
71            .unwrap_or(checkpoint_sequence);
72
73        let mut wal = Self {
74            file: clone,
75            region_offset,
76            region_size,
77            write_head: next_head % region_size,
78            checkpoint_head: header.wal_checkpoint_pos % region_size,
79            pending_bytes,
80            sequence,
81            checkpoint_sequence,
82            appends_since_checkpoint: 0,
83            read_only,
84        };
85
86        if !wal.read_only {
87            wal.initialise_sentinel()?;
88        }
89        Ok(wal)
90    }
91
92    fn assert_writable(&self) -> Result<()> {
93        if self.read_only {
94            return Err(MemvidError::Lock(
95                "wal is read-only; reopen memory with write access".into(),
96            ));
97        }
98        Ok(())
99    }
100
101    pub fn append_entry(&mut self, payload: &[u8]) -> Result<u64> {
102        self.assert_writable()?;
103        let payload_len = payload.len();
104        if payload_len > u32::MAX as usize {
105            return Err(MemvidError::CheckpointFailed {
106                reason: "WAL payload too large".into(),
107            });
108        }
109
110        let entry_size = ENTRY_HEADER_SIZE as u64 + payload_len as u64;
111        if entry_size > self.region_size {
112            return Err(MemvidError::CheckpointFailed {
113                reason: "embedded WAL region too small for entry".into(),
114            });
115        }
116        if self.pending_bytes + entry_size > self.region_size {
117            return Err(MemvidError::CheckpointFailed {
118                reason: "embedded WAL region full".into(),
119            });
120        }
121
122        if self.write_head + entry_size > self.region_size {
123            self.write_head = 0;
124        }
125
126        let next_sequence = self.sequence + 1;
127        tracing::debug!(
128            wal.write_head = self.write_head,
129            wal.sequence = next_sequence,
130            wal.payload_len = payload_len,
131            "wal append entry"
132        );
133        self.write_record(self.write_head, next_sequence, payload)?;
134
135        self.write_head = (self.write_head + entry_size) % self.region_size;
136        self.pending_bytes += entry_size;
137        self.sequence = self.sequence.wrapping_add(1);
138        self.appends_since_checkpoint = self.appends_since_checkpoint.saturating_add(1);
139
140        self.maybe_write_sentinel()?;
141
142        Ok(self.sequence)
143    }
144
145    pub fn should_checkpoint(&self) -> bool {
146        if self.read_only || self.region_size == 0 {
147            return false;
148        }
149        let occupancy = self.pending_bytes as f64 / self.region_size as f64;
150        occupancy >= WAL_CHECKPOINT_THRESHOLD
151            || self.appends_since_checkpoint >= WAL_CHECKPOINT_PERIOD
152    }
153
154    pub fn record_checkpoint(&mut self, header: &mut Header) -> Result<()> {
155        self.assert_writable()?;
156        self.checkpoint_head = self.write_head;
157        self.pending_bytes = 0;
158        self.appends_since_checkpoint = 0;
159        self.checkpoint_sequence = self.sequence;
160        header.wal_checkpoint_pos = self.checkpoint_head;
161        header.wal_sequence = self.checkpoint_sequence;
162        self.maybe_write_sentinel()
163    }
164
165    pub fn pending_records(&mut self) -> Result<Vec<WalRecord>> {
166        self.records_after(self.checkpoint_sequence)
167    }
168
169    pub fn records_after(&mut self, sequence: u64) -> Result<Vec<WalRecord>> {
170        let (entries, next_head) =
171            Self::scan_records(&mut self.file, self.region_offset, self.region_size)?;
172
173        self.sequence = entries
174            .last()
175            .map(|entry| entry.sequence)
176            .unwrap_or(self.sequence);
177        self.pending_bytes = entries
178            .iter()
179            .filter(|entry| entry.sequence > self.checkpoint_sequence)
180            .map(|entry| entry.total_size)
181            .sum();
182        self.write_head = next_head % self.region_size;
183        if !self.read_only {
184            self.initialise_sentinel()?;
185        }
186
187        Ok(entries
188            .into_iter()
189            .filter(|entry| entry.sequence > sequence)
190            .map(|entry| WalRecord {
191                sequence: entry.sequence,
192                payload: entry.payload,
193            })
194            .collect())
195    }
196
197    pub fn stats(&self) -> WalStats {
198        WalStats {
199            region_size: self.region_size,
200            pending_bytes: self.pending_bytes,
201            appends_since_checkpoint: self.appends_since_checkpoint,
202            sequence: self.sequence,
203        }
204    }
205
206    pub fn region_offset(&self) -> u64 {
207        self.region_offset
208    }
209
210    pub fn file(&self) -> &File {
211        &self.file
212    }
213
214    fn initialise_sentinel(&mut self) -> Result<()> {
215        self.maybe_write_sentinel()
216    }
217
218    fn write_record(&mut self, position: u64, sequence: u64, payload: &[u8]) -> Result<()> {
219        self.assert_writable()?;
220        let digest = blake3::hash(payload);
221        let mut header = [0u8; ENTRY_HEADER_SIZE];
222        header[..8].copy_from_slice(&sequence.to_le_bytes());
223        header[8..12].copy_from_slice(&(payload.len() as u32).to_le_bytes());
224        header[16..48].copy_from_slice(digest.as_bytes());
225
226        // Atomic write: combine header and payload into single buffer
227        // This prevents corruption if the file is closed mid-write
228        let mut combined = Vec::with_capacity(ENTRY_HEADER_SIZE + payload.len());
229        combined.extend_from_slice(&header);
230        combined.extend_from_slice(payload);
231
232        self.seek_and_write(position, &combined)?;
233        if tracing::enabled!(tracing::Level::DEBUG) {
234            if let Err(err) = self.debug_verify_header(position, sequence, payload.len()) {
235                tracing::warn!(error = %err, "wal header verify failed");
236            }
237        }
238
239        // Force fsync to ensure data is durable before returning
240        // Critical for preventing corruption during rapid file operations
241        self.file.sync_all()?;
242
243        Ok(())
244    }
245
246    fn write_zero_header(&mut self, position: u64) -> Result<u64> {
247        self.assert_writable()?;
248        if self.region_size == 0 {
249            return Ok(0);
250        }
251        let mut pos = position % self.region_size;
252        let remaining = self.region_size - pos;
253        if remaining < ENTRY_HEADER_SIZE as u64 {
254            if remaining > 0 {
255                let zero_tail = vec![0u8; remaining as usize];
256                self.seek_and_write(pos, &zero_tail)?;
257            }
258            pos = 0;
259        }
260        let zero = [0u8; ENTRY_HEADER_SIZE];
261        self.seek_and_write(pos, &zero)?;
262        Ok(pos)
263    }
264
265    fn seek_and_write(&mut self, position: u64, bytes: &[u8]) -> Result<()> {
266        self.assert_writable()?;
267        let pos = position % self.region_size;
268        let absolute = self.region_offset + pos;
269        self.file.seek(SeekFrom::Start(absolute))?;
270        self.file.write_all(bytes)?;
271        Ok(())
272    }
273
274    fn maybe_write_sentinel(&mut self) -> Result<()> {
275        if self.read_only || self.region_size == 0 {
276            return Ok(());
277        }
278        if self.pending_bytes >= self.region_size {
279            return Ok(());
280        }
281        // Sentinel marks end of valid entries - always keep write_head in sync
282        let next = self.write_zero_header(self.write_head)?;
283        self.write_head = next;
284        Ok(())
285    }
286
287    fn scan_records(file: &mut File, offset: u64, size: u64) -> Result<(Vec<ScannedRecord>, u64)> {
288        let mut records = Vec::new();
289        let mut cursor = 0u64;
290        while cursor + ENTRY_HEADER_SIZE as u64 <= size {
291            file.seek(SeekFrom::Start(offset + cursor))?;
292            let mut header = [0u8; ENTRY_HEADER_SIZE];
293            file.read_exact(&mut header)?;
294
295            let sequence = u64::from_le_bytes(header[..8].try_into().map_err(|_| {
296                MemvidError::WalCorruption {
297                    offset: cursor,
298                    reason: "invalid wal sequence header".into(),
299                }
300            })?);
301            let length = u32::from_le_bytes(header[8..12].try_into().map_err(|_| {
302                MemvidError::WalCorruption {
303                    offset: cursor,
304                    reason: "invalid wal length header".into(),
305                }
306            })?) as u64;
307            let checksum = &header[16..48];
308
309            if sequence == 0 && length == 0 {
310                break;
311            }
312            if length == 0 || cursor + ENTRY_HEADER_SIZE as u64 + length > size {
313                tracing::error!(
314                    wal.scan_offset = cursor,
315                    wal.sequence = sequence,
316                    wal.length = length,
317                    wal.region_size = size,
318                    "wal record length invalid"
319                );
320                return Err(MemvidError::WalCorruption {
321                    offset: cursor,
322                    reason: "wal record length invalid".into(),
323                });
324            }
325
326            let mut payload = vec![0u8; length as usize];
327            file.read_exact(&mut payload)?;
328            let expected = blake3::hash(&payload);
329            if expected.as_bytes() != checksum {
330                return Err(MemvidError::WalCorruption {
331                    offset: cursor,
332                    reason: "wal record checksum mismatch".into(),
333                });
334            }
335
336            records.push(ScannedRecord {
337                sequence,
338                payload,
339                total_size: ENTRY_HEADER_SIZE as u64 + length,
340            });
341
342            cursor += ENTRY_HEADER_SIZE as u64 + length;
343        }
344
345        Ok((records, cursor))
346    }
347}
348
349#[derive(Debug)]
350struct ScannedRecord {
351    sequence: u64,
352    payload: Vec<u8>,
353    total_size: u64,
354}
355
356impl EmbeddedWal {
357    fn debug_verify_header(
358        &mut self,
359        position: u64,
360        expected_sequence: u64,
361        expected_len: usize,
362    ) -> Result<()> {
363        if self.region_size == 0 {
364            return Ok(());
365        }
366        let pos = position % self.region_size;
367        let absolute = self.region_offset + pos;
368        let mut buf = [0u8; ENTRY_HEADER_SIZE];
369        self.file.seek(SeekFrom::Start(absolute))?;
370        self.file.read_exact(&mut buf)?;
371        let seq = u64::from_le_bytes(buf[..8].try_into().unwrap());
372        let len = u32::from_le_bytes(buf[8..12].try_into().unwrap());
373        tracing::debug!(
374            wal.verify_position = pos,
375            wal.verify_sequence = seq,
376            wal.expected_sequence = expected_sequence,
377            wal.verify_length = len,
378            wal.expected_length = expected_len,
379            "wal header verify"
380        );
381        Ok(())
382    }
383}
384
385#[cfg(test)]
386mod tests {
387    use super::*;
388    use crate::constants::WAL_OFFSET;
389    use std::io::{Seek, SeekFrom, Write};
390    use tempfile::tempfile;
391
392    fn header_for(size: u64) -> Header {
393        Header {
394            magic: *b"MV2\0",
395            version: 0x0201,
396            footer_offset: 0,
397            wal_offset: WAL_OFFSET,
398            wal_size: size,
399            wal_checkpoint_pos: 0,
400            wal_sequence: 0,
401            toc_checksum: [0u8; 32],
402        }
403    }
404
405    fn prepare_wal(size: u64) -> (File, Header) {
406        let file = tempfile().expect("temp file");
407        file.set_len(WAL_OFFSET + size).expect("set_len");
408        let header = header_for(size);
409        (file, header)
410    }
411
412    #[test]
413    fn append_and_recover() {
414        let (file, header) = prepare_wal(1024);
415        let mut wal = EmbeddedWal::open(&file, &header).expect("open wal");
416
417        wal.append_entry(b"first").expect("append first");
418        wal.append_entry(b"second").expect("append second");
419
420        let records = wal.records_after(0).expect("records");
421        assert_eq!(records.len(), 2);
422        assert_eq!(records[0].payload, b"first");
423        assert_eq!(records[0].sequence, 1);
424        assert_eq!(records[1].payload, b"second");
425        assert_eq!(records[1].sequence, 2);
426    }
427
428    #[test]
429    fn wrap_and_checkpoint() {
430        let size = (ENTRY_HEADER_SIZE as u64 * 2) + 64;
431        let (file, mut header) = prepare_wal(size);
432        let mut wal = EmbeddedWal::open(&file, &header).expect("open wal");
433
434        wal.append_entry(&vec![0xAA; 32]).expect("append a");
435        wal.append_entry(&vec![0xBB; 32]).expect("append b");
436        wal.record_checkpoint(&mut header).expect("checkpoint");
437
438        assert!(wal.pending_records().expect("pending").is_empty());
439
440        wal.append_entry(&vec![0xCC; 32]).expect("append c");
441        let records = wal.pending_records().expect("after append");
442        assert_eq!(records.len(), 1);
443        assert_eq!(records[0].payload, vec![0xCC; 32]);
444    }
445
446    #[test]
447    fn corrupted_record_reports_offset() {
448        let (mut file, header) = prepare_wal(64);
449        // Write a record header that claims an impossible length so scan_records trips.
450        file.seek(SeekFrom::Start(header.wal_offset)).expect("seek");
451        let mut record = [0u8; ENTRY_HEADER_SIZE];
452        record[..8].copy_from_slice(&1u64.to_le_bytes()); // sequence
453        record[8..12].copy_from_slice(&(u32::MAX).to_le_bytes()); // absurd length
454        file.write_all(&record).expect("write corrupt header");
455        file.sync_all().expect("sync");
456
457        let err = EmbeddedWal::open(&file, &header).expect_err("open should fail");
458        match err {
459            MemvidError::WalCorruption { offset, reason } => {
460                assert_eq!(offset, 0);
461                assert!(reason.contains("length"), "reason should mention length");
462            }
463            other => panic!("unexpected error: {other:?}"),
464        }
465    }
466}