1use serde::{Deserialize, Serialize};
42use std::fs::{File, OpenOptions};
43use std::io::{self, BufReader, BufWriter, Read, Write};
44use std::path::{Path, PathBuf};
45use thiserror::Error;
46use tracing::{debug, info, warn};
47
48#[derive(Error, Debug)]
50pub enum WalError {
51 #[error("I/O error: {0}")]
53 Io(#[from] io::Error),
54
55 #[error("Serialization error: {0}")]
57 Serialization(#[from] bincode::Error),
58
59 #[error("WAL corruption detected at offset {0}")]
61 Corruption(u64),
62
63 #[error("Invalid log entry: {0}")]
65 InvalidEntry(String),
66}
67
68pub type WalResult<T> = Result<T, WalError>;
69
70#[derive(Debug, Clone, Serialize, Deserialize)]
72pub enum WalEntry {
73 CreateNode {
75 tenant: String,
76 node_id: u64,
77 labels: Vec<String>,
78 properties: Vec<u8>, },
80 CreateEdge {
82 tenant: String,
83 edge_id: u64,
84 source: u64,
85 target: u64,
86 edge_type: String,
87 properties: Vec<u8>, },
89 DeleteNode { tenant: String, node_id: u64 },
91 DeleteEdge { tenant: String, edge_id: u64 },
93 UpdateNodeProperties {
95 tenant: String,
96 node_id: u64,
97 properties: Vec<u8>,
98 },
99 UpdateEdgeProperties {
101 tenant: String,
102 edge_id: u64,
103 properties: Vec<u8>,
104 },
105 Checkpoint { sequence: u64, timestamp: i64 },
107}
108
109#[derive(Debug, Clone, Serialize, Deserialize)]
111struct WalRecord {
112 sequence: u64,
114 entry: WalEntry,
116 checksum: u32,
118}
119
120impl WalRecord {
121 fn new(sequence: u64, entry: WalEntry) -> Self {
122 let mut record = Self {
123 sequence,
124 entry,
125 checksum: 0,
126 };
127 record.checksum = record.calculate_checksum();
129 record
130 }
131
132 fn calculate_checksum(&self) -> u32 {
133 let bytes = bincode::serialize(&self.entry).unwrap_or_default();
135 bytes.iter().fold(0u32, |acc, &b| acc ^ (b as u32))
136 }
137
138 fn verify_checksum(&self) -> bool {
139 self.checksum == self.calculate_checksum()
140 }
141}
142
143pub struct Wal {
145 path: PathBuf,
147 current_file: Option<BufWriter<File>>,
149 sequence: u64,
151 sync_mode: bool,
153}
154
155impl Wal {
156 pub fn new(path: impl AsRef<Path>) -> WalResult<Self> {
158 let path = path.as_ref().to_path_buf();
159
160 std::fs::create_dir_all(&path)?;
162
163 let sequence = Self::find_latest_sequence(&path)?;
165
166 info!("Initializing WAL at {:?}, sequence: {}", path, sequence);
167
168 Ok(Self {
169 path,
170 current_file: None,
171 sequence,
172 sync_mode: false, })
174 }
175
176 pub fn set_sync_mode(&mut self, sync: bool) {
178 self.sync_mode = sync;
179 debug!("WAL sync mode: {}", sync);
180 }
181
182 pub fn current_sequence(&self) -> u64 {
189 self.sequence
190 }
191
192 pub fn append(&mut self, entry: WalEntry) -> WalResult<u64> {
194 self.sequence += 1;
196 let sequence = self.sequence;
197
198 let record = WalRecord::new(sequence, entry);
200
201 let data = bincode::serialize(&record)?;
203
204 if self.current_file.is_none() {
206 self.open_new_file()?;
207 }
208
209 if let Some(ref mut file) = self.current_file {
211 file.write_all(&(data.len() as u32).to_le_bytes())?;
213 file.write_all(&data)?;
215
216 if self.sync_mode {
218 file.flush()?;
219 }
220 }
221
222 Ok(sequence)
223 }
224
225 pub fn flush(&mut self) -> WalResult<()> {
227 if let Some(ref mut file) = self.current_file {
228 file.flush()?;
229 }
230 Ok(())
231 }
232
233 pub fn replay<F>(&self, from_sequence: u64, mut callback: F) -> WalResult<u64>
235 where
236 F: FnMut(&WalEntry) -> WalResult<()>,
237 {
238 info!("Replaying WAL from sequence {}", from_sequence);
239
240 let files = self.get_wal_files()?;
241 let mut replayed = 0u64;
242 let mut last_sequence = from_sequence;
243
244 for file_path in files {
245 let file = File::open(&file_path)?;
246 let mut reader = BufReader::new(file);
247 let mut buf = Vec::new();
248
249 loop {
250 let mut len_bytes = [0u8; 4];
252 match reader.read_exact(&mut len_bytes) {
253 Ok(_) => {}
254 Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => break,
255 Err(e) => return Err(e.into()),
256 }
257
258 let len = u32::from_le_bytes(len_bytes) as usize;
259
260 buf.resize(len, 0);
262 reader.read_exact(&mut buf)?;
263
264 let record: WalRecord = bincode::deserialize(&buf)?;
266
267 if !record.verify_checksum() {
269 warn!("WAL corruption detected at sequence {}", record.sequence);
270 return Err(WalError::Corruption(record.sequence));
271 }
272
273 if record.sequence < from_sequence {
275 continue;
276 }
277
278 callback(&record.entry)?;
280 replayed += 1;
281 last_sequence = record.sequence;
282 }
283 }
284
285 info!(
286 "Replayed {} WAL entries, last sequence: {}",
287 replayed, last_sequence
288 );
289 Ok(last_sequence)
290 }
291
292 pub fn checkpoint(&mut self, sequence: u64) -> WalResult<()> {
294 info!("Creating WAL checkpoint at sequence {}", sequence);
295
296 let timestamp = chrono::Utc::now().timestamp();
298 self.append(WalEntry::Checkpoint {
299 sequence,
300 timestamp,
301 })?;
302
303 self.flush()?;
305
306 self.current_file = None;
308
309 Ok(())
314 }
315
316 fn open_new_file(&mut self) -> WalResult<()> {
318 let filename = format!("wal-{:016x}.log", self.sequence);
319 let file_path = self.path.join(filename);
320
321 debug!("Opening new WAL file: {:?}", file_path);
322
323 let file = OpenOptions::new()
324 .create(true)
325 .append(true)
326 .open(file_path)?;
327
328 self.current_file = Some(BufWriter::new(file));
329 Ok(())
330 }
331
332 fn find_latest_sequence(path: &Path) -> WalResult<u64> {
334 let files = match std::fs::read_dir(path) {
335 Ok(entries) => entries,
336 Err(_) => return Ok(0), };
338
339 let mut max_sequence = 0u64;
340
341 for entry in files.flatten() {
342 if let Some(filename) = entry.file_name().to_str() {
343 if filename.starts_with("wal-") && filename.ends_with(".log") {
344 if let Some(seq_str) = filename
346 .strip_prefix("wal-")
347 .and_then(|s| s.strip_suffix(".log"))
348 {
349 if let Ok(seq) = u64::from_str_radix(seq_str, 16) {
350 max_sequence = max_sequence.max(seq);
351 }
352 }
353 }
354 }
355 }
356
357 Ok(max_sequence)
358 }
359
360 fn get_wal_files(&self) -> WalResult<Vec<PathBuf>> {
362 let mut files = Vec::new();
363
364 let entries = std::fs::read_dir(&self.path)?;
365
366 for entry in entries.flatten() {
367 if let Some(filename) = entry.file_name().to_str() {
368 if filename.starts_with("wal-") && filename.ends_with(".log") {
369 files.push(entry.path());
370 }
371 }
372 }
373
374 files.sort();
376
377 Ok(files)
378 }
379}
380
381#[cfg(test)]
382mod tests {
383 use super::*;
384 use tempfile::TempDir;
385
386 #[test]
387 fn test_wal_creation() {
388 let temp_dir = TempDir::new().unwrap();
389 let wal = Wal::new(temp_dir.path()).unwrap();
390 assert_eq!(wal.sequence, 0);
391 }
392
393 #[test]
394 fn test_wal_append() {
395 let temp_dir = TempDir::new().unwrap();
396 let mut wal = Wal::new(temp_dir.path()).unwrap();
397
398 let entry = WalEntry::CreateNode {
399 tenant: "default".to_string(),
400 node_id: 1,
401 labels: vec!["Person".to_string()],
402 properties: vec![],
403 };
404
405 let seq = wal.append(entry).unwrap();
406 assert_eq!(seq, 1);
407
408 wal.flush().unwrap();
409 }
410
411 #[test]
412 fn test_wal_replay() {
413 let temp_dir = TempDir::new().unwrap();
414 let mut wal = Wal::new(temp_dir.path()).unwrap();
415
416 for i in 1..=5 {
418 let entry = WalEntry::CreateNode {
419 tenant: "default".to_string(),
420 node_id: i,
421 labels: vec![],
422 properties: vec![],
423 };
424 wal.append(entry).unwrap();
425 }
426
427 wal.flush().unwrap();
428
429 let mut count = 0;
431 wal.replay(0, |_entry| {
432 count += 1;
433 Ok(())
434 })
435 .unwrap();
436
437 assert_eq!(count, 5);
438 }
439
440 #[test]
441 fn test_wal_checkpoint() {
442 let temp_dir = TempDir::new().unwrap();
443 let mut wal = Wal::new(temp_dir.path()).unwrap();
444
445 for i in 1..=10 {
447 let entry = WalEntry::CreateNode {
448 tenant: "default".to_string(),
449 node_id: i,
450 labels: vec![],
451 properties: vec![],
452 };
453 wal.append(entry).unwrap();
454 }
455
456 wal.checkpoint(10).unwrap();
458
459 let mut found_checkpoint = false;
461 wal.replay(0, |entry| {
462 if matches!(entry, WalEntry::Checkpoint { .. }) {
463 found_checkpoint = true;
464 }
465 Ok(())
466 })
467 .unwrap();
468
469 assert!(found_checkpoint);
470 }
471}