1use std::fs::File;
2use std::io::{Read, Seek, SeekFrom, Write};
3
4use crate::{
5 constants::{WAL_CHECKPOINT_PERIOD, WAL_CHECKPOINT_THRESHOLD},
6 error::{MemvidError, Result},
7 types::Header,
8};
9
10const ENTRY_HEADER_SIZE: usize = 48;
12
13#[derive(Debug, Clone, Copy, PartialEq)]
14pub struct WalStats {
15 pub region_size: u64,
16 pub pending_bytes: u64,
17 pub appends_since_checkpoint: u64,
18 pub sequence: u64,
19}
20
21#[derive(Debug, Clone, PartialEq, Eq)]
22pub struct WalRecord {
23 pub sequence: u64,
24 pub payload: Vec<u8>,
25}
26
27#[derive(Debug)]
28pub struct EmbeddedWal {
29 file: File,
30 region_offset: u64,
31 region_size: u64,
32 write_head: u64,
33 checkpoint_head: u64,
34 pending_bytes: u64,
35 sequence: u64,
36 checkpoint_sequence: u64,
37 appends_since_checkpoint: u64,
38 read_only: bool,
39}
40
41impl EmbeddedWal {
42 pub fn open(file: &File, header: &Header) -> Result<Self> {
43 Self::open_internal(file, header, false)
44 }
45
46 pub fn open_read_only(file: &File, header: &Header) -> Result<Self> {
47 Self::open_internal(file, header, true)
48 }
49
50 fn open_internal(file: &File, header: &Header, read_only: bool) -> Result<Self> {
51 if header.wal_size == 0 {
52 return Err(MemvidError::InvalidHeader {
53 reason: "wal_size must be non-zero".into(),
54 });
55 }
56 let mut clone = file.try_clone()?;
57 let region_offset = header.wal_offset;
58 let region_size = header.wal_size;
59 let checkpoint_sequence = header.wal_sequence;
60
61 let (entries, next_head) = Self::scan_records(&mut clone, region_offset, region_size)?;
62
63 let pending_bytes = entries
64 .iter()
65 .filter(|entry| entry.sequence > checkpoint_sequence)
66 .map(|entry| entry.total_size)
67 .sum();
68 let sequence = entries
69 .last()
70 .map(|entry| entry.sequence)
71 .unwrap_or(checkpoint_sequence);
72
73 let mut wal = Self {
74 file: clone,
75 region_offset,
76 region_size,
77 write_head: next_head % region_size,
78 checkpoint_head: header.wal_checkpoint_pos % region_size,
79 pending_bytes,
80 sequence,
81 checkpoint_sequence,
82 appends_since_checkpoint: 0,
83 read_only,
84 };
85
86 if !wal.read_only {
87 wal.initialise_sentinel()?;
88 }
89 Ok(wal)
90 }
91
92 fn assert_writable(&self) -> Result<()> {
93 if self.read_only {
94 return Err(MemvidError::Lock(
95 "wal is read-only; reopen memory with write access".into(),
96 ));
97 }
98 Ok(())
99 }
100
101 pub fn append_entry(&mut self, payload: &[u8]) -> Result<u64> {
102 self.assert_writable()?;
103 let payload_len = payload.len();
104 if payload_len > u32::MAX as usize {
105 return Err(MemvidError::CheckpointFailed {
106 reason: "WAL payload too large".into(),
107 });
108 }
109
110 let entry_size = ENTRY_HEADER_SIZE as u64 + payload_len as u64;
111 if entry_size > self.region_size {
112 return Err(MemvidError::CheckpointFailed {
113 reason: "embedded WAL region too small for entry".into(),
114 });
115 }
116 if self.pending_bytes + entry_size > self.region_size {
117 return Err(MemvidError::CheckpointFailed {
118 reason: "embedded WAL region full".into(),
119 });
120 }
121
122 if self.write_head + entry_size > self.region_size {
123 self.write_head = 0;
124 }
125
126 let next_sequence = self.sequence + 1;
127 tracing::debug!(
128 wal.write_head = self.write_head,
129 wal.sequence = next_sequence,
130 wal.payload_len = payload_len,
131 "wal append entry"
132 );
133 self.write_record(self.write_head, next_sequence, payload)?;
134
135 self.write_head = (self.write_head + entry_size) % self.region_size;
136 self.pending_bytes += entry_size;
137 self.sequence = self.sequence.wrapping_add(1);
138 self.appends_since_checkpoint = self.appends_since_checkpoint.saturating_add(1);
139
140 self.maybe_write_sentinel()?;
141
142 Ok(self.sequence)
143 }
144
145 pub fn should_checkpoint(&self) -> bool {
146 if self.read_only || self.region_size == 0 {
147 return false;
148 }
149 let occupancy = self.pending_bytes as f64 / self.region_size as f64;
150 occupancy >= WAL_CHECKPOINT_THRESHOLD
151 || self.appends_since_checkpoint >= WAL_CHECKPOINT_PERIOD
152 }
153
154 pub fn record_checkpoint(&mut self, header: &mut Header) -> Result<()> {
155 self.assert_writable()?;
156 self.checkpoint_head = self.write_head;
157 self.pending_bytes = 0;
158 self.appends_since_checkpoint = 0;
159 self.checkpoint_sequence = self.sequence;
160 header.wal_checkpoint_pos = self.checkpoint_head;
161 header.wal_sequence = self.checkpoint_sequence;
162 self.maybe_write_sentinel()
163 }
164
165 pub fn pending_records(&mut self) -> Result<Vec<WalRecord>> {
166 self.records_after(self.checkpoint_sequence)
167 }
168
169 pub fn records_after(&mut self, sequence: u64) -> Result<Vec<WalRecord>> {
170 let (entries, next_head) =
171 Self::scan_records(&mut self.file, self.region_offset, self.region_size)?;
172
173 self.sequence = entries
174 .last()
175 .map(|entry| entry.sequence)
176 .unwrap_or(self.sequence);
177 self.pending_bytes = entries
178 .iter()
179 .filter(|entry| entry.sequence > self.checkpoint_sequence)
180 .map(|entry| entry.total_size)
181 .sum();
182 self.write_head = next_head % self.region_size;
183 if !self.read_only {
184 self.initialise_sentinel()?;
185 }
186
187 Ok(entries
188 .into_iter()
189 .filter(|entry| entry.sequence > sequence)
190 .map(|entry| WalRecord {
191 sequence: entry.sequence,
192 payload: entry.payload,
193 })
194 .collect())
195 }
196
197 pub fn stats(&self) -> WalStats {
198 WalStats {
199 region_size: self.region_size,
200 pending_bytes: self.pending_bytes,
201 appends_since_checkpoint: self.appends_since_checkpoint,
202 sequence: self.sequence,
203 }
204 }
205
206 pub fn region_offset(&self) -> u64 {
207 self.region_offset
208 }
209
210 pub fn file(&self) -> &File {
211 &self.file
212 }
213
214 fn initialise_sentinel(&mut self) -> Result<()> {
215 self.maybe_write_sentinel()
216 }
217
218 fn write_record(&mut self, position: u64, sequence: u64, payload: &[u8]) -> Result<()> {
219 self.assert_writable()?;
220 let digest = blake3::hash(payload);
221 let mut header = [0u8; ENTRY_HEADER_SIZE];
222 header[..8].copy_from_slice(&sequence.to_le_bytes());
223 header[8..12].copy_from_slice(&(payload.len() as u32).to_le_bytes());
224 header[16..48].copy_from_slice(digest.as_bytes());
225
226 let mut combined = Vec::with_capacity(ENTRY_HEADER_SIZE + payload.len());
229 combined.extend_from_slice(&header);
230 combined.extend_from_slice(payload);
231
232 self.seek_and_write(position, &combined)?;
233 if tracing::enabled!(tracing::Level::DEBUG) {
234 if let Err(err) = self.debug_verify_header(position, sequence, payload.len()) {
235 tracing::warn!(error = %err, "wal header verify failed");
236 }
237 }
238
239 self.file.sync_all()?;
242
243 Ok(())
244 }
245
246 fn write_zero_header(&mut self, position: u64) -> Result<u64> {
247 self.assert_writable()?;
248 if self.region_size == 0 {
249 return Ok(0);
250 }
251 let mut pos = position % self.region_size;
252 let remaining = self.region_size - pos;
253 if remaining < ENTRY_HEADER_SIZE as u64 {
254 if remaining > 0 {
255 let zero_tail = vec![0u8; remaining as usize];
256 self.seek_and_write(pos, &zero_tail)?;
257 }
258 pos = 0;
259 }
260 let zero = [0u8; ENTRY_HEADER_SIZE];
261 self.seek_and_write(pos, &zero)?;
262 Ok(pos)
263 }
264
265 fn seek_and_write(&mut self, position: u64, bytes: &[u8]) -> Result<()> {
266 self.assert_writable()?;
267 let pos = position % self.region_size;
268 let absolute = self.region_offset + pos;
269 self.file.seek(SeekFrom::Start(absolute))?;
270 self.file.write_all(bytes)?;
271 Ok(())
272 }
273
274 fn maybe_write_sentinel(&mut self) -> Result<()> {
275 if self.read_only || self.region_size == 0 {
276 return Ok(());
277 }
278 if self.pending_bytes >= self.region_size {
279 return Ok(());
280 }
281 let next = self.write_zero_header(self.write_head)?;
283 self.write_head = next;
284 Ok(())
285 }
286
287 fn scan_records(file: &mut File, offset: u64, size: u64) -> Result<(Vec<ScannedRecord>, u64)> {
288 let mut records = Vec::new();
289 let mut cursor = 0u64;
290 while cursor + ENTRY_HEADER_SIZE as u64 <= size {
291 file.seek(SeekFrom::Start(offset + cursor))?;
292 let mut header = [0u8; ENTRY_HEADER_SIZE];
293 file.read_exact(&mut header)?;
294
295 let sequence = u64::from_le_bytes(header[..8].try_into().map_err(|_| {
296 MemvidError::WalCorruption {
297 offset: cursor,
298 reason: "invalid wal sequence header".into(),
299 }
300 })?);
301 let length = u32::from_le_bytes(header[8..12].try_into().map_err(|_| {
302 MemvidError::WalCorruption {
303 offset: cursor,
304 reason: "invalid wal length header".into(),
305 }
306 })?) as u64;
307 let checksum = &header[16..48];
308
309 if sequence == 0 && length == 0 {
310 break;
311 }
312 if length == 0 || cursor + ENTRY_HEADER_SIZE as u64 + length > size {
313 tracing::error!(
314 wal.scan_offset = cursor,
315 wal.sequence = sequence,
316 wal.length = length,
317 wal.region_size = size,
318 "wal record length invalid"
319 );
320 return Err(MemvidError::WalCorruption {
321 offset: cursor,
322 reason: "wal record length invalid".into(),
323 });
324 }
325
326 let mut payload = vec![0u8; length as usize];
327 file.read_exact(&mut payload)?;
328 let expected = blake3::hash(&payload);
329 if expected.as_bytes() != checksum {
330 return Err(MemvidError::WalCorruption {
331 offset: cursor,
332 reason: "wal record checksum mismatch".into(),
333 });
334 }
335
336 records.push(ScannedRecord {
337 sequence,
338 payload,
339 total_size: ENTRY_HEADER_SIZE as u64 + length,
340 });
341
342 cursor += ENTRY_HEADER_SIZE as u64 + length;
343 }
344
345 Ok((records, cursor))
346 }
347}
348
349#[derive(Debug)]
350struct ScannedRecord {
351 sequence: u64,
352 payload: Vec<u8>,
353 total_size: u64,
354}
355
356impl EmbeddedWal {
357 fn debug_verify_header(
358 &mut self,
359 position: u64,
360 expected_sequence: u64,
361 expected_len: usize,
362 ) -> Result<()> {
363 if self.region_size == 0 {
364 return Ok(());
365 }
366 let pos = position % self.region_size;
367 let absolute = self.region_offset + pos;
368 let mut buf = [0u8; ENTRY_HEADER_SIZE];
369 self.file.seek(SeekFrom::Start(absolute))?;
370 self.file.read_exact(&mut buf)?;
371 let seq = u64::from_le_bytes(buf[..8].try_into().unwrap());
372 let len = u32::from_le_bytes(buf[8..12].try_into().unwrap());
373 tracing::debug!(
374 wal.verify_position = pos,
375 wal.verify_sequence = seq,
376 wal.expected_sequence = expected_sequence,
377 wal.verify_length = len,
378 wal.expected_length = expected_len,
379 "wal header verify"
380 );
381 Ok(())
382 }
383}
384
385#[cfg(test)]
386mod tests {
387 use super::*;
388 use crate::constants::WAL_OFFSET;
389 use std::io::{Seek, SeekFrom, Write};
390 use tempfile::tempfile;
391
392 fn header_for(size: u64) -> Header {
393 Header {
394 magic: *b"MV2\0",
395 version: 0x0201,
396 footer_offset: 0,
397 wal_offset: WAL_OFFSET,
398 wal_size: size,
399 wal_checkpoint_pos: 0,
400 wal_sequence: 0,
401 toc_checksum: [0u8; 32],
402 }
403 }
404
405 fn prepare_wal(size: u64) -> (File, Header) {
406 let file = tempfile().expect("temp file");
407 file.set_len(WAL_OFFSET + size).expect("set_len");
408 let header = header_for(size);
409 (file, header)
410 }
411
412 #[test]
413 fn append_and_recover() {
414 let (file, header) = prepare_wal(1024);
415 let mut wal = EmbeddedWal::open(&file, &header).expect("open wal");
416
417 wal.append_entry(b"first").expect("append first");
418 wal.append_entry(b"second").expect("append second");
419
420 let records = wal.records_after(0).expect("records");
421 assert_eq!(records.len(), 2);
422 assert_eq!(records[0].payload, b"first");
423 assert_eq!(records[0].sequence, 1);
424 assert_eq!(records[1].payload, b"second");
425 assert_eq!(records[1].sequence, 2);
426 }
427
428 #[test]
429 fn wrap_and_checkpoint() {
430 let size = (ENTRY_HEADER_SIZE as u64 * 2) + 64;
431 let (file, mut header) = prepare_wal(size);
432 let mut wal = EmbeddedWal::open(&file, &header).expect("open wal");
433
434 wal.append_entry(&vec![0xAA; 32]).expect("append a");
435 wal.append_entry(&vec![0xBB; 32]).expect("append b");
436 wal.record_checkpoint(&mut header).expect("checkpoint");
437
438 assert!(wal.pending_records().expect("pending").is_empty());
439
440 wal.append_entry(&vec![0xCC; 32]).expect("append c");
441 let records = wal.pending_records().expect("after append");
442 assert_eq!(records.len(), 1);
443 assert_eq!(records[0].payload, vec![0xCC; 32]);
444 }
445
446 #[test]
447 fn corrupted_record_reports_offset() {
448 let (mut file, header) = prepare_wal(64);
449 file.seek(SeekFrom::Start(header.wal_offset)).expect("seek");
451 let mut record = [0u8; ENTRY_HEADER_SIZE];
452 record[..8].copy_from_slice(&1u64.to_le_bytes()); record[8..12].copy_from_slice(&(u32::MAX).to_le_bytes()); file.write_all(&record).expect("write corrupt header");
455 file.sync_all().expect("sync");
456
457 let err = EmbeddedWal::open(&file, &header).expect_err("open should fail");
458 match err {
459 MemvidError::WalCorruption { offset, reason } => {
460 assert_eq!(offset, 0);
461 assert!(reason.contains("length"), "reason should mention length");
462 }
463 other => panic!("unexpected error: {other:?}"),
464 }
465 }
466}