1use std::fs::File;
2use std::io::{Read, Seek, SeekFrom, Write};
3
4use crate::{
5 constants::{WAL_CHECKPOINT_PERIOD, WAL_CHECKPOINT_THRESHOLD},
6 error::{MemvidError, Result},
7 types::Header,
8};
9
10const ENTRY_HEADER_SIZE: usize = 48;
12
13#[derive(Debug, Clone, Copy, PartialEq)]
14pub struct WalStats {
15 pub region_size: u64,
16 pub pending_bytes: u64,
17 pub appends_since_checkpoint: u64,
18 pub sequence: u64,
19}
20
21#[derive(Debug, Clone, PartialEq, Eq)]
22pub struct WalRecord {
23 pub sequence: u64,
24 pub payload: Vec<u8>,
25}
26
27#[derive(Debug)]
28pub struct EmbeddedWal {
29 file: File,
30 region_offset: u64,
31 region_size: u64,
32 write_head: u64,
33 checkpoint_head: u64,
34 pending_bytes: u64,
35 sequence: u64,
36 checkpoint_sequence: u64,
37 appends_since_checkpoint: u64,
38 read_only: bool,
39}
40
41impl EmbeddedWal {
42 pub fn open(file: &File, header: &Header) -> Result<Self> {
43 Self::open_internal(file, header, false)
44 }
45
46 pub fn open_read_only(file: &File, header: &Header) -> Result<Self> {
47 Self::open_internal(file, header, true)
48 }
49
50 fn open_internal(file: &File, header: &Header, read_only: bool) -> Result<Self> {
51 if header.wal_size == 0 {
52 return Err(MemvidError::InvalidHeader {
53 reason: "wal_size must be non-zero".into(),
54 });
55 }
56 let mut clone = file.try_clone()?;
57 let region_offset = header.wal_offset;
58 let region_size = header.wal_size;
59 let checkpoint_sequence = header.wal_sequence;
60
61 let (entries, next_head) = Self::scan_records(&mut clone, region_offset, region_size)?;
62
63 let pending_bytes = entries
64 .iter()
65 .filter(|entry| entry.sequence > checkpoint_sequence)
66 .map(|entry| entry.total_size)
67 .sum();
68 let sequence = entries
69 .last()
70 .map_or(checkpoint_sequence, |entry| entry.sequence);
71
72 let mut wal = Self {
73 file: clone,
74 region_offset,
75 region_size,
76 write_head: next_head % region_size,
77 checkpoint_head: header.wal_checkpoint_pos % region_size,
78 pending_bytes,
79 sequence,
80 checkpoint_sequence,
81 appends_since_checkpoint: 0,
82 read_only,
83 };
84
85 if !wal.read_only {
86 wal.initialise_sentinel()?;
87 }
88 Ok(wal)
89 }
90
91 fn assert_writable(&self) -> Result<()> {
92 if self.read_only {
93 return Err(MemvidError::Lock(
94 "wal is read-only; reopen memory with write access".into(),
95 ));
96 }
97 Ok(())
98 }
99
100 pub fn append_entry(&mut self, payload: &[u8]) -> Result<u64> {
101 self.assert_writable()?;
102 let payload_len = payload.len();
103 if payload_len > u32::MAX as usize {
104 return Err(MemvidError::CheckpointFailed {
105 reason: "WAL payload too large".into(),
106 });
107 }
108
109 let entry_size = ENTRY_HEADER_SIZE as u64 + payload_len as u64;
110 if entry_size > self.region_size {
111 return Err(MemvidError::CheckpointFailed {
112 reason: "embedded WAL region too small for entry".into(),
113 });
114 }
115 if self.pending_bytes + entry_size > self.region_size {
116 return Err(MemvidError::CheckpointFailed {
117 reason: "embedded WAL region full".into(),
118 });
119 }
120
121 let wrapping = self.write_head + entry_size > self.region_size;
123 if wrapping {
124 if self.pending_bytes > 0 {
129 return Err(MemvidError::CheckpointFailed {
130 reason: "embedded WAL region full".into(),
131 });
132 }
133 self.write_head = 0;
134 }
135
136 let next_sequence = self.sequence + 1;
137 tracing::debug!(
138 wal.write_head = self.write_head,
139 wal.sequence = next_sequence,
140 wal.payload_len = payload_len,
141 "wal append entry"
142 );
143 self.write_record(self.write_head, next_sequence, payload)?;
144
145 self.write_head = (self.write_head + entry_size) % self.region_size;
146 self.pending_bytes += entry_size;
147 self.sequence = self.sequence.wrapping_add(1);
148 self.appends_since_checkpoint = self.appends_since_checkpoint.saturating_add(1);
149
150 self.maybe_write_sentinel()?;
151
152 Ok(self.sequence)
153 }
154
155 #[must_use]
156 pub fn should_checkpoint(&self) -> bool {
157 if self.read_only || self.region_size == 0 {
158 return false;
159 }
160 let occupancy = self.pending_bytes as f64 / self.region_size as f64;
161 occupancy >= WAL_CHECKPOINT_THRESHOLD
162 || self.appends_since_checkpoint >= WAL_CHECKPOINT_PERIOD
163 }
164
165 pub fn record_checkpoint(&mut self, header: &mut Header) -> Result<()> {
166 self.assert_writable()?;
167 self.checkpoint_head = self.write_head;
168 self.pending_bytes = 0;
169 self.appends_since_checkpoint = 0;
170 self.checkpoint_sequence = self.sequence;
171 header.wal_checkpoint_pos = self.checkpoint_head;
172 header.wal_sequence = self.checkpoint_sequence;
173 self.maybe_write_sentinel()
174 }
175
176 pub fn pending_records(&mut self) -> Result<Vec<WalRecord>> {
177 self.records_after(self.checkpoint_sequence)
178 }
179
180 pub fn records_after(&mut self, sequence: u64) -> Result<Vec<WalRecord>> {
181 let (entries, next_head) =
182 Self::scan_records(&mut self.file, self.region_offset, self.region_size)?;
183
184 self.sequence = entries.last().map_or(self.sequence, |entry| entry.sequence);
185 self.pending_bytes = entries
186 .iter()
187 .filter(|entry| entry.sequence > self.checkpoint_sequence)
188 .map(|entry| entry.total_size)
189 .sum();
190 self.write_head = next_head % self.region_size;
191 if !self.read_only {
192 self.initialise_sentinel()?;
193 }
194
195 Ok(entries
196 .into_iter()
197 .filter(|entry| entry.sequence > sequence)
198 .map(|entry| WalRecord {
199 sequence: entry.sequence,
200 payload: entry.payload,
201 })
202 .collect())
203 }
204
205 #[must_use]
206 pub fn stats(&self) -> WalStats {
207 WalStats {
208 region_size: self.region_size,
209 pending_bytes: self.pending_bytes,
210 appends_since_checkpoint: self.appends_since_checkpoint,
211 sequence: self.sequence,
212 }
213 }
214
215 #[must_use]
216 pub fn region_offset(&self) -> u64 {
217 self.region_offset
218 }
219
220 #[must_use]
221 pub fn file(&self) -> &File {
222 &self.file
223 }
224
225 fn initialise_sentinel(&mut self) -> Result<()> {
226 self.maybe_write_sentinel()
227 }
228
229 fn write_record(&mut self, position: u64, sequence: u64, payload: &[u8]) -> Result<()> {
230 self.assert_writable()?;
231 let digest = blake3::hash(payload);
232 let mut header = [0u8; ENTRY_HEADER_SIZE];
233 header[..8].copy_from_slice(&sequence.to_le_bytes());
234 header[8..12]
235 .copy_from_slice(&(u32::try_from(payload.len()).unwrap_or(u32::MAX)).to_le_bytes());
236 header[16..48].copy_from_slice(digest.as_bytes());
237
238 let mut combined = Vec::with_capacity(ENTRY_HEADER_SIZE + payload.len());
241 combined.extend_from_slice(&header);
242 combined.extend_from_slice(payload);
243
244 self.seek_and_write(position, &combined)?;
245 if tracing::enabled!(tracing::Level::DEBUG) {
246 if let Err(err) = self.debug_verify_header(position, sequence, payload.len()) {
247 tracing::warn!(error = %err, "wal header verify failed");
248 }
249 }
250
251 self.file.sync_all()?;
254
255 Ok(())
256 }
257
258 fn write_zero_header(&mut self, position: u64) -> Result<u64> {
259 self.assert_writable()?;
260 if self.region_size == 0 {
261 return Ok(0);
262 }
263 let mut pos = position % self.region_size;
264 let remaining = self.region_size - pos;
265 if remaining < ENTRY_HEADER_SIZE as u64 {
266 if remaining > 0 {
267 #[allow(clippy::cast_possible_truncation)]
269 let zero_tail = vec![0u8; remaining as usize];
270 self.seek_and_write(pos, &zero_tail)?;
271 }
272 pos = 0;
273 }
274 let zero = [0u8; ENTRY_HEADER_SIZE];
275 self.seek_and_write(pos, &zero)?;
276 Ok(pos)
277 }
278
279 fn seek_and_write(&mut self, position: u64, bytes: &[u8]) -> Result<()> {
280 self.assert_writable()?;
281 let pos = position % self.region_size;
282 let absolute = self.region_offset + pos;
283 self.file.seek(SeekFrom::Start(absolute))?;
284 self.file.write_all(bytes)?;
285 Ok(())
286 }
287
288 fn maybe_write_sentinel(&mut self) -> Result<()> {
289 if self.read_only || self.region_size == 0 {
290 return Ok(());
291 }
292 if self.pending_bytes >= self.region_size {
293 return Ok(());
294 }
295 let next = self.write_zero_header(self.write_head)?;
297 self.write_head = next;
298 Ok(())
299 }
300
301 fn scan_records(file: &mut File, offset: u64, size: u64) -> Result<(Vec<ScannedRecord>, u64)> {
302 let mut records = Vec::new();
303 let mut cursor = 0u64;
304 while cursor + ENTRY_HEADER_SIZE as u64 <= size {
305 file.seek(SeekFrom::Start(offset + cursor))?;
306 let mut header = [0u8; ENTRY_HEADER_SIZE];
307 file.read_exact(&mut header)?;
308
309 let sequence = u64::from_le_bytes(header[..8].try_into().map_err(|_| {
310 MemvidError::WalCorruption {
311 offset: cursor,
312 reason: "invalid wal sequence header".into(),
313 }
314 })?);
315 let length = u64::from(u32::from_le_bytes(header[8..12].try_into().map_err(
316 |_| MemvidError::WalCorruption {
317 offset: cursor,
318 reason: "invalid wal length header".into(),
319 },
320 )?));
321 let checksum = &header[16..48];
322
323 if sequence == 0 && length == 0 {
324 break;
325 }
326 if length == 0 || cursor + ENTRY_HEADER_SIZE as u64 + length > size {
327 tracing::error!(
328 wal.scan_offset = cursor,
329 wal.sequence = sequence,
330 wal.length = length,
331 wal.region_size = size,
332 "wal record length invalid"
333 );
334 return Err(MemvidError::WalCorruption {
335 offset: cursor,
336 reason: "wal record length invalid".into(),
337 });
338 }
339
340 let length_usize = usize::try_from(length).map_err(|_| MemvidError::WalCorruption {
343 offset: cursor,
344 reason: "wal record length too large for platform".into(),
345 })?;
346 let mut payload = vec![0u8; length_usize];
347 file.read_exact(&mut payload)?;
348 let expected = blake3::hash(&payload);
349 if expected.as_bytes() != checksum {
350 return Err(MemvidError::WalCorruption {
351 offset: cursor,
352 reason: "wal record checksum mismatch".into(),
353 });
354 }
355
356 records.push(ScannedRecord {
357 sequence,
358 payload,
359 total_size: ENTRY_HEADER_SIZE as u64 + length,
360 });
361
362 cursor += ENTRY_HEADER_SIZE as u64 + length;
363 }
364
365 Ok((records, cursor))
366 }
367}
368
369#[derive(Debug)]
370struct ScannedRecord {
371 sequence: u64,
372 payload: Vec<u8>,
373 total_size: u64,
374}
375
376impl EmbeddedWal {
377 fn debug_verify_header(
378 &mut self,
379 position: u64,
380 expected_sequence: u64,
381 expected_len: usize,
382 ) -> Result<()> {
383 if self.region_size == 0 {
384 return Ok(());
385 }
386 let pos = position % self.region_size;
387 let absolute = self.region_offset + pos;
388 let mut buf = [0u8; ENTRY_HEADER_SIZE];
389 self.file.seek(SeekFrom::Start(absolute))?;
390 self.file.read_exact(&mut buf)?;
391
392 let seq = buf
394 .get(..8)
395 .and_then(|s| <[u8; 8]>::try_from(s).ok())
396 .map_or(0, u64::from_le_bytes);
397 let len = buf
398 .get(8..12)
399 .and_then(|s| <[u8; 4]>::try_from(s).ok())
400 .map_or(0, u32::from_le_bytes);
401
402 tracing::debug!(
403 wal.verify_position = pos,
404 wal.verify_sequence = seq,
405 wal.expected_sequence = expected_sequence,
406 wal.verify_length = len,
407 wal.expected_length = expected_len,
408 "wal header verify"
409 );
410 Ok(())
411 }
412}
413
414#[cfg(test)]
415mod tests {
416 use super::*;
417 use crate::constants::WAL_OFFSET;
418 use std::io::{Seek, SeekFrom, Write};
419 use tempfile::tempfile;
420
421 fn header_for(size: u64) -> Header {
422 Header {
423 magic: *b"MV2\0",
424 version: 0x0201,
425 footer_offset: 0,
426 wal_offset: WAL_OFFSET,
427 wal_size: size,
428 wal_checkpoint_pos: 0,
429 wal_sequence: 0,
430 toc_checksum: [0u8; 32],
431 }
432 }
433
434 fn prepare_wal(size: u64) -> (File, Header) {
435 let file = tempfile().expect("temp file");
436 file.set_len(WAL_OFFSET + size).expect("set_len");
437 let header = header_for(size);
438 (file, header)
439 }
440
441 #[test]
442 fn append_and_recover() {
443 let (file, header) = prepare_wal(1024);
444 let mut wal = EmbeddedWal::open(&file, &header).expect("open wal");
445
446 wal.append_entry(b"first").expect("append first");
447 wal.append_entry(b"second").expect("append second");
448
449 let records = wal.records_after(0).expect("records");
450 assert_eq!(records.len(), 2);
451 assert_eq!(records[0].payload, b"first");
452 assert_eq!(records[0].sequence, 1);
453 assert_eq!(records[1].payload, b"second");
454 assert_eq!(records[1].sequence, 2);
455 }
456
457 #[test]
458 fn wrap_and_checkpoint() {
459 let size = (ENTRY_HEADER_SIZE as u64 * 2) + 64;
460 let (file, mut header) = prepare_wal(size);
461 let mut wal = EmbeddedWal::open(&file, &header).expect("open wal");
462
463 wal.append_entry(&[0xAA; 32]).expect("append a");
464 wal.append_entry(&[0xBB; 32]).expect("append b");
465 wal.record_checkpoint(&mut header).expect("checkpoint");
466
467 assert!(wal.pending_records().expect("pending").is_empty());
468
469 wal.append_entry(&[0xCC; 32]).expect("append c");
470 let records = wal.pending_records().expect("after append");
471 assert_eq!(records.len(), 1);
472 assert_eq!(records[0].payload, vec![0xCC; 32]);
473 }
474
475 #[test]
476 fn corrupted_record_reports_offset() {
477 let (mut file, header) = prepare_wal(64);
478 file.seek(SeekFrom::Start(header.wal_offset)).expect("seek");
480 let mut record = [0u8; ENTRY_HEADER_SIZE];
481 record[..8].copy_from_slice(&1u64.to_le_bytes()); record[8..12].copy_from_slice(&(u32::MAX).to_le_bytes()); file.write_all(&record).expect("write corrupt header");
484 file.sync_all().expect("sync");
485
486 let err = EmbeddedWal::open(&file, &header).expect_err("open should fail");
487 match err {
488 MemvidError::WalCorruption { offset, reason } => {
489 assert_eq!(offset, 0);
490 assert!(reason.contains("length"), "reason should mention length");
491 }
492 other => panic!("unexpected error: {other:?}"),
493 }
494 }
495}