1use std::fs::{File, OpenOptions};
6use std::io::{self, BufWriter, Read, Seek, SeekFrom, Write};
7use std::path::Path;
8
9#[cfg(windows)]
12fn configure_open_options(opts: &mut OpenOptions) {
13 use std::os::windows::fs::OpenOptionsExt;
14 opts.share_mode(0x1 | 0x2 | 0x4);
16}
17
18#[cfg(not(windows))]
19fn configure_open_options(_opts: &mut OpenOptions) {
20 }
22
23#[derive(Debug, Clone, Copy, PartialEq, Eq)]
25#[repr(u8)]
26pub enum WalEntryType {
27 InsertNode = 1,
29 DeleteNode = 2,
31 UpdateNeighbors = 3,
33 UpdateMetadata = 4,
35 Checkpoint = 100,
37}
38
39impl From<u8> for WalEntryType {
40 fn from(v: u8) -> Self {
41 match v {
42 1 => Self::InsertNode,
43 2 => Self::DeleteNode,
44 3 => Self::UpdateNeighbors,
45 4 => Self::UpdateMetadata,
46 _ => Self::Checkpoint, }
48 }
49}
50
51#[derive(Debug, Clone)]
54pub struct WalEntryHeader {
55 pub entry_type: WalEntryType,
56 pub timestamp: u64, pub data_len: u32,
58 pub checksum: u32,
59}
60
61impl WalEntryHeader {
62 pub const SIZE: usize = 20;
63
64 pub fn to_bytes(&self) -> [u8; Self::SIZE] {
65 let mut buf = [0u8; Self::SIZE];
66 buf[0] = self.entry_type as u8;
67 buf[4..12].copy_from_slice(&self.timestamp.to_le_bytes());
69 buf[12..16].copy_from_slice(&self.data_len.to_le_bytes());
70 buf[16..20].copy_from_slice(&self.checksum.to_le_bytes());
71 buf
72 }
73
74 pub fn from_bytes(buf: &[u8; Self::SIZE]) -> Self {
75 Self {
77 entry_type: WalEntryType::from(buf[0]),
78 timestamp: u64::from_le_bytes([
79 buf[4], buf[5], buf[6], buf[7], buf[8], buf[9], buf[10], buf[11],
80 ]),
81 data_len: u32::from_le_bytes([buf[12], buf[13], buf[14], buf[15]]),
82 checksum: u32::from_le_bytes([buf[16], buf[17], buf[18], buf[19]]),
83 }
84 }
85}
86
87#[derive(Debug, Clone)]
89pub struct WalEntry {
90 pub header: WalEntryHeader,
91 pub data: Vec<u8>,
92}
93
94impl WalEntry {
95 #[must_use]
97 pub fn insert_node(
98 timestamp: u64,
99 string_id: &str,
100 level: u8,
101 vector: &[f32],
102 metadata: &[u8],
103 ) -> Self {
104 let mut data = Vec::new();
105
106 data.extend_from_slice(&(string_id.len() as u32).to_le_bytes());
108 data.extend_from_slice(string_id.as_bytes());
109
110 data.push(level);
112
113 data.extend_from_slice(&(vector.len() as u32).to_le_bytes());
115 for &val in vector {
116 data.extend_from_slice(&val.to_le_bytes());
117 }
118
119 data.extend_from_slice(&(metadata.len() as u32).to_le_bytes());
121 data.extend_from_slice(metadata);
122
123 let checksum = crc32fast::hash(&data);
124
125 Self {
126 header: WalEntryHeader {
127 entry_type: WalEntryType::InsertNode,
128 timestamp,
129 data_len: data.len() as u32,
130 checksum,
131 },
132 data,
133 }
134 }
135
136 #[must_use]
138 pub fn delete_node(timestamp: u64, string_id: &str) -> Self {
139 let mut data = Vec::new();
140 data.extend_from_slice(&(string_id.len() as u32).to_le_bytes());
141 data.extend_from_slice(string_id.as_bytes());
142
143 let checksum = crc32fast::hash(&data);
144
145 Self {
146 header: WalEntryHeader {
147 entry_type: WalEntryType::DeleteNode,
148 timestamp,
149 data_len: data.len() as u32,
150 checksum,
151 },
152 data,
153 }
154 }
155
156 #[must_use]
158 pub fn update_neighbors(timestamp: u64, node_id: u32, level: u8, neighbors: &[u32]) -> Self {
159 let mut data = Vec::new();
160
161 data.extend_from_slice(&node_id.to_le_bytes());
163
164 data.push(level);
166
167 data.extend_from_slice(&(neighbors.len() as u32).to_le_bytes());
169 for &neighbor in neighbors {
170 data.extend_from_slice(&neighbor.to_le_bytes());
171 }
172
173 let checksum = crc32fast::hash(&data);
174
175 Self {
176 header: WalEntryHeader {
177 entry_type: WalEntryType::UpdateNeighbors,
178 timestamp,
179 data_len: data.len() as u32,
180 checksum,
181 },
182 data,
183 }
184 }
185
186 #[must_use]
188 pub fn checkpoint(timestamp: u64) -> Self {
189 Self {
190 header: WalEntryHeader {
191 entry_type: WalEntryType::Checkpoint,
192 timestamp,
193 data_len: 0,
194 checksum: 0,
195 },
196 data: Vec::new(),
197 }
198 }
199
200 #[must_use]
202 pub fn verify(&self) -> bool {
203 if self.data.is_empty() {
204 return self.header.checksum == 0;
205 }
206 crc32fast::hash(&self.data) == self.header.checksum
207 }
208}
209
210pub struct Wal {
212 file: BufWriter<File>,
213 #[allow(dead_code)]
214 path: std::path::PathBuf,
215 next_timestamp: u64,
216 entry_count: u64,
217}
218
219impl Wal {
220 pub fn open(path: impl AsRef<Path>) -> io::Result<Self> {
222 let path = path.as_ref().to_path_buf();
223 let mut opts = OpenOptions::new();
224 opts.read(true).write(true).create(true);
227 configure_open_options(&mut opts);
228 let mut file = opts.open(&path)?;
229
230 let metadata = file.metadata()?;
231 let file_len = metadata.len();
232
233 if file_len > 0 {
235 file.seek(SeekFrom::End(0))?;
236 }
237
238 let mut wal = Self {
239 file: BufWriter::new(file),
240 path,
241 next_timestamp: 0,
242 entry_count: 0,
243 };
244
245 if file_len > 0 {
247 wal.scan_for_timestamp()?;
248 }
249
250 Ok(wal)
251 }
252
253 fn scan_for_timestamp(&mut self) -> io::Result<()> {
255 let file = self.file.get_mut();
256 file.seek(SeekFrom::Start(0))?;
257
258 let mut header_buf = [0u8; WalEntryHeader::SIZE];
259 let mut max_timestamp = 0u64;
260 let mut count = 0u64;
261
262 loop {
263 match file.read_exact(&mut header_buf) {
264 Ok(()) => {
265 let header = WalEntryHeader::from_bytes(&header_buf);
266 max_timestamp = max_timestamp.max(header.timestamp);
267 count += 1;
268
269 if header.data_len > 0 {
271 file.seek(SeekFrom::Current(header.data_len as i64))?;
272 }
273 }
274 Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => break,
275 Err(e) => return Err(e),
276 }
277 }
278
279 self.next_timestamp = max_timestamp + 1;
280 self.entry_count = count;
281
282 file.seek(SeekFrom::End(0))?;
284
285 Ok(())
286 }
287
288 pub fn append(&mut self, mut entry: WalEntry) -> io::Result<()> {
290 entry.header.timestamp = self.next_timestamp;
291 self.next_timestamp += 1;
292
293 self.file.write_all(&entry.header.to_bytes())?;
294 if !entry.data.is_empty() {
295 self.file.write_all(&entry.data)?;
296 }
297
298 self.entry_count += 1;
299 Ok(())
300 }
301
302 pub fn sync(&mut self) -> io::Result<()> {
304 self.file.flush()?;
305 self.file.get_mut().sync_all()
306 }
307
308 pub fn entries_after_checkpoint(&mut self) -> io::Result<Vec<WalEntry>> {
310 let file = self.file.get_mut();
311 file.seek(SeekFrom::Start(0))?;
312
313 let mut all_entries = Vec::new();
314 let mut last_checkpoint_idx: Option<usize> = None;
315 let mut header_buf = [0u8; WalEntryHeader::SIZE];
316
317 loop {
318 match file.read_exact(&mut header_buf) {
319 Ok(()) => {
320 let header = WalEntryHeader::from_bytes(&header_buf);
321 let mut data = vec![0u8; header.data_len as usize];
322 if header.data_len > 0 {
323 file.read_exact(&mut data)?;
324 }
325
326 let entry = WalEntry { header, data };
327
328 if entry.header.entry_type == WalEntryType::Checkpoint {
329 last_checkpoint_idx = Some(all_entries.len());
330 }
331
332 all_entries.push(entry);
333 }
334 Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => break,
335 Err(e) => return Err(e),
336 }
337 }
338
339 match last_checkpoint_idx {
341 Some(idx) => Ok(all_entries.split_off(idx + 1)),
342 None => Ok(all_entries),
343 }
344 }
345
346 #[must_use]
348 pub fn len(&self) -> u64 {
349 self.entry_count
350 }
351
352 #[must_use]
354 pub fn is_empty(&self) -> bool {
355 self.entry_count == 0
356 }
357
358 pub fn truncate(&mut self) -> io::Result<()> {
360 self.file.flush()?;
362 self.file.get_mut().set_len(0)?;
363 self.file.get_mut().seek(SeekFrom::Start(0))?;
364 self.next_timestamp = 0;
365 self.entry_count = 0;
366 Ok(())
367 }
368}
369
370#[cfg(test)]
371mod tests {
372 use super::*;
373 use tempfile::tempdir;
374
375 #[test]
376 fn test_wal_roundtrip() {
377 let dir = tempdir().unwrap();
378 let wal_path = dir.path().join("test.wal");
379
380 {
381 let mut wal = Wal::open(&wal_path).unwrap();
382 wal.append(WalEntry::insert_node(0, "vec1", 0, &[1.0, 2.0, 3.0], b"{}"))
383 .unwrap();
384 wal.append(WalEntry::delete_node(0, "vec2")).unwrap();
385 wal.append(WalEntry::checkpoint(0)).unwrap();
386 wal.append(WalEntry::insert_node(0, "vec3", 1, &[4.0, 5.0, 6.0], b"{}"))
387 .unwrap();
388 wal.sync().unwrap();
389 }
390
391 {
392 let mut wal = Wal::open(&wal_path).unwrap();
393 let entries = wal.entries_after_checkpoint().unwrap();
394
395 assert_eq!(entries.len(), 1);
397 assert_eq!(entries[0].header.entry_type, WalEntryType::InsertNode);
398 }
399 }
400
401 #[test]
402 fn test_entry_checksum() {
403 let entry = WalEntry::insert_node(1, "test", 0, &[1.0, 2.0], b"metadata");
404 assert!(entry.verify());
405 }
406
407 #[test]
408 fn test_corrupted_entry_data_detected() {
409 let mut entry = WalEntry::insert_node(1, "test", 0, &[1.0, 2.0], b"metadata");
410 assert!(entry.verify());
411
412 if !entry.data.is_empty() {
414 entry.data[0] ^= 0xFF;
415 }
416
417 assert!(!entry.verify(), "Corrupted data should fail verification");
419 }
420
421 #[test]
422 fn test_corrupted_entry_checksum_detected() {
423 let mut entry = WalEntry::insert_node(1, "test", 0, &[1.0, 2.0], b"metadata");
424 assert!(entry.verify());
425
426 entry.header.checksum ^= 0xFFFF_FFFF;
428
429 assert!(
431 !entry.verify(),
432 "Corrupted checksum should fail verification"
433 );
434 }
435
436 #[test]
437 fn test_wal_recovery_skips_corrupted_entries() {
438 use std::io::Write;
439
440 let dir = tempdir().unwrap();
441 let wal_path = dir.path().join("test_corrupt.wal");
442
443 {
445 let mut wal = Wal::open(&wal_path).unwrap();
446 wal.append(WalEntry::insert_node(0, "vec1", 0, &[1.0, 2.0, 3.0], b"{}"))
447 .unwrap();
448 wal.append(WalEntry::insert_node(0, "vec2", 0, &[4.0, 5.0, 6.0], b"{}"))
449 .unwrap();
450 wal.sync().unwrap();
451 }
452
453 {
455 let mut file = OpenOptions::new()
456 .read(true)
457 .write(true)
458 .open(&wal_path)
459 .unwrap();
460
461 file.seek(SeekFrom::Start(40)).unwrap();
465 file.write_all(&[0xFF, 0xFF, 0xFF, 0xFF]).unwrap();
466 file.sync_all().unwrap();
467 }
468
469 {
471 let mut wal = Wal::open(&wal_path).unwrap();
472 let entries = wal.entries_after_checkpoint().unwrap();
473
474 let invalid_count = entries.iter().filter(|e| !e.verify()).count();
476 assert!(
477 invalid_count > 0,
478 "Expected at least one corrupted entry, got none"
479 );
480
481 let valid_count = entries.iter().filter(|e| e.verify()).count();
483 assert!(
486 valid_count + invalid_count == entries.len(),
487 "All entries should be either valid or invalid"
488 );
489 }
490 }
491}