1use anyhow::{anyhow, Result};
36use oxicode::{Decode, Encode};
37use serde::{Deserialize, Serialize};
38use std::collections::HashMap;
39use std::fs::{File, OpenOptions};
40use std::io::{BufReader, BufWriter, Read, Write};
41use std::path::PathBuf;
42use std::sync::{Arc, Mutex};
43use std::time::{SystemTime, UNIX_EPOCH};
44
45const WAL_MAGIC: &[u8; 4] = b"WALV"; const WAL_VERSION: u32 = 1;
50
51#[derive(Debug, Clone, Serialize, Deserialize, Encode, Decode)]
53pub enum WalEntry {
54 Insert {
56 id: String,
57 vector: Vec<f32>,
58 metadata: Option<HashMap<String, String>>,
59 timestamp: u64,
60 },
61 Update {
63 id: String,
64 vector: Vec<f32>,
65 metadata: Option<HashMap<String, String>>,
66 timestamp: u64,
67 },
68 Delete { id: String, timestamp: u64 },
70 Batch {
72 entries: Vec<WalEntry>,
73 timestamp: u64,
74 },
75 Checkpoint {
77 sequence_number: u64,
78 timestamp: u64,
79 },
80 BeginTransaction { transaction_id: u64, timestamp: u64 },
82 CommitTransaction { transaction_id: u64, timestamp: u64 },
84 AbortTransaction { transaction_id: u64, timestamp: u64 },
86}
87
88impl WalEntry {
89 pub fn timestamp(&self) -> u64 {
91 match self {
92 WalEntry::Insert { timestamp, .. }
93 | WalEntry::Update { timestamp, .. }
94 | WalEntry::Delete { timestamp, .. }
95 | WalEntry::Batch { timestamp, .. }
96 | WalEntry::Checkpoint { timestamp, .. }
97 | WalEntry::BeginTransaction { timestamp, .. }
98 | WalEntry::CommitTransaction { timestamp, .. }
99 | WalEntry::AbortTransaction { timestamp, .. } => *timestamp,
100 }
101 }
102
103 pub fn is_checkpoint(&self) -> bool {
105 matches!(self, WalEntry::Checkpoint { .. })
106 }
107}
108
109#[derive(Debug, Clone)]
111pub struct WalConfig {
112 pub wal_directory: PathBuf,
114 pub max_file_size: u64,
116 pub sync_on_write: bool,
118 pub checkpoint_interval: u64,
120 pub checkpoint_retention: usize,
122 pub buffer_size: usize,
124}
125
126impl Default for WalConfig {
127 fn default() -> Self {
128 Self {
129 wal_directory: PathBuf::from("./wal"),
130 max_file_size: 100 * 1024 * 1024, sync_on_write: false, checkpoint_interval: 10000,
133 checkpoint_retention: 3,
134 buffer_size: 64 * 1024, }
136 }
137}
138
139pub struct WalManager {
141 config: WalConfig,
142 current_file: Arc<Mutex<Option<BufWriter<File>>>>,
143 current_file_path: Arc<Mutex<PathBuf>>,
144 sequence_number: Arc<Mutex<u64>>,
145 last_checkpoint: Arc<Mutex<u64>>,
146}
147
148impl WalManager {
149 pub fn new(config: WalConfig) -> Result<Self> {
151 std::fs::create_dir_all(&config.wal_directory)?;
153
154 let manager = Self {
155 config,
156 current_file: Arc::new(Mutex::new(None)),
157 current_file_path: Arc::new(Mutex::new(PathBuf::new())),
158 sequence_number: Arc::new(Mutex::new(0)),
159 last_checkpoint: Arc::new(Mutex::new(0)),
160 };
161
162 manager.rotate_wal_file()?;
164
165 Ok(manager)
166 }
167
168 pub fn append(&self, entry: WalEntry) -> Result<u64> {
170 let seq = {
171 let mut seq_guard = self
172 .sequence_number
173 .lock()
174 .expect("mutex lock should not be poisoned");
175 let seq = *seq_guard;
176 *seq_guard += 1;
177 seq
178 };
179
180 let needs_checkpoint = {
182 let mut file_guard = self
183 .current_file
184 .lock()
185 .expect("mutex lock should not be poisoned");
186
187 if let Some(ref mut writer) = *file_guard {
188 let entry_bytes =
190 oxicode::serde::encode_to_vec(&entry, oxicode::config::standard())
191 .map_err(|e| anyhow!("Failed to serialize WAL entry: {}", e))?;
192 let entry_len = entry_bytes.len() as u32;
193
194 writer.write_all(&seq.to_le_bytes())?;
196 writer.write_all(&entry_len.to_le_bytes())?;
197 writer.write_all(&entry_bytes)?;
198
199 if self.config.sync_on_write {
200 writer.flush()?;
201 writer.get_ref().sync_all()?;
202 }
203
204 let needs_rotation = if let Ok(metadata) = writer.get_ref().metadata() {
206 metadata.len() >= self.config.max_file_size
207 } else {
208 false
209 };
210
211 if needs_rotation {
212 drop(file_guard);
213 self.rotate_wal_file()?;
214 }
215
216 let last_checkpoint = *self
218 .last_checkpoint
219 .lock()
220 .expect("mutex lock should not be poisoned");
221 seq - last_checkpoint >= self.config.checkpoint_interval
222 } else {
223 return Err(anyhow!("WAL file not open"));
224 }
225 };
226
227 if needs_checkpoint {
229 self.checkpoint(seq)?;
230 }
231
232 Ok(seq)
233 }
234
235 pub fn checkpoint(&self, sequence_number: u64) -> Result<()> {
237 tracing::info!("Creating WAL checkpoint at sequence {}", sequence_number);
238
239 let timestamp = SystemTime::now()
240 .duration_since(UNIX_EPOCH)
241 .expect("system time should be after UNIX_EPOCH")
242 .as_secs();
243
244 let checkpoint_entry = WalEntry::Checkpoint {
245 sequence_number,
246 timestamp,
247 };
248
249 let seq = {
251 let mut seq_guard = self
252 .sequence_number
253 .lock()
254 .expect("mutex lock should not be poisoned");
255 let seq = *seq_guard;
256 *seq_guard += 1;
257 seq
258 };
259
260 {
261 let mut file_guard = self
262 .current_file
263 .lock()
264 .expect("mutex lock should not be poisoned");
265 if let Some(ref mut writer) = *file_guard {
266 let entry_bytes =
267 oxicode::serde::encode_to_vec(&checkpoint_entry, oxicode::config::standard())
268 .map_err(|e| anyhow!("Failed to serialize checkpoint entry: {}", e))?;
269 let entry_len = entry_bytes.len() as u32;
270
271 writer.write_all(&seq.to_le_bytes())?;
272 writer.write_all(&entry_len.to_le_bytes())?;
273 writer.write_all(&entry_bytes)?;
274
275 if self.config.sync_on_write {
276 writer.flush()?;
277 writer.get_ref().sync_all()?;
278 }
279 }
280 }
281
282 let mut last_checkpoint = self
283 .last_checkpoint
284 .lock()
285 .expect("mutex lock should not be poisoned");
286 *last_checkpoint = sequence_number;
287
288 self.cleanup_old_files()?;
290
291 Ok(())
292 }
293
294 fn rotate_wal_file(&self) -> Result<()> {
296 let timestamp = SystemTime::now()
297 .duration_since(UNIX_EPOCH)
298 .expect("system time should be after UNIX_EPOCH")
299 .as_secs();
300
301 let filename = format!("wal-{:016x}.log", timestamp);
302 let filepath = self.config.wal_directory.join(&filename);
303
304 tracing::info!("Rotating WAL to new file: {:?}", filepath);
305
306 let file = OpenOptions::new()
307 .create(true)
308 .append(true)
309 .open(&filepath)?;
310
311 let mut writer = BufWriter::with_capacity(self.config.buffer_size, file);
312
313 writer.write_all(WAL_MAGIC)?;
315 writer.write_all(&WAL_VERSION.to_le_bytes())?;
316 writer.write_all(×tamp.to_le_bytes())?;
317
318 if self.config.sync_on_write {
319 writer.flush()?;
320 writer.get_ref().sync_all()?;
321 }
322
323 let mut file_guard = self
324 .current_file
325 .lock()
326 .expect("mutex lock should not be poisoned");
327 let mut path_guard = self
328 .current_file_path
329 .lock()
330 .expect("mutex lock should not be poisoned");
331
332 if let Some(mut old_writer) = file_guard.take() {
334 old_writer.flush()?;
335 }
336
337 *file_guard = Some(writer);
338 *path_guard = filepath;
339
340 Ok(())
341 }
342
343 fn cleanup_old_files(&self) -> Result<()> {
345 let mut wal_files: Vec<_> = std::fs::read_dir(&self.config.wal_directory)?
346 .filter_map(|entry| entry.ok())
347 .filter(|entry| {
348 entry
349 .file_name()
350 .to_str()
351 .map(|s| s.starts_with("wal-") && s.ends_with(".log"))
352 .unwrap_or(false)
353 })
354 .collect();
355
356 wal_files.sort_by_key(|entry| entry.file_name());
358
359 if wal_files.len() > self.config.checkpoint_retention {
361 let to_remove = wal_files.len() - self.config.checkpoint_retention;
362 for entry in wal_files.iter().take(to_remove) {
363 tracing::info!("Removing old WAL file: {:?}", entry.path());
364 std::fs::remove_file(entry.path())?;
365 }
366 }
367
368 Ok(())
369 }
370
371 pub fn recover(&self) -> Result<Vec<WalEntry>> {
373 tracing::info!("Starting WAL recovery");
374
375 let mut all_entries = Vec::new();
376 let mut last_checkpoint_seq = 0u64;
377
378 let mut wal_files: Vec<_> = std::fs::read_dir(&self.config.wal_directory)?
380 .filter_map(|entry| entry.ok())
381 .filter(|entry| {
382 entry
383 .file_name()
384 .to_str()
385 .map(|s| s.starts_with("wal-") && s.ends_with(".log"))
386 .unwrap_or(false)
387 })
388 .collect();
389
390 wal_files.sort_by_key(|entry| entry.file_name());
392
393 for entry in wal_files {
395 let path = entry.path();
396 tracing::debug!("Reading WAL file: {:?}", path);
397
398 let file = File::open(&path)?;
399 let mut reader = BufReader::new(file);
400
401 let mut magic = [0u8; 4];
403 reader.read_exact(&mut magic)?;
404 if &magic != WAL_MAGIC {
405 tracing::warn!("Invalid WAL file magic number: {:?}", path);
406 continue;
407 }
408
409 let mut version_bytes = [0u8; 4];
411 reader.read_exact(&mut version_bytes)?;
412 let version = u32::from_le_bytes(version_bytes);
413 if version != WAL_VERSION {
414 tracing::warn!("Unsupported WAL version {} in {:?}", version, path);
415 continue;
416 }
417
418 let mut timestamp_bytes = [0u8; 8];
420 reader.read_exact(&mut timestamp_bytes)?;
421
422 loop {
424 let mut seq_bytes = [0u8; 8];
426 match reader.read_exact(&mut seq_bytes) {
427 Ok(_) => {}
428 Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => {
429 tracing::debug!("Reached end of WAL file (expected)");
430 break;
431 }
432 Err(e) => return Err(e.into()),
433 }
434 let seq = u64::from_le_bytes(seq_bytes);
435
436 let mut len_bytes = [0u8; 4];
438 match reader.read_exact(&mut len_bytes) {
439 Ok(_) => {}
440 Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => {
441 tracing::warn!(
442 "Incomplete entry at sequence {}: missing length field. Skipping rest of file.",
443 seq
444 );
445 break;
446 }
447 Err(e) => return Err(e.into()),
448 }
449 let len = u32::from_le_bytes(len_bytes);
450
451 if len > 100_000_000 {
453 tracing::warn!(
455 "Entry at sequence {} has suspicious length {}. Possibly corrupted. Skipping.",
456 seq,
457 len
458 );
459 break;
460 }
461
462 let mut entry_bytes = vec![0u8; len as usize];
464 match reader.read_exact(&mut entry_bytes) {
465 Ok(_) => {}
466 Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => {
467 tracing::warn!(
468 "Incomplete entry at sequence {}: expected {} bytes but reached EOF. Skipping rest of file.",
469 seq,
470 len
471 );
472 break;
473 }
474 Err(e) => return Err(e.into()),
475 }
476
477 let entry: WalEntry = match oxicode::serde::decode_from_slice(
479 &entry_bytes,
480 oxicode::config::standard(),
481 ) {
482 Ok((e, _)) => e,
483 Err(e) => {
484 tracing::warn!(
485 "Failed to deserialize entry at sequence {}: {}. Skipping entry.",
486 seq,
487 e
488 );
489 continue; }
491 };
492
493 if let WalEntry::Checkpoint {
495 sequence_number, ..
496 } = &entry
497 {
498 last_checkpoint_seq = *sequence_number;
499 }
500
501 all_entries.push((seq, entry));
502 }
503 }
504
505 let recovered_entries: Vec<_> = all_entries
509 .iter()
510 .filter(|(seq, _)| {
511 if last_checkpoint_seq == 0 {
512 true } else {
514 *seq > last_checkpoint_seq }
516 })
517 .map(|(_, entry)| entry.clone())
518 .collect();
519
520 tracing::info!(
521 "Recovered {} entries from WAL (after checkpoint {})",
522 recovered_entries.len(),
523 last_checkpoint_seq
524 );
525
526 if let Some((max_seq, _)) = all_entries.iter().max_by_key(|(seq, _)| seq) {
528 let mut seq = self
529 .sequence_number
530 .lock()
531 .expect("mutex lock should not be poisoned");
532 *seq = max_seq + 1;
533 }
534
535 Ok(recovered_entries)
536 }
537
538 pub fn flush(&self) -> Result<()> {
540 let mut file_guard = self
541 .current_file
542 .lock()
543 .expect("mutex lock should not be poisoned");
544 if let Some(ref mut writer) = *file_guard {
545 writer.flush()?;
546 writer.get_ref().sync_all()?;
547 }
548 Ok(())
549 }
550
551 pub fn current_sequence(&self) -> u64 {
553 *self
554 .sequence_number
555 .lock()
556 .expect("mutex lock should not be poisoned")
557 }
558
559 pub fn last_checkpoint_sequence(&self) -> u64 {
561 *self
562 .last_checkpoint
563 .lock()
564 .expect("mutex lock should not be poisoned")
565 }
566}
567
568impl Drop for WalManager {
569 fn drop(&mut self) {
570 let _ = self.flush();
572 }
573}
574
575#[cfg(test)]
576mod tests {
577 use super::*;
578 use tempfile::TempDir;
579
580 #[test]
581 fn test_wal_creation() {
582 let temp_dir = TempDir::new().unwrap();
583 let config = WalConfig {
584 wal_directory: temp_dir.path().to_path_buf(),
585 ..Default::default()
586 };
587
588 let wal = WalManager::new(config).unwrap();
589 assert_eq!(wal.current_sequence(), 0);
590 }
591
592 #[test]
593 fn test_wal_append() {
594 let temp_dir = TempDir::new().unwrap();
595 let config = WalConfig {
596 wal_directory: temp_dir.path().to_path_buf(),
597 sync_on_write: true,
598 ..Default::default()
599 };
600
601 let wal = WalManager::new(config).unwrap();
602
603 let entry = WalEntry::Insert {
604 id: "vec1".to_string(),
605 vector: vec![1.0, 2.0, 3.0],
606 metadata: None,
607 timestamp: 12345,
608 };
609
610 let seq = wal.append(entry).unwrap();
611 assert_eq!(seq, 0);
612 }
613
614 #[test]
615 fn test_wal_recovery() {
616 let temp_dir = TempDir::new().unwrap();
617 let config = WalConfig {
618 wal_directory: temp_dir.path().to_path_buf(),
619 sync_on_write: true,
620 checkpoint_interval: 100,
621 ..Default::default()
622 };
623
624 {
626 let wal = WalManager::new(config.clone()).unwrap();
627
628 for i in 0..5 {
629 let entry = WalEntry::Insert {
630 id: format!("vec{}", i),
631 vector: vec![i as f32, (i * 2) as f32],
632 metadata: None,
633 timestamp: (i + 1) * 1000, };
635 wal.append(entry).unwrap();
636 }
637
638 wal.flush().unwrap();
639 drop(wal);
641 }
642
643 std::thread::sleep(std::time::Duration::from_millis(100));
645
646 {
648 let wal = WalManager::new(config).unwrap();
649 let recovered = wal.recover().unwrap();
650
651 assert_eq!(
653 recovered.len(),
654 5,
655 "Expected exactly 5 entries, got {}",
656 recovered.len()
657 );
658
659 let timestamps: Vec<u64> = recovered.iter().map(|e| e.timestamp()).collect();
661 assert_eq!(timestamps, vec![1000, 2000, 3000, 4000, 5000]);
662 }
663 }
664
665 #[test]
666 fn test_wal_checkpoint() {
667 let temp_dir = TempDir::new().unwrap();
668 let config = WalConfig {
669 wal_directory: temp_dir.path().to_path_buf(),
670 sync_on_write: true,
671 checkpoint_interval: 3,
672 ..Default::default()
673 };
674
675 let wal = WalManager::new(config).unwrap();
676
677 for i in 0..5 {
679 let entry = WalEntry::Insert {
680 id: format!("vec{}", i),
681 vector: vec![i as f32],
682 metadata: None,
683 timestamp: i,
684 };
685 wal.append(entry).unwrap();
686 }
687
688 assert!(wal.last_checkpoint_sequence() > 0);
689 }
690
691 #[test]
692 fn test_wal_batch_operation() {
693 let temp_dir = TempDir::new().unwrap();
694 let config = WalConfig {
695 wal_directory: temp_dir.path().to_path_buf(),
696 ..Default::default()
697 };
698
699 let wal = WalManager::new(config).unwrap();
700
701 let batch = WalEntry::Batch {
702 entries: vec![
703 WalEntry::Insert {
704 id: "vec1".to_string(),
705 vector: vec![1.0],
706 metadata: None,
707 timestamp: 1,
708 },
709 WalEntry::Update {
710 id: "vec2".to_string(),
711 vector: vec![2.0],
712 metadata: None,
713 timestamp: 2,
714 },
715 ],
716 timestamp: 3,
717 };
718
719 wal.append(batch).unwrap();
720 wal.flush().unwrap();
721 }
722
723 #[test]
724 fn test_wal_transaction() {
725 let temp_dir = TempDir::new().unwrap();
726 let config = WalConfig {
727 wal_directory: temp_dir.path().to_path_buf(),
728 ..Default::default()
729 };
730
731 let wal = WalManager::new(config).unwrap();
732
733 wal.append(WalEntry::BeginTransaction {
735 transaction_id: 1,
736 timestamp: 100,
737 })
738 .unwrap();
739
740 wal.append(WalEntry::Insert {
742 id: "vec1".to_string(),
743 vector: vec![1.0],
744 metadata: None,
745 timestamp: 101,
746 })
747 .unwrap();
748
749 wal.append(WalEntry::CommitTransaction {
751 transaction_id: 1,
752 timestamp: 102,
753 })
754 .unwrap();
755
756 wal.flush().unwrap();
757 }
758}