1use std::fs::File;
2use std::io::{Read, Seek, SeekFrom, Write};
3
4use crate::{
5 constants::{WAL_CHECKPOINT_PERIOD, WAL_CHECKPOINT_THRESHOLD},
6 error::{MemvidError, Result},
7 types::Header,
8};
9
10const ENTRY_HEADER_SIZE: usize = 48;
12
13#[derive(Debug, Clone, Copy, PartialEq)]
14pub struct WalStats {
15 pub region_size: u64,
16 pub pending_bytes: u64,
17 pub appends_since_checkpoint: u64,
18 pub sequence: u64,
19}
20
21#[derive(Debug, Clone, PartialEq, Eq)]
22pub struct WalRecord {
23 pub sequence: u64,
24 pub payload: Vec<u8>,
25}
26
27#[derive(Debug)]
28pub struct EmbeddedWal {
29 file: File,
30 region_offset: u64,
31 region_size: u64,
32 write_head: u64,
33 checkpoint_head: u64,
34 pending_bytes: u64,
35 sequence: u64,
36 checkpoint_sequence: u64,
37 appends_since_checkpoint: u64,
38 read_only: bool,
39 skip_sync: bool,
40}
41
42impl EmbeddedWal {
43 pub fn open(file: &File, header: &Header) -> Result<Self> {
44 Self::open_internal(file, header, false)
45 }
46
47 pub fn open_read_only(file: &File, header: &Header) -> Result<Self> {
48 Self::open_internal(file, header, true)
49 }
50
51 fn open_internal(file: &File, header: &Header, read_only: bool) -> Result<Self> {
52 if header.wal_size == 0 {
53 return Err(MemvidError::InvalidHeader {
54 reason: "wal_size must be non-zero".into(),
55 });
56 }
57 let mut clone = file.try_clone()?;
58 let region_offset = header.wal_offset;
59 let region_size = header.wal_size;
60 let checkpoint_sequence = header.wal_sequence;
61
62 let (entries, next_head) = Self::scan_records(&mut clone, region_offset, region_size)?;
63
64 let pending_bytes = entries
65 .iter()
66 .filter(|entry| entry.sequence > checkpoint_sequence)
67 .map(|entry| entry.total_size)
68 .sum();
69 let sequence = entries
70 .last()
71 .map_or(checkpoint_sequence, |entry| entry.sequence);
72
73 let mut wal = Self {
74 file: clone,
75 region_offset,
76 region_size,
77 write_head: next_head % region_size,
78 checkpoint_head: header.wal_checkpoint_pos % region_size,
79 pending_bytes,
80 sequence,
81 checkpoint_sequence,
82 appends_since_checkpoint: 0,
83 read_only,
84 skip_sync: false,
85 };
86
87 if !wal.read_only {
88 wal.initialise_sentinel()?;
89 }
90 Ok(wal)
91 }
92
93 fn assert_writable(&self) -> Result<()> {
94 if self.read_only {
95 return Err(MemvidError::Lock(
96 "wal is read-only; reopen memory with write access".into(),
97 ));
98 }
99 Ok(())
100 }
101
102 pub fn append_entry(&mut self, payload: &[u8]) -> Result<u64> {
103 self.assert_writable()?;
104 let payload_len = payload.len();
105 if payload_len > u32::MAX as usize {
106 return Err(MemvidError::CheckpointFailed {
107 reason: "WAL payload too large".into(),
108 });
109 }
110
111 let entry_size = ENTRY_HEADER_SIZE as u64 + payload_len as u64;
112 if entry_size > self.region_size {
113 return Err(MemvidError::CheckpointFailed {
114 reason: "embedded WAL region too small for entry".into(),
115 });
116 }
117 if self.pending_bytes + entry_size > self.region_size {
118 return Err(MemvidError::CheckpointFailed {
119 reason: "embedded WAL region full".into(),
120 });
121 }
122
123 let wrapping = self.write_head + entry_size > self.region_size;
125 if wrapping {
126 if self.pending_bytes > 0 {
131 return Err(MemvidError::CheckpointFailed {
132 reason: "embedded WAL region full".into(),
133 });
134 }
135 self.write_head = 0;
136 }
137
138 let next_sequence = self.sequence + 1;
139 tracing::debug!(
140 wal.write_head = self.write_head,
141 wal.sequence = next_sequence,
142 wal.payload_len = payload_len,
143 "wal append entry"
144 );
145 self.write_record(self.write_head, next_sequence, payload)?;
146
147 self.write_head = (self.write_head + entry_size) % self.region_size;
148 self.pending_bytes += entry_size;
149 self.sequence = self.sequence.wrapping_add(1);
150 self.appends_since_checkpoint = self.appends_since_checkpoint.saturating_add(1);
151
152 self.maybe_write_sentinel()?;
153
154 Ok(self.sequence)
155 }
156
157 #[must_use]
158 pub fn should_checkpoint(&self) -> bool {
159 if self.read_only || self.region_size == 0 {
160 return false;
161 }
162 let occupancy = self.pending_bytes as f64 / self.region_size as f64;
163 occupancy >= WAL_CHECKPOINT_THRESHOLD
164 || self.appends_since_checkpoint >= WAL_CHECKPOINT_PERIOD
165 }
166
167 pub fn record_checkpoint(&mut self, header: &mut Header) -> Result<()> {
168 self.assert_writable()?;
169 self.checkpoint_head = self.write_head;
170 self.pending_bytes = 0;
171 self.appends_since_checkpoint = 0;
172 self.checkpoint_sequence = self.sequence;
173 header.wal_checkpoint_pos = self.checkpoint_head;
174 header.wal_sequence = self.checkpoint_sequence;
175 self.maybe_write_sentinel()
176 }
177
178 pub fn pending_records(&mut self) -> Result<Vec<WalRecord>> {
179 self.records_after(self.checkpoint_sequence)
180 }
181
182 pub fn records_after(&mut self, sequence: u64) -> Result<Vec<WalRecord>> {
183 let (entries, next_head) =
184 Self::scan_records(&mut self.file, self.region_offset, self.region_size)?;
185
186 self.sequence = entries.last().map_or(self.sequence, |entry| entry.sequence);
187 self.pending_bytes = entries
188 .iter()
189 .filter(|entry| entry.sequence > self.checkpoint_sequence)
190 .map(|entry| entry.total_size)
191 .sum();
192 self.write_head = next_head % self.region_size;
193 if !self.read_only {
194 self.initialise_sentinel()?;
195 }
196
197 Ok(entries
198 .into_iter()
199 .filter(|entry| entry.sequence > sequence)
200 .map(|entry| WalRecord {
201 sequence: entry.sequence,
202 payload: entry.payload,
203 })
204 .collect())
205 }
206
207 #[must_use]
208 pub fn stats(&self) -> WalStats {
209 WalStats {
210 region_size: self.region_size,
211 pending_bytes: self.pending_bytes,
212 appends_since_checkpoint: self.appends_since_checkpoint,
213 sequence: self.sequence,
214 }
215 }
216
217 #[must_use]
218 pub fn region_offset(&self) -> u64 {
219 self.region_offset
220 }
221
222 #[must_use]
223 pub fn file(&self) -> &File {
224 &self.file
225 }
226
227 pub fn set_skip_sync(&mut self, skip: bool) {
233 self.skip_sync = skip;
234 }
235
236 pub fn flush(&mut self) -> Result<()> {
241 self.file.sync_all().map_err(Into::into)
242 }
243
244 fn initialise_sentinel(&mut self) -> Result<()> {
245 self.maybe_write_sentinel()
246 }
247
248 fn write_record(&mut self, position: u64, sequence: u64, payload: &[u8]) -> Result<()> {
249 self.assert_writable()?;
250 let digest = blake3::hash(payload);
251 let mut header = [0u8; ENTRY_HEADER_SIZE];
252 header[..8].copy_from_slice(&sequence.to_le_bytes());
253 header[8..12]
254 .copy_from_slice(&(u32::try_from(payload.len()).unwrap_or(u32::MAX)).to_le_bytes());
255 header[16..48].copy_from_slice(digest.as_bytes());
256
257 let mut combined = Vec::with_capacity(ENTRY_HEADER_SIZE + payload.len());
260 combined.extend_from_slice(&header);
261 combined.extend_from_slice(payload);
262
263 self.seek_and_write(position, &combined)?;
264 if tracing::enabled!(tracing::Level::DEBUG) {
265 if let Err(err) = self.debug_verify_header(position, sequence, payload.len()) {
266 tracing::warn!(error = %err, "wal header verify failed");
267 }
268 }
269
270 if !self.skip_sync {
274 self.file.sync_all()?;
275 }
276
277 Ok(())
278 }
279
280 fn write_zero_header(&mut self, position: u64) -> Result<u64> {
281 self.assert_writable()?;
282 if self.region_size == 0 {
283 return Ok(0);
284 }
285 let mut pos = position % self.region_size;
286 let remaining = self.region_size - pos;
287 if remaining < ENTRY_HEADER_SIZE as u64 {
288 if remaining > 0 {
289 #[allow(clippy::cast_possible_truncation)]
291 let zero_tail = vec![0u8; remaining as usize];
292 self.seek_and_write(pos, &zero_tail)?;
293 }
294 pos = 0;
295 }
296 let zero = [0u8; ENTRY_HEADER_SIZE];
297 self.seek_and_write(pos, &zero)?;
298 Ok(pos)
299 }
300
301 fn seek_and_write(&mut self, position: u64, bytes: &[u8]) -> Result<()> {
302 self.assert_writable()?;
303 let pos = position % self.region_size;
304 let absolute = self.region_offset + pos;
305 self.file.seek(SeekFrom::Start(absolute))?;
306 self.file.write_all(bytes)?;
307 Ok(())
308 }
309
310 fn maybe_write_sentinel(&mut self) -> Result<()> {
311 if self.read_only || self.region_size == 0 {
312 return Ok(());
313 }
314 if self.pending_bytes >= self.region_size {
315 return Ok(());
316 }
317 let next = self.write_zero_header(self.write_head)?;
319 self.write_head = next;
320 Ok(())
321 }
322
323 fn scan_records(file: &mut File, offset: u64, size: u64) -> Result<(Vec<ScannedRecord>, u64)> {
324 let mut records = Vec::new();
325 let mut cursor = 0u64;
326 while cursor + ENTRY_HEADER_SIZE as u64 <= size {
327 file.seek(SeekFrom::Start(offset + cursor))?;
328 let mut header = [0u8; ENTRY_HEADER_SIZE];
329 file.read_exact(&mut header)?;
330
331 let sequence = u64::from_le_bytes(header[..8].try_into().map_err(|_| {
332 MemvidError::WalCorruption {
333 offset: cursor,
334 reason: "invalid wal sequence header".into(),
335 }
336 })?);
337 let length = u64::from(u32::from_le_bytes(header[8..12].try_into().map_err(
338 |_| MemvidError::WalCorruption {
339 offset: cursor,
340 reason: "invalid wal length header".into(),
341 },
342 )?));
343 let checksum = &header[16..48];
344
345 if sequence == 0 && length == 0 {
346 break;
347 }
348 if length == 0 || cursor + ENTRY_HEADER_SIZE as u64 + length > size {
349 tracing::error!(
350 wal.scan_offset = cursor,
351 wal.sequence = sequence,
352 wal.length = length,
353 wal.region_size = size,
354 "wal record length invalid"
355 );
356 return Err(MemvidError::WalCorruption {
357 offset: cursor,
358 reason: "wal record length invalid".into(),
359 });
360 }
361
362 let length_usize = usize::try_from(length).map_err(|_| MemvidError::WalCorruption {
365 offset: cursor,
366 reason: "wal record length too large for platform".into(),
367 })?;
368 let mut payload = vec![0u8; length_usize];
369 file.read_exact(&mut payload)?;
370 let expected = blake3::hash(&payload);
371 if expected.as_bytes() != checksum {
372 return Err(MemvidError::WalCorruption {
373 offset: cursor,
374 reason: "wal record checksum mismatch".into(),
375 });
376 }
377
378 records.push(ScannedRecord {
379 sequence,
380 payload,
381 total_size: ENTRY_HEADER_SIZE as u64 + length,
382 });
383
384 cursor += ENTRY_HEADER_SIZE as u64 + length;
385 }
386
387 Ok((records, cursor))
388 }
389}
390
391#[derive(Debug)]
392struct ScannedRecord {
393 sequence: u64,
394 payload: Vec<u8>,
395 total_size: u64,
396}
397
398impl EmbeddedWal {
399 fn debug_verify_header(
400 &mut self,
401 position: u64,
402 expected_sequence: u64,
403 expected_len: usize,
404 ) -> Result<()> {
405 if self.region_size == 0 {
406 return Ok(());
407 }
408 let pos = position % self.region_size;
409 let absolute = self.region_offset + pos;
410 let mut buf = [0u8; ENTRY_HEADER_SIZE];
411 self.file.seek(SeekFrom::Start(absolute))?;
412 self.file.read_exact(&mut buf)?;
413
414 let seq = buf
416 .get(..8)
417 .and_then(|s| <[u8; 8]>::try_from(s).ok())
418 .map_or(0, u64::from_le_bytes);
419 let len = buf
420 .get(8..12)
421 .and_then(|s| <[u8; 4]>::try_from(s).ok())
422 .map_or(0, u32::from_le_bytes);
423
424 tracing::debug!(
425 wal.verify_position = pos,
426 wal.verify_sequence = seq,
427 wal.expected_sequence = expected_sequence,
428 wal.verify_length = len,
429 wal.expected_length = expected_len,
430 "wal header verify"
431 );
432 Ok(())
433 }
434}
435
436#[cfg(test)]
437mod tests {
438 use super::*;
439 use crate::constants::WAL_OFFSET;
440 use std::io::{Seek, SeekFrom, Write};
441 use tempfile::tempfile;
442
443 fn header_for(size: u64) -> Header {
444 Header {
445 magic: *b"MV2\0",
446 version: 0x0201,
447 footer_offset: 0,
448 wal_offset: WAL_OFFSET,
449 wal_size: size,
450 wal_checkpoint_pos: 0,
451 wal_sequence: 0,
452 toc_checksum: [0u8; 32],
453 }
454 }
455
456 fn prepare_wal(size: u64) -> (File, Header) {
457 let file = tempfile().expect("temp file");
458 file.set_len(WAL_OFFSET + size).expect("set_len");
459 let header = header_for(size);
460 (file, header)
461 }
462
463 #[test]
464 fn append_and_recover() {
465 let (file, header) = prepare_wal(1024);
466 let mut wal = EmbeddedWal::open(&file, &header).expect("open wal");
467
468 wal.append_entry(b"first").expect("append first");
469 wal.append_entry(b"second").expect("append second");
470
471 let records = wal.records_after(0).expect("records");
472 assert_eq!(records.len(), 2);
473 assert_eq!(records[0].payload, b"first");
474 assert_eq!(records[0].sequence, 1);
475 assert_eq!(records[1].payload, b"second");
476 assert_eq!(records[1].sequence, 2);
477 }
478
479 #[test]
480 fn wrap_and_checkpoint() {
481 let size = (ENTRY_HEADER_SIZE as u64 * 2) + 64;
482 let (file, mut header) = prepare_wal(size);
483 let mut wal = EmbeddedWal::open(&file, &header).expect("open wal");
484
485 wal.append_entry(&[0xAA; 32]).expect("append a");
486 wal.append_entry(&[0xBB; 32]).expect("append b");
487 wal.record_checkpoint(&mut header).expect("checkpoint");
488
489 assert!(wal.pending_records().expect("pending").is_empty());
490
491 wal.append_entry(&[0xCC; 32]).expect("append c");
492 let records = wal.pending_records().expect("after append");
493 assert_eq!(records.len(), 1);
494 assert_eq!(records[0].payload, vec![0xCC; 32]);
495 }
496
497 #[test]
498 fn corrupted_record_reports_offset() {
499 let (mut file, header) = prepare_wal(64);
500 file.seek(SeekFrom::Start(header.wal_offset)).expect("seek");
502 let mut record = [0u8; ENTRY_HEADER_SIZE];
503 record[..8].copy_from_slice(&1u64.to_le_bytes()); record[8..12].copy_from_slice(&(u32::MAX).to_le_bytes()); file.write_all(&record).expect("write corrupt header");
506 file.sync_all().expect("sync");
507
508 let err = EmbeddedWal::open(&file, &header).expect_err("open should fail");
509 match err {
510 MemvidError::WalCorruption { offset, reason } => {
511 assert_eq!(offset, 0);
512 assert!(reason.contains("length"), "reason should mention length");
513 }
514 other => panic!("unexpected error: {other:?}"),
515 }
516 }
517}