1use anyhow::{anyhow, Result};
36use oxicode::{Decode, Encode};
37use serde::{Deserialize, Serialize};
38use std::collections::HashMap;
39use std::fs::{File, OpenOptions};
40use std::io::{BufReader, BufWriter, Read, Write};
41use std::path::PathBuf;
42use std::sync::{Arc, Mutex};
43use std::time::{SystemTime, UNIX_EPOCH};
44
45const WAL_MAGIC: &[u8; 4] = b"WALV"; const WAL_VERSION: u32 = 1;
50
51#[derive(Debug, Clone, Serialize, Deserialize, Encode, Decode)]
53pub enum WalEntry {
54 Insert {
56 id: String,
57 vector: Vec<f32>,
58 metadata: Option<HashMap<String, String>>,
59 timestamp: u64,
60 },
61 Update {
63 id: String,
64 vector: Vec<f32>,
65 metadata: Option<HashMap<String, String>>,
66 timestamp: u64,
67 },
68 Delete { id: String, timestamp: u64 },
70 Batch {
72 entries: Vec<WalEntry>,
73 timestamp: u64,
74 },
75 Checkpoint {
77 sequence_number: u64,
78 timestamp: u64,
79 },
80 BeginTransaction { transaction_id: u64, timestamp: u64 },
82 CommitTransaction { transaction_id: u64, timestamp: u64 },
84 AbortTransaction { transaction_id: u64, timestamp: u64 },
86}
87
88impl WalEntry {
89 pub fn timestamp(&self) -> u64 {
91 match self {
92 WalEntry::Insert { timestamp, .. }
93 | WalEntry::Update { timestamp, .. }
94 | WalEntry::Delete { timestamp, .. }
95 | WalEntry::Batch { timestamp, .. }
96 | WalEntry::Checkpoint { timestamp, .. }
97 | WalEntry::BeginTransaction { timestamp, .. }
98 | WalEntry::CommitTransaction { timestamp, .. }
99 | WalEntry::AbortTransaction { timestamp, .. } => *timestamp,
100 }
101 }
102
103 pub fn is_checkpoint(&self) -> bool {
105 matches!(self, WalEntry::Checkpoint { .. })
106 }
107}
108
109#[derive(Debug, Clone)]
111pub struct WalConfig {
112 pub wal_directory: PathBuf,
114 pub max_file_size: u64,
116 pub sync_on_write: bool,
118 pub checkpoint_interval: u64,
120 pub checkpoint_retention: usize,
122 pub buffer_size: usize,
124}
125
126impl Default for WalConfig {
127 fn default() -> Self {
128 Self {
129 wal_directory: PathBuf::from("./wal"),
130 max_file_size: 100 * 1024 * 1024, sync_on_write: false, checkpoint_interval: 10000,
133 checkpoint_retention: 3,
134 buffer_size: 64 * 1024, }
136 }
137}
138
139pub struct WalManager {
141 config: WalConfig,
142 current_file: Arc<Mutex<Option<BufWriter<File>>>>,
143 current_file_path: Arc<Mutex<PathBuf>>,
144 sequence_number: Arc<Mutex<u64>>,
145 last_checkpoint: Arc<Mutex<u64>>,
146}
147
148impl WalManager {
149 pub fn new(config: WalConfig) -> Result<Self> {
151 std::fs::create_dir_all(&config.wal_directory)?;
153
154 let manager = Self {
155 config,
156 current_file: Arc::new(Mutex::new(None)),
157 current_file_path: Arc::new(Mutex::new(PathBuf::new())),
158 sequence_number: Arc::new(Mutex::new(0)),
159 last_checkpoint: Arc::new(Mutex::new(0)),
160 };
161
162 manager.rotate_wal_file()?;
164
165 Ok(manager)
166 }
167
168 pub fn append(&self, entry: WalEntry) -> Result<u64> {
170 let seq = {
171 let mut seq_guard = self
172 .sequence_number
173 .lock()
174 .expect("mutex lock should not be poisoned");
175 let seq = *seq_guard;
176 *seq_guard += 1;
177 seq
178 };
179
180 let needs_checkpoint = {
182 let mut file_guard = self
183 .current_file
184 .lock()
185 .expect("mutex lock should not be poisoned");
186
187 if let Some(ref mut writer) = *file_guard {
188 let entry_bytes =
190 oxicode::serde::encode_to_vec(&entry, oxicode::config::standard())
191 .map_err(|e| anyhow!("Failed to serialize WAL entry: {}", e))?;
192 let entry_len = entry_bytes.len() as u32;
193
194 writer.write_all(&seq.to_le_bytes())?;
196 writer.write_all(&entry_len.to_le_bytes())?;
197 writer.write_all(&entry_bytes)?;
198
199 if self.config.sync_on_write {
200 writer.flush()?;
201 writer.get_ref().sync_all()?;
202 }
203
204 let needs_rotation = if let Ok(metadata) = writer.get_ref().metadata() {
206 metadata.len() >= self.config.max_file_size
207 } else {
208 false
209 };
210
211 if needs_rotation {
212 drop(file_guard);
213 self.rotate_wal_file()?;
214 }
215
216 let last_checkpoint = *self
218 .last_checkpoint
219 .lock()
220 .expect("mutex lock should not be poisoned");
221 seq - last_checkpoint >= self.config.checkpoint_interval
222 } else {
223 return Err(anyhow!("WAL file not open"));
224 }
225 };
226
227 if needs_checkpoint {
229 self.checkpoint(seq)?;
230 }
231
232 Ok(seq)
233 }
234
235 pub fn checkpoint(&self, sequence_number: u64) -> Result<()> {
237 tracing::info!("Creating WAL checkpoint at sequence {}", sequence_number);
238
239 let timestamp = SystemTime::now()
240 .duration_since(UNIX_EPOCH)
241 .expect("system time should be after UNIX_EPOCH")
242 .as_secs();
243
244 let checkpoint_entry = WalEntry::Checkpoint {
245 sequence_number,
246 timestamp,
247 };
248
249 let seq = {
251 let mut seq_guard = self
252 .sequence_number
253 .lock()
254 .expect("mutex lock should not be poisoned");
255 let seq = *seq_guard;
256 *seq_guard += 1;
257 seq
258 };
259
260 {
261 let mut file_guard = self
262 .current_file
263 .lock()
264 .expect("mutex lock should not be poisoned");
265 if let Some(ref mut writer) = *file_guard {
266 let entry_bytes =
267 oxicode::serde::encode_to_vec(&checkpoint_entry, oxicode::config::standard())
268 .map_err(|e| anyhow!("Failed to serialize checkpoint entry: {}", e))?;
269 let entry_len = entry_bytes.len() as u32;
270
271 writer.write_all(&seq.to_le_bytes())?;
272 writer.write_all(&entry_len.to_le_bytes())?;
273 writer.write_all(&entry_bytes)?;
274
275 if self.config.sync_on_write {
276 writer.flush()?;
277 writer.get_ref().sync_all()?;
278 }
279 }
280 }
281
282 let mut last_checkpoint = self
283 .last_checkpoint
284 .lock()
285 .expect("mutex lock should not be poisoned");
286 *last_checkpoint = sequence_number;
287
288 self.cleanup_old_files()?;
290
291 Ok(())
292 }
293
294 fn rotate_wal_file(&self) -> Result<()> {
296 let timestamp = SystemTime::now()
297 .duration_since(UNIX_EPOCH)
298 .expect("system time should be after UNIX_EPOCH")
299 .as_secs();
300
301 let filename = format!("wal-{:016x}.log", timestamp);
302 let filepath = self.config.wal_directory.join(&filename);
303
304 tracing::info!("Rotating WAL to new file: {:?}", filepath);
305
306 let file = OpenOptions::new()
307 .create(true)
308 .append(true)
309 .open(&filepath)?;
310
311 let mut writer = BufWriter::with_capacity(self.config.buffer_size, file);
312
313 writer.write_all(WAL_MAGIC)?;
315 writer.write_all(&WAL_VERSION.to_le_bytes())?;
316 writer.write_all(×tamp.to_le_bytes())?;
317
318 if self.config.sync_on_write {
319 writer.flush()?;
320 writer.get_ref().sync_all()?;
321 }
322
323 let mut file_guard = self
324 .current_file
325 .lock()
326 .expect("mutex lock should not be poisoned");
327 let mut path_guard = self
328 .current_file_path
329 .lock()
330 .expect("mutex lock should not be poisoned");
331
332 if let Some(mut old_writer) = file_guard.take() {
334 old_writer.flush()?;
335 }
336
337 *file_guard = Some(writer);
338 *path_guard = filepath;
339
340 Ok(())
341 }
342
343 fn cleanup_old_files(&self) -> Result<()> {
345 let mut wal_files: Vec<_> = std::fs::read_dir(&self.config.wal_directory)?
346 .filter_map(|entry| entry.ok())
347 .filter(|entry| {
348 entry
349 .file_name()
350 .to_str()
351 .map(|s| s.starts_with("wal-") && s.ends_with(".log"))
352 .unwrap_or(false)
353 })
354 .collect();
355
356 wal_files.sort_by_key(|entry| entry.file_name());
358
359 if wal_files.len() > self.config.checkpoint_retention {
361 let to_remove = wal_files.len() - self.config.checkpoint_retention;
362 for entry in wal_files.iter().take(to_remove) {
363 tracing::info!("Removing old WAL file: {:?}", entry.path());
364 std::fs::remove_file(entry.path())?;
365 }
366 }
367
368 Ok(())
369 }
370
371 pub fn recover(&self) -> Result<Vec<WalEntry>> {
373 tracing::info!("Starting WAL recovery");
374
375 let mut all_entries = Vec::new();
376 let mut last_checkpoint_seq = 0u64;
377
378 let mut wal_files: Vec<_> = std::fs::read_dir(&self.config.wal_directory)?
380 .filter_map(|entry| entry.ok())
381 .filter(|entry| {
382 entry
383 .file_name()
384 .to_str()
385 .map(|s| s.starts_with("wal-") && s.ends_with(".log"))
386 .unwrap_or(false)
387 })
388 .collect();
389
390 wal_files.sort_by_key(|entry| entry.file_name());
392
393 for entry in wal_files {
395 let path = entry.path();
396 tracing::debug!("Reading WAL file: {:?}", path);
397
398 let file = File::open(&path)?;
399 let mut reader = BufReader::new(file);
400
401 let mut magic = [0u8; 4];
403 reader.read_exact(&mut magic)?;
404 if &magic != WAL_MAGIC {
405 tracing::warn!("Invalid WAL file magic number: {:?}", path);
406 continue;
407 }
408
409 let mut version_bytes = [0u8; 4];
411 reader.read_exact(&mut version_bytes)?;
412 let version = u32::from_le_bytes(version_bytes);
413 if version != WAL_VERSION {
414 tracing::warn!("Unsupported WAL version {} in {:?}", version, path);
415 continue;
416 }
417
418 let mut timestamp_bytes = [0u8; 8];
420 reader.read_exact(&mut timestamp_bytes)?;
421
422 loop {
424 let mut seq_bytes = [0u8; 8];
426 match reader.read_exact(&mut seq_bytes) {
427 Ok(_) => {}
428 Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => {
429 tracing::debug!("Reached end of WAL file (expected)");
430 break;
431 }
432 Err(e) => return Err(e.into()),
433 }
434 let seq = u64::from_le_bytes(seq_bytes);
435
436 let mut len_bytes = [0u8; 4];
438 match reader.read_exact(&mut len_bytes) {
439 Ok(_) => {}
440 Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => {
441 tracing::warn!(
442 "Incomplete entry at sequence {}: missing length field. Skipping rest of file.",
443 seq
444 );
445 break;
446 }
447 Err(e) => return Err(e.into()),
448 }
449 let len = u32::from_le_bytes(len_bytes);
450
451 if len > 100_000_000 {
453 tracing::warn!(
455 "Entry at sequence {} has suspicious length {}. Possibly corrupted. Skipping.",
456 seq,
457 len
458 );
459 break;
460 }
461
462 let mut entry_bytes = vec![0u8; len as usize];
464 match reader.read_exact(&mut entry_bytes) {
465 Ok(_) => {}
466 Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => {
467 tracing::warn!(
468 "Incomplete entry at sequence {}: expected {} bytes but reached EOF. Skipping rest of file.",
469 seq,
470 len
471 );
472 break;
473 }
474 Err(e) => return Err(e.into()),
475 }
476
477 let entry: WalEntry = match oxicode::serde::decode_from_slice(
479 &entry_bytes,
480 oxicode::config::standard(),
481 ) {
482 Ok((e, _)) => e,
483 Err(e) => {
484 tracing::warn!(
485 "Failed to deserialize entry at sequence {}: {}. Skipping entry.",
486 seq,
487 e
488 );
489 continue; }
491 };
492
493 if let WalEntry::Checkpoint {
495 sequence_number, ..
496 } = &entry
497 {
498 last_checkpoint_seq = *sequence_number;
499 }
500
501 all_entries.push((seq, entry));
502 }
503 }
504
505 let recovered_entries: Vec<_> = all_entries
509 .iter()
510 .filter(|(seq, _)| {
511 if last_checkpoint_seq == 0 {
512 true } else {
514 *seq > last_checkpoint_seq }
516 })
517 .map(|(_, entry)| entry.clone())
518 .collect();
519
520 tracing::info!(
521 "Recovered {} entries from WAL (after checkpoint {})",
522 recovered_entries.len(),
523 last_checkpoint_seq
524 );
525
526 if let Some((max_seq, _)) = all_entries.iter().max_by_key(|(seq, _)| seq) {
528 let mut seq = self
529 .sequence_number
530 .lock()
531 .expect("mutex lock should not be poisoned");
532 *seq = max_seq + 1;
533 }
534
535 Ok(recovered_entries)
536 }
537
538 pub fn flush(&self) -> Result<()> {
540 let mut file_guard = self
541 .current_file
542 .lock()
543 .expect("mutex lock should not be poisoned");
544 if let Some(ref mut writer) = *file_guard {
545 writer.flush()?;
546 writer.get_ref().sync_all()?;
547 }
548 Ok(())
549 }
550
551 pub fn current_sequence(&self) -> u64 {
553 *self
554 .sequence_number
555 .lock()
556 .expect("mutex lock should not be poisoned")
557 }
558
559 pub fn last_checkpoint_sequence(&self) -> u64 {
561 *self
562 .last_checkpoint
563 .lock()
564 .expect("mutex lock should not be poisoned")
565 }
566}
567
568impl Drop for WalManager {
569 fn drop(&mut self) {
570 let _ = self.flush();
572 }
573}
574
575#[cfg(test)]
576mod tests {
577 use super::*;
578 use anyhow::Result;
579 use tempfile::TempDir;
580
581 #[test]
582 fn test_wal_creation() -> Result<()> {
583 let temp_dir = TempDir::new()?;
584 let config = WalConfig {
585 wal_directory: temp_dir.path().to_path_buf(),
586 ..Default::default()
587 };
588
589 let wal = WalManager::new(config)?;
590 assert_eq!(wal.current_sequence(), 0);
591 Ok(())
592 }
593
594 #[test]
595 fn test_wal_append() -> Result<()> {
596 let temp_dir = TempDir::new()?;
597 let config = WalConfig {
598 wal_directory: temp_dir.path().to_path_buf(),
599 sync_on_write: true,
600 ..Default::default()
601 };
602
603 let wal = WalManager::new(config)?;
604
605 let entry = WalEntry::Insert {
606 id: "vec1".to_string(),
607 vector: vec![1.0, 2.0, 3.0],
608 metadata: None,
609 timestamp: 12345,
610 };
611
612 let seq = wal.append(entry)?;
613 assert_eq!(seq, 0);
614 Ok(())
615 }
616
617 #[test]
618 fn test_wal_recovery() -> Result<()> {
619 let temp_dir = TempDir::new()?;
620 let config = WalConfig {
621 wal_directory: temp_dir.path().to_path_buf(),
622 sync_on_write: true,
623 checkpoint_interval: 100,
624 ..Default::default()
625 };
626
627 {
629 let wal = WalManager::new(config.clone())?;
630
631 for i in 0..5 {
632 let entry = WalEntry::Insert {
633 id: format!("vec{}", i),
634 vector: vec![i as f32, (i * 2) as f32],
635 metadata: None,
636 timestamp: (i + 1) * 1000, };
638 wal.append(entry)?;
639 }
640
641 wal.flush()?;
642 drop(wal);
644 }
645
646 std::thread::sleep(std::time::Duration::from_millis(100));
648
649 {
651 let wal = WalManager::new(config)?;
652 let recovered = wal.recover()?;
653
654 assert_eq!(
656 recovered.len(),
657 5,
658 "Expected exactly 5 entries, got {}",
659 recovered.len()
660 );
661
662 let timestamps: Vec<u64> = recovered.iter().map(|e| e.timestamp()).collect();
664 assert_eq!(timestamps, vec![1000, 2000, 3000, 4000, 5000]);
665 }
666 Ok(())
667 }
668
669 #[test]
670 fn test_wal_checkpoint() -> Result<()> {
671 let temp_dir = TempDir::new()?;
672 let config = WalConfig {
673 wal_directory: temp_dir.path().to_path_buf(),
674 sync_on_write: true,
675 checkpoint_interval: 3,
676 ..Default::default()
677 };
678
679 let wal = WalManager::new(config)?;
680
681 for i in 0..5 {
683 let entry = WalEntry::Insert {
684 id: format!("vec{}", i),
685 vector: vec![i as f32],
686 metadata: None,
687 timestamp: i,
688 };
689 wal.append(entry)?;
690 }
691
692 assert!(wal.last_checkpoint_sequence() > 0);
693 Ok(())
694 }
695
696 #[test]
697 fn test_wal_batch_operation() -> Result<()> {
698 let temp_dir = TempDir::new()?;
699 let config = WalConfig {
700 wal_directory: temp_dir.path().to_path_buf(),
701 ..Default::default()
702 };
703
704 let wal = WalManager::new(config)?;
705
706 let batch = WalEntry::Batch {
707 entries: vec![
708 WalEntry::Insert {
709 id: "vec1".to_string(),
710 vector: vec![1.0],
711 metadata: None,
712 timestamp: 1,
713 },
714 WalEntry::Update {
715 id: "vec2".to_string(),
716 vector: vec![2.0],
717 metadata: None,
718 timestamp: 2,
719 },
720 ],
721 timestamp: 3,
722 };
723
724 wal.append(batch)?;
725 wal.flush()?;
726 Ok(())
727 }
728
729 #[test]
730 fn test_wal_transaction() -> Result<()> {
731 let temp_dir = TempDir::new()?;
732 let config = WalConfig {
733 wal_directory: temp_dir.path().to_path_buf(),
734 ..Default::default()
735 };
736
737 let wal = WalManager::new(config)?;
738
739 wal.append(WalEntry::BeginTransaction {
741 transaction_id: 1,
742 timestamp: 100,
743 })?;
744
745 wal.append(WalEntry::Insert {
747 id: "vec1".to_string(),
748 vector: vec![1.0],
749 metadata: None,
750 timestamp: 101,
751 })?;
752
753 wal.append(WalEntry::CommitTransaction {
755 transaction_id: 1,
756 timestamp: 102,
757 })?;
758
759 wal.flush()?;
760 Ok(())
761 }
762}