1use std::fs::{File, OpenOptions};
6use std::io::{self, BufWriter, Read, Seek, SeekFrom, Write};
7use std::path::Path;
8
9#[cfg(windows)]
12fn configure_open_options(opts: &mut OpenOptions) {
13 use std::os::windows::fs::OpenOptionsExt;
14 opts.share_mode(0x1 | 0x2 | 0x4);
16}
17
18#[cfg(not(windows))]
19fn configure_open_options(_opts: &mut OpenOptions) {
20 }
22
23#[derive(Debug, Clone, Copy, PartialEq, Eq)]
25#[repr(u8)]
26pub enum WalEntryType {
27 InsertNode = 1,
29 DeleteNode = 2,
31 UpdateNeighbors = 3,
33 UpdateMetadata = 4,
35 Checkpoint = 100,
37}
38
39impl WalEntryType {
40 #[must_use]
44 pub fn from_byte(v: u8) -> Option<Self> {
45 match v {
46 1 => Some(Self::InsertNode),
47 2 => Some(Self::DeleteNode),
48 3 => Some(Self::UpdateNeighbors),
49 4 => Some(Self::UpdateMetadata),
50 100 => Some(Self::Checkpoint),
51 _ => None, }
53 }
54}
55
56impl From<u8> for WalEntryType {
57 fn from(v: u8) -> Self {
58 Self::from_byte(v).unwrap_or(Self::Checkpoint)
61 }
62}
63
64#[derive(Debug, Clone)]
67pub struct WalEntryHeader {
68 pub entry_type: WalEntryType,
69 pub timestamp: u64, pub data_len: u32,
71 pub checksum: u32,
72}
73
74impl WalEntryHeader {
75 pub const SIZE: usize = 20;
76
77 pub fn to_bytes(&self) -> [u8; Self::SIZE] {
78 let mut buf = [0u8; Self::SIZE];
79 buf[0] = self.entry_type as u8;
80 buf[4..12].copy_from_slice(&self.timestamp.to_le_bytes());
82 buf[12..16].copy_from_slice(&self.data_len.to_le_bytes());
83 buf[16..20].copy_from_slice(&self.checksum.to_le_bytes());
84 buf
85 }
86
87 pub fn from_bytes(buf: &[u8; Self::SIZE]) -> Self {
88 Self {
90 entry_type: WalEntryType::from(buf[0]),
91 timestamp: u64::from_le_bytes([
92 buf[4], buf[5], buf[6], buf[7], buf[8], buf[9], buf[10], buf[11],
93 ]),
94 data_len: u32::from_le_bytes([buf[12], buf[13], buf[14], buf[15]]),
95 checksum: u32::from_le_bytes([buf[16], buf[17], buf[18], buf[19]]),
96 }
97 }
98}
99
100#[derive(Debug, Clone)]
102pub struct WalEntry {
103 pub header: WalEntryHeader,
104 pub data: Vec<u8>,
105}
106
107impl WalEntry {
108 #[must_use]
110 pub fn insert_node(
111 timestamp: u64,
112 string_id: &str,
113 level: u8,
114 vector: &[f32],
115 metadata: &[u8],
116 ) -> Self {
117 let mut data = Vec::new();
118
119 data.extend_from_slice(&(string_id.len() as u32).to_le_bytes());
121 data.extend_from_slice(string_id.as_bytes());
122
123 data.push(level);
125
126 data.extend_from_slice(&(vector.len() as u32).to_le_bytes());
128 for &val in vector {
129 data.extend_from_slice(&val.to_le_bytes());
130 }
131
132 data.extend_from_slice(&(metadata.len() as u32).to_le_bytes());
134 data.extend_from_slice(metadata);
135
136 let checksum = crc32fast::hash(&data);
137
138 Self {
139 header: WalEntryHeader {
140 entry_type: WalEntryType::InsertNode,
141 timestamp,
142 data_len: data.len() as u32,
143 checksum,
144 },
145 data,
146 }
147 }
148
149 #[must_use]
151 pub fn delete_node(timestamp: u64, string_id: &str) -> Self {
152 let mut data = Vec::new();
153 data.extend_from_slice(&(string_id.len() as u32).to_le_bytes());
154 data.extend_from_slice(string_id.as_bytes());
155
156 let checksum = crc32fast::hash(&data);
157
158 Self {
159 header: WalEntryHeader {
160 entry_type: WalEntryType::DeleteNode,
161 timestamp,
162 data_len: data.len() as u32,
163 checksum,
164 },
165 data,
166 }
167 }
168
169 #[must_use]
171 pub fn update_neighbors(timestamp: u64, node_id: u32, level: u8, neighbors: &[u32]) -> Self {
172 let mut data = Vec::new();
173
174 data.extend_from_slice(&node_id.to_le_bytes());
176
177 data.push(level);
179
180 data.extend_from_slice(&(neighbors.len() as u32).to_le_bytes());
182 for &neighbor in neighbors {
183 data.extend_from_slice(&neighbor.to_le_bytes());
184 }
185
186 let checksum = crc32fast::hash(&data);
187
188 Self {
189 header: WalEntryHeader {
190 entry_type: WalEntryType::UpdateNeighbors,
191 timestamp,
192 data_len: data.len() as u32,
193 checksum,
194 },
195 data,
196 }
197 }
198
199 #[must_use]
201 pub fn checkpoint(timestamp: u64) -> Self {
202 Self {
203 header: WalEntryHeader {
204 entry_type: WalEntryType::Checkpoint,
205 timestamp,
206 data_len: 0,
207 checksum: 0,
208 },
209 data: Vec::new(),
210 }
211 }
212
213 #[must_use]
215 pub fn verify(&self) -> bool {
216 if self.data.is_empty() {
217 return self.header.checksum == 0;
218 }
219 crc32fast::hash(&self.data) == self.header.checksum
220 }
221}
222
223pub struct Wal {
225 file: BufWriter<File>,
226 #[allow(dead_code)]
227 path: std::path::PathBuf,
228 next_timestamp: u64,
229 entry_count: u64,
230}
231
232impl Wal {
233 pub fn open(path: impl AsRef<Path>) -> io::Result<Self> {
235 let path = path.as_ref().to_path_buf();
236 let mut opts = OpenOptions::new();
237 opts.read(true).write(true).create(true);
240 configure_open_options(&mut opts);
241 let mut file = opts.open(&path)?;
242
243 let metadata = file.metadata()?;
244 let file_len = metadata.len();
245
246 if file_len > 0 {
248 file.seek(SeekFrom::End(0))?;
249 }
250
251 let mut wal = Self {
252 file: BufWriter::new(file),
253 path,
254 next_timestamp: 0,
255 entry_count: 0,
256 };
257
258 if file_len > 0 {
260 wal.scan_for_timestamp()?;
261 }
262
263 Ok(wal)
264 }
265
266 fn scan_for_timestamp(&mut self) -> io::Result<()> {
268 let file = self.file.get_mut();
269 file.seek(SeekFrom::Start(0))?;
270
271 let mut header_buf = [0u8; WalEntryHeader::SIZE];
272 let mut max_timestamp = 0u64;
273 let mut count = 0u64;
274
275 const MAX_ENTRY_SIZE: u32 = 100 * 1024 * 1024;
277
278 loop {
279 match file.read_exact(&mut header_buf) {
280 Ok(()) => {
281 let header = WalEntryHeader::from_bytes(&header_buf);
282
283 if header.data_len > MAX_ENTRY_SIZE {
285 return Err(io::Error::new(
286 io::ErrorKind::InvalidData,
287 format!(
288 "WAL entry has suspicious data_len: {} bytes (max: {})",
289 header.data_len, MAX_ENTRY_SIZE
290 ),
291 ));
292 }
293
294 max_timestamp = max_timestamp.max(header.timestamp);
295 count += 1;
296
297 if header.data_len > 0 {
299 file.seek(SeekFrom::Current(header.data_len as i64))?;
300 }
301 }
302 Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => break,
303 Err(e) => return Err(e),
304 }
305 }
306
307 self.next_timestamp = max_timestamp + 1;
308 self.entry_count = count;
309
310 file.seek(SeekFrom::End(0))?;
312
313 Ok(())
314 }
315
316 pub fn append(&mut self, mut entry: WalEntry) -> io::Result<()> {
318 entry.header.timestamp = self.next_timestamp;
319 self.next_timestamp += 1;
320
321 self.file.write_all(&entry.header.to_bytes())?;
322 if !entry.data.is_empty() {
323 self.file.write_all(&entry.data)?;
324 }
325
326 self.entry_count += 1;
327 Ok(())
328 }
329
330 pub fn sync(&mut self) -> io::Result<()> {
332 self.file.flush()?;
333 self.file.get_mut().sync_all()
334 }
335
336 pub fn entries_after_checkpoint(&mut self) -> io::Result<Vec<WalEntry>> {
341 let file = self.file.get_mut();
342 file.seek(SeekFrom::Start(0))?;
343
344 let mut all_entries = Vec::new();
345 let mut last_checkpoint_idx: Option<usize> = None;
346 let mut header_buf = [0u8; WalEntryHeader::SIZE];
347
348 const MAX_ENTRY_SIZE: u32 = 100 * 1024 * 1024;
350
351 loop {
352 match file.read_exact(&mut header_buf) {
353 Ok(()) => {
354 let header = WalEntryHeader::from_bytes(&header_buf);
355
356 if header.data_len > MAX_ENTRY_SIZE {
358 return Err(io::Error::new(
359 io::ErrorKind::InvalidData,
360 format!(
361 "WAL entry has suspicious data_len: {} bytes",
362 header.data_len
363 ),
364 ));
365 }
366
367 let mut data = vec![0u8; header.data_len as usize];
368 if header.data_len > 0 {
369 file.read_exact(&mut data)?;
370 }
371
372 let entry = WalEntry { header, data };
373
374 if !entry.verify() {
376 continue;
377 }
378
379 let entry_type_byte = entry.header.entry_type as u8;
382 if WalEntryType::from_byte(entry_type_byte) == Some(WalEntryType::Checkpoint) {
383 last_checkpoint_idx = Some(all_entries.len());
384 }
385
386 all_entries.push(entry);
387 }
388 Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => break,
389 Err(e) => return Err(e),
390 }
391 }
392
393 match last_checkpoint_idx {
395 Some(idx) => Ok(all_entries.split_off(idx + 1)),
396 None => Ok(all_entries),
397 }
398 }
399
400 #[must_use]
402 pub fn len(&self) -> u64 {
403 self.entry_count
404 }
405
406 #[must_use]
408 pub fn is_empty(&self) -> bool {
409 self.entry_count == 0
410 }
411
412 pub fn truncate(&mut self) -> io::Result<()> {
414 self.file.flush()?;
416 self.file.get_mut().set_len(0)?;
417 self.file.get_mut().seek(SeekFrom::Start(0))?;
418 self.next_timestamp = 0;
419 self.entry_count = 0;
420 Ok(())
421 }
422}
423
424#[cfg(test)]
425mod tests {
426 use super::*;
427 use tempfile::tempdir;
428
429 #[test]
430 fn test_wal_roundtrip() {
431 let dir = tempdir().unwrap();
432 let wal_path = dir.path().join("test.wal");
433
434 {
435 let mut wal = Wal::open(&wal_path).unwrap();
436 wal.append(WalEntry::insert_node(0, "vec1", 0, &[1.0, 2.0, 3.0], b"{}"))
437 .unwrap();
438 wal.append(WalEntry::delete_node(0, "vec2")).unwrap();
439 wal.append(WalEntry::checkpoint(0)).unwrap();
440 wal.append(WalEntry::insert_node(0, "vec3", 1, &[4.0, 5.0, 6.0], b"{}"))
441 .unwrap();
442 wal.sync().unwrap();
443 }
444
445 {
446 let mut wal = Wal::open(&wal_path).unwrap();
447 let entries = wal.entries_after_checkpoint().unwrap();
448
449 assert_eq!(entries.len(), 1);
451 assert_eq!(entries[0].header.entry_type, WalEntryType::InsertNode);
452 }
453 }
454
455 #[test]
456 fn test_entry_checksum() {
457 let entry = WalEntry::insert_node(1, "test", 0, &[1.0, 2.0], b"metadata");
458 assert!(entry.verify());
459 }
460
461 #[test]
462 fn test_corrupted_entry_data_detected() {
463 let mut entry = WalEntry::insert_node(1, "test", 0, &[1.0, 2.0], b"metadata");
464 assert!(entry.verify());
465
466 if !entry.data.is_empty() {
468 entry.data[0] ^= 0xFF;
469 }
470
471 assert!(!entry.verify(), "Corrupted data should fail verification");
473 }
474
475 #[test]
476 fn test_corrupted_entry_checksum_detected() {
477 let mut entry = WalEntry::insert_node(1, "test", 0, &[1.0, 2.0], b"metadata");
478 assert!(entry.verify());
479
480 entry.header.checksum ^= 0xFFFF_FFFF;
482
483 assert!(
485 !entry.verify(),
486 "Corrupted checksum should fail verification"
487 );
488 }
489
490 #[test]
491 fn test_wal_recovery_skips_corrupted_entries() {
492 use std::io::Write;
493
494 let dir = tempdir().unwrap();
495 let wal_path = dir.path().join("test_corrupt.wal");
496
497 {
499 let mut wal = Wal::open(&wal_path).unwrap();
500 wal.append(WalEntry::insert_node(0, "vec1", 0, &[1.0, 2.0, 3.0], b"{}"))
501 .unwrap();
502 wal.append(WalEntry::insert_node(0, "vec2", 0, &[4.0, 5.0, 6.0], b"{}"))
503 .unwrap();
504 wal.sync().unwrap();
505 }
506
507 {
509 let mut file = OpenOptions::new()
510 .read(true)
511 .write(true)
512 .open(&wal_path)
513 .unwrap();
514
515 file.seek(SeekFrom::Start(25)).unwrap();
518 file.write_all(&[0xFF, 0xFF, 0xFF, 0xFF]).unwrap();
519 file.sync_all().unwrap();
520 }
521
522 {
524 let mut wal = Wal::open(&wal_path).unwrap();
525 let entries = wal.entries_after_checkpoint().unwrap();
526
527 for entry in &entries {
529 assert!(entry.verify(), "All returned entries should be valid");
530 }
531
532 assert!(
535 entries.len() <= 2,
536 "Should have at most 2 entries after skipping corrupted ones"
537 );
538 }
539 }
540}