oxirs_vec/
wal.rs

1//! Write-Ahead Logging (WAL) for crash recovery
2//!
3//! This module provides comprehensive write-ahead logging for vector index operations,
4//! enabling crash recovery and ensuring data durability. The WAL records all modifications
5//! before they are applied to the index, allowing the system to recover from crashes by
6//! replaying the log.
7//!
8//! # Features
9//!
10//! - Transaction-based logging
11//! - Automatic crash recovery
12//! - Log compaction and checkpointing
13//! - Concurrent write support with proper synchronization
14//! - Configurable fsync behavior for performance tuning
15//!
16//! # Architecture
17//!
18//! ```text
19//! ┌─────────────┐
20//! │ Index Ops   │
21//! └──────┬──────┘
22//!        │
23//!        ▼
24//! ┌─────────────┐     ┌──────────────┐
25//! │ WAL Writer  │────▶│ Log File     │
26//! └─────────────┘     └──────────────┘
27//!        │                    │
28//!        │                    │ (on crash)
29//!        ▼                    ▼
30//! ┌─────────────┐     ┌──────────────┐
31//! │ Index       │◀────│ WAL Recovery │
32//! └─────────────┘     └──────────────┘
33//! ```
34
35use anyhow::{anyhow, Result};
36use oxicode::{Decode, Encode};
37use serde::{Deserialize, Serialize};
38use std::collections::HashMap;
39use std::fs::{File, OpenOptions};
40use std::io::{BufReader, BufWriter, Read, Write};
41use std::path::PathBuf;
42use std::sync::{Arc, Mutex};
43use std::time::{SystemTime, UNIX_EPOCH};
44
45/// WAL magic number for file format validation
46const WAL_MAGIC: &[u8; 4] = b"WALV"; // WAL Vector
47
48/// WAL format version
49const WAL_VERSION: u32 = 1;
50
51/// Write-Ahead Log entry representing a single operation
52#[derive(Debug, Clone, Serialize, Deserialize, Encode, Decode)]
53pub enum WalEntry {
54    /// Insert a new vector
55    Insert {
56        id: String,
57        vector: Vec<f32>,
58        metadata: Option<HashMap<String, String>>,
59        timestamp: u64,
60    },
61    /// Update an existing vector
62    Update {
63        id: String,
64        vector: Vec<f32>,
65        metadata: Option<HashMap<String, String>>,
66        timestamp: u64,
67    },
68    /// Delete a vector
69    Delete { id: String, timestamp: u64 },
70    /// Batch operation (multiple entries)
71    Batch {
72        entries: Vec<WalEntry>,
73        timestamp: u64,
74    },
75    /// Checkpoint marker (all operations before this are persisted)
76    Checkpoint {
77        sequence_number: u64,
78        timestamp: u64,
79    },
80    /// Transaction begin
81    BeginTransaction { transaction_id: u64, timestamp: u64 },
82    /// Transaction commit
83    CommitTransaction { transaction_id: u64, timestamp: u64 },
84    /// Transaction abort
85    AbortTransaction { transaction_id: u64, timestamp: u64 },
86}
87
88impl WalEntry {
89    /// Get the timestamp of this entry
90    pub fn timestamp(&self) -> u64 {
91        match self {
92            WalEntry::Insert { timestamp, .. }
93            | WalEntry::Update { timestamp, .. }
94            | WalEntry::Delete { timestamp, .. }
95            | WalEntry::Batch { timestamp, .. }
96            | WalEntry::Checkpoint { timestamp, .. }
97            | WalEntry::BeginTransaction { timestamp, .. }
98            | WalEntry::CommitTransaction { timestamp, .. }
99            | WalEntry::AbortTransaction { timestamp, .. } => *timestamp,
100        }
101    }
102
103    /// Check if this is a checkpoint entry
104    pub fn is_checkpoint(&self) -> bool {
105        matches!(self, WalEntry::Checkpoint { .. })
106    }
107}
108
109/// WAL configuration
110#[derive(Debug, Clone)]
111pub struct WalConfig {
112    /// Directory where WAL files are stored
113    pub wal_directory: PathBuf,
114    /// Maximum size of a single WAL file before rotation (in bytes)
115    pub max_file_size: u64,
116    /// Whether to call fsync after each write (slower but safer)
117    pub sync_on_write: bool,
118    /// Checkpoint interval (number of operations)
119    pub checkpoint_interval: u64,
120    /// Keep this many checkpoint files
121    pub checkpoint_retention: usize,
122    /// Buffer size for WAL writes
123    pub buffer_size: usize,
124}
125
126impl Default for WalConfig {
127    fn default() -> Self {
128        Self {
129            wal_directory: PathBuf::from("./wal"),
130            max_file_size: 100 * 1024 * 1024, // 100MB
131            sync_on_write: false,             // Better performance, acceptable risk
132            checkpoint_interval: 10000,
133            checkpoint_retention: 3,
134            buffer_size: 64 * 1024, // 64KB buffer
135        }
136    }
137}
138
139/// Write-Ahead Log manager
140pub struct WalManager {
141    config: WalConfig,
142    current_file: Arc<Mutex<Option<BufWriter<File>>>>,
143    current_file_path: Arc<Mutex<PathBuf>>,
144    sequence_number: Arc<Mutex<u64>>,
145    last_checkpoint: Arc<Mutex<u64>>,
146}
147
148impl WalManager {
149    /// Create a new WAL manager
150    pub fn new(config: WalConfig) -> Result<Self> {
151        // Ensure WAL directory exists
152        std::fs::create_dir_all(&config.wal_directory)?;
153
154        let manager = Self {
155            config,
156            current_file: Arc::new(Mutex::new(None)),
157            current_file_path: Arc::new(Mutex::new(PathBuf::new())),
158            sequence_number: Arc::new(Mutex::new(0)),
159            last_checkpoint: Arc::new(Mutex::new(0)),
160        };
161
162        // Open or create the current WAL file
163        manager.rotate_wal_file()?;
164
165        Ok(manager)
166    }
167
168    /// Append an entry to the WAL
169    pub fn append(&self, entry: WalEntry) -> Result<u64> {
170        let seq = {
171            let mut seq_guard = self
172                .sequence_number
173                .lock()
174                .expect("mutex lock should not be poisoned");
175            let seq = *seq_guard;
176            *seq_guard += 1;
177            seq
178        };
179
180        // Write to file
181        let needs_checkpoint = {
182            let mut file_guard = self
183                .current_file
184                .lock()
185                .expect("mutex lock should not be poisoned");
186
187            if let Some(ref mut writer) = *file_guard {
188                // Serialize the entry
189                let entry_bytes =
190                    oxicode::serde::encode_to_vec(&entry, oxicode::config::standard())
191                        .map_err(|e| anyhow!("Failed to serialize WAL entry: {}", e))?;
192                let entry_len = entry_bytes.len() as u32;
193
194                // Write sequence number, length, and data
195                writer.write_all(&seq.to_le_bytes())?;
196                writer.write_all(&entry_len.to_le_bytes())?;
197                writer.write_all(&entry_bytes)?;
198
199                if self.config.sync_on_write {
200                    writer.flush()?;
201                    writer.get_ref().sync_all()?;
202                }
203
204                // Check if file rotation is needed
205                let needs_rotation = if let Ok(metadata) = writer.get_ref().metadata() {
206                    metadata.len() >= self.config.max_file_size
207                } else {
208                    false
209                };
210
211                if needs_rotation {
212                    drop(file_guard);
213                    self.rotate_wal_file()?;
214                }
215
216                // Check if checkpoint is needed
217                let last_checkpoint = *self
218                    .last_checkpoint
219                    .lock()
220                    .expect("mutex lock should not be poisoned");
221                seq - last_checkpoint >= self.config.checkpoint_interval
222            } else {
223                return Err(anyhow!("WAL file not open"));
224            }
225        };
226
227        // Checkpoint outside of lock
228        if needs_checkpoint {
229            self.checkpoint(seq)?;
230        }
231
232        Ok(seq)
233    }
234
235    /// Create a checkpoint
236    pub fn checkpoint(&self, sequence_number: u64) -> Result<()> {
237        tracing::info!("Creating WAL checkpoint at sequence {}", sequence_number);
238
239        let timestamp = SystemTime::now()
240            .duration_since(UNIX_EPOCH)
241            .expect("system time should be after UNIX_EPOCH")
242            .as_secs();
243
244        let checkpoint_entry = WalEntry::Checkpoint {
245            sequence_number,
246            timestamp,
247        };
248
249        // Write checkpoint directly without going through append() to avoid recursion
250        let seq = {
251            let mut seq_guard = self
252                .sequence_number
253                .lock()
254                .expect("mutex lock should not be poisoned");
255            let seq = *seq_guard;
256            *seq_guard += 1;
257            seq
258        };
259
260        {
261            let mut file_guard = self
262                .current_file
263                .lock()
264                .expect("mutex lock should not be poisoned");
265            if let Some(ref mut writer) = *file_guard {
266                let entry_bytes =
267                    oxicode::serde::encode_to_vec(&checkpoint_entry, oxicode::config::standard())
268                        .map_err(|e| anyhow!("Failed to serialize checkpoint entry: {}", e))?;
269                let entry_len = entry_bytes.len() as u32;
270
271                writer.write_all(&seq.to_le_bytes())?;
272                writer.write_all(&entry_len.to_le_bytes())?;
273                writer.write_all(&entry_bytes)?;
274
275                if self.config.sync_on_write {
276                    writer.flush()?;
277                    writer.get_ref().sync_all()?;
278                }
279            }
280        }
281
282        let mut last_checkpoint = self
283            .last_checkpoint
284            .lock()
285            .expect("mutex lock should not be poisoned");
286        *last_checkpoint = sequence_number;
287
288        // Cleanup old WAL files
289        self.cleanup_old_files()?;
290
291        Ok(())
292    }
293
294    /// Rotate to a new WAL file
295    fn rotate_wal_file(&self) -> Result<()> {
296        let timestamp = SystemTime::now()
297            .duration_since(UNIX_EPOCH)
298            .expect("system time should be after UNIX_EPOCH")
299            .as_secs();
300
301        let filename = format!("wal-{:016x}.log", timestamp);
302        let filepath = self.config.wal_directory.join(&filename);
303
304        tracing::info!("Rotating WAL to new file: {:?}", filepath);
305
306        let file = OpenOptions::new()
307            .create(true)
308            .append(true)
309            .open(&filepath)?;
310
311        let mut writer = BufWriter::with_capacity(self.config.buffer_size, file);
312
313        // Write WAL file header
314        writer.write_all(WAL_MAGIC)?;
315        writer.write_all(&WAL_VERSION.to_le_bytes())?;
316        writer.write_all(&timestamp.to_le_bytes())?;
317
318        if self.config.sync_on_write {
319            writer.flush()?;
320            writer.get_ref().sync_all()?;
321        }
322
323        let mut file_guard = self
324            .current_file
325            .lock()
326            .expect("mutex lock should not be poisoned");
327        let mut path_guard = self
328            .current_file_path
329            .lock()
330            .expect("mutex lock should not be poisoned");
331
332        // Flush and close old file
333        if let Some(mut old_writer) = file_guard.take() {
334            old_writer.flush()?;
335        }
336
337        *file_guard = Some(writer);
338        *path_guard = filepath;
339
340        Ok(())
341    }
342
343    /// Clean up old WAL files (keep only recent checkpoints)
344    fn cleanup_old_files(&self) -> Result<()> {
345        let mut wal_files: Vec<_> = std::fs::read_dir(&self.config.wal_directory)?
346            .filter_map(|entry| entry.ok())
347            .filter(|entry| {
348                entry
349                    .file_name()
350                    .to_str()
351                    .map(|s| s.starts_with("wal-") && s.ends_with(".log"))
352                    .unwrap_or(false)
353            })
354            .collect();
355
356        // Sort by filename (timestamp-based)
357        wal_files.sort_by_key(|entry| entry.file_name());
358
359        // Keep the most recent files
360        if wal_files.len() > self.config.checkpoint_retention {
361            let to_remove = wal_files.len() - self.config.checkpoint_retention;
362            for entry in wal_files.iter().take(to_remove) {
363                tracing::info!("Removing old WAL file: {:?}", entry.path());
364                std::fs::remove_file(entry.path())?;
365            }
366        }
367
368        Ok(())
369    }
370
371    /// Recover from WAL files
372    pub fn recover(&self) -> Result<Vec<WalEntry>> {
373        tracing::info!("Starting WAL recovery");
374
375        let mut all_entries = Vec::new();
376        let mut last_checkpoint_seq = 0u64;
377
378        // Find all WAL files
379        let mut wal_files: Vec<_> = std::fs::read_dir(&self.config.wal_directory)?
380            .filter_map(|entry| entry.ok())
381            .filter(|entry| {
382                entry
383                    .file_name()
384                    .to_str()
385                    .map(|s| s.starts_with("wal-") && s.ends_with(".log"))
386                    .unwrap_or(false)
387            })
388            .collect();
389
390        // Sort by filename (timestamp-based)
391        wal_files.sort_by_key(|entry| entry.file_name());
392
393        // Read all WAL files
394        for entry in wal_files {
395            let path = entry.path();
396            tracing::debug!("Reading WAL file: {:?}", path);
397
398            let file = File::open(&path)?;
399            let mut reader = BufReader::new(file);
400
401            // Verify magic number
402            let mut magic = [0u8; 4];
403            reader.read_exact(&mut magic)?;
404            if &magic != WAL_MAGIC {
405                tracing::warn!("Invalid WAL file magic number: {:?}", path);
406                continue;
407            }
408
409            // Read version
410            let mut version_bytes = [0u8; 4];
411            reader.read_exact(&mut version_bytes)?;
412            let version = u32::from_le_bytes(version_bytes);
413            if version != WAL_VERSION {
414                tracing::warn!("Unsupported WAL version {} in {:?}", version, path);
415                continue;
416            }
417
418            // Read file timestamp
419            let mut timestamp_bytes = [0u8; 8];
420            reader.read_exact(&mut timestamp_bytes)?;
421
422            // Read entries with robust error handling for incomplete writes
423            loop {
424                // Read sequence number
425                let mut seq_bytes = [0u8; 8];
426                match reader.read_exact(&mut seq_bytes) {
427                    Ok(_) => {}
428                    Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => {
429                        tracing::debug!("Reached end of WAL file (expected)");
430                        break;
431                    }
432                    Err(e) => return Err(e.into()),
433                }
434                let seq = u64::from_le_bytes(seq_bytes);
435
436                // Read entry length
437                let mut len_bytes = [0u8; 4];
438                match reader.read_exact(&mut len_bytes) {
439                    Ok(_) => {}
440                    Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => {
441                        tracing::warn!(
442                            "Incomplete entry at sequence {}: missing length field. Skipping rest of file.",
443                            seq
444                        );
445                        break;
446                    }
447                    Err(e) => return Err(e.into()),
448                }
449                let len = u32::from_le_bytes(len_bytes);
450
451                // Sanity check on entry length (prevent excessive memory allocation)
452                if len > 100_000_000 {
453                    // 100MB max entry size
454                    tracing::warn!(
455                        "Entry at sequence {} has suspicious length {}. Possibly corrupted. Skipping.",
456                        seq,
457                        len
458                    );
459                    break;
460                }
461
462                // Read entry data
463                let mut entry_bytes = vec![0u8; len as usize];
464                match reader.read_exact(&mut entry_bytes) {
465                    Ok(_) => {}
466                    Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => {
467                        tracing::warn!(
468                            "Incomplete entry at sequence {}: expected {} bytes but reached EOF. Skipping rest of file.",
469                            seq,
470                            len
471                        );
472                        break;
473                    }
474                    Err(e) => return Err(e.into()),
475                }
476
477                // Deserialize entry
478                let entry: WalEntry = match oxicode::serde::decode_from_slice(
479                    &entry_bytes,
480                    oxicode::config::standard(),
481                ) {
482                    Ok((e, _)) => e,
483                    Err(e) => {
484                        tracing::warn!(
485                            "Failed to deserialize entry at sequence {}: {}. Skipping entry.",
486                            seq,
487                            e
488                        );
489                        continue; // Skip corrupted entry but continue reading
490                    }
491                };
492
493                // Track last checkpoint
494                if let WalEntry::Checkpoint {
495                    sequence_number, ..
496                } = &entry
497                {
498                    last_checkpoint_seq = *sequence_number;
499                }
500
501                all_entries.push((seq, entry));
502            }
503        }
504
505        // Filter entries after last checkpoint
506        // Note: If last_checkpoint_seq == 0 (no checkpoint), recover all entries including seq 0
507        // Otherwise, only recover entries strictly after the checkpoint
508        let recovered_entries: Vec<_> = all_entries
509            .iter()
510            .filter(|(seq, _)| {
511                if last_checkpoint_seq == 0 {
512                    true // No checkpoint, recover everything
513                } else {
514                    *seq > last_checkpoint_seq // Checkpoint exists, only after it
515                }
516            })
517            .map(|(_, entry)| entry.clone())
518            .collect();
519
520        tracing::info!(
521            "Recovered {} entries from WAL (after checkpoint {})",
522            recovered_entries.len(),
523            last_checkpoint_seq
524        );
525
526        // Update sequence number based on the maximum sequence number seen
527        if let Some((max_seq, _)) = all_entries.iter().max_by_key(|(seq, _)| seq) {
528            let mut seq = self
529                .sequence_number
530                .lock()
531                .expect("mutex lock should not be poisoned");
532            *seq = max_seq + 1;
533        }
534
535        Ok(recovered_entries)
536    }
537
538    /// Flush all pending writes to disk
539    pub fn flush(&self) -> Result<()> {
540        let mut file_guard = self
541            .current_file
542            .lock()
543            .expect("mutex lock should not be poisoned");
544        if let Some(ref mut writer) = *file_guard {
545            writer.flush()?;
546            writer.get_ref().sync_all()?;
547        }
548        Ok(())
549    }
550
551    /// Get current sequence number
552    pub fn current_sequence(&self) -> u64 {
553        *self
554            .sequence_number
555            .lock()
556            .expect("mutex lock should not be poisoned")
557    }
558
559    /// Get last checkpoint sequence number
560    pub fn last_checkpoint_sequence(&self) -> u64 {
561        *self
562            .last_checkpoint
563            .lock()
564            .expect("mutex lock should not be poisoned")
565    }
566}
567
568impl Drop for WalManager {
569    fn drop(&mut self) {
570        // Ensure all data is flushed on drop
571        let _ = self.flush();
572    }
573}
574
575#[cfg(test)]
576mod tests {
577    use super::*;
578    use tempfile::TempDir;
579
580    #[test]
581    fn test_wal_creation() {
582        let temp_dir = TempDir::new().unwrap();
583        let config = WalConfig {
584            wal_directory: temp_dir.path().to_path_buf(),
585            ..Default::default()
586        };
587
588        let wal = WalManager::new(config).unwrap();
589        assert_eq!(wal.current_sequence(), 0);
590    }
591
592    #[test]
593    fn test_wal_append() {
594        let temp_dir = TempDir::new().unwrap();
595        let config = WalConfig {
596            wal_directory: temp_dir.path().to_path_buf(),
597            sync_on_write: true,
598            ..Default::default()
599        };
600
601        let wal = WalManager::new(config).unwrap();
602
603        let entry = WalEntry::Insert {
604            id: "vec1".to_string(),
605            vector: vec![1.0, 2.0, 3.0],
606            metadata: None,
607            timestamp: 12345,
608        };
609
610        let seq = wal.append(entry).unwrap();
611        assert_eq!(seq, 0);
612    }
613
614    #[test]
615    fn test_wal_recovery() {
616        let temp_dir = TempDir::new().unwrap();
617        let config = WalConfig {
618            wal_directory: temp_dir.path().to_path_buf(),
619            sync_on_write: true,
620            checkpoint_interval: 100,
621            ..Default::default()
622        };
623
624        // Write some entries
625        {
626            let wal = WalManager::new(config.clone()).unwrap();
627
628            for i in 0..5 {
629                let entry = WalEntry::Insert {
630                    id: format!("vec{}", i),
631                    vector: vec![i as f32, (i * 2) as f32],
632                    metadata: None,
633                    timestamp: (i + 1) * 1000, // Use unique timestamps
634                };
635                wal.append(entry).unwrap();
636            }
637
638            wal.flush().unwrap();
639            // Ensure Drop is called to flush everything
640            drop(wal);
641        }
642
643        // Small delay to ensure file is written
644        std::thread::sleep(std::time::Duration::from_millis(100));
645
646        // Recover
647        {
648            let wal = WalManager::new(config).unwrap();
649            let recovered = wal.recover().unwrap();
650
651            // Should recover 5 entries
652            assert_eq!(
653                recovered.len(),
654                5,
655                "Expected exactly 5 entries, got {}",
656                recovered.len()
657            );
658
659            // Verify all timestamps are present
660            let timestamps: Vec<u64> = recovered.iter().map(|e| e.timestamp()).collect();
661            assert_eq!(timestamps, vec![1000, 2000, 3000, 4000, 5000]);
662        }
663    }
664
665    #[test]
666    fn test_wal_checkpoint() {
667        let temp_dir = TempDir::new().unwrap();
668        let config = WalConfig {
669            wal_directory: temp_dir.path().to_path_buf(),
670            sync_on_write: true,
671            checkpoint_interval: 3,
672            ..Default::default()
673        };
674
675        let wal = WalManager::new(config).unwrap();
676
677        // Write entries (should trigger checkpoint)
678        for i in 0..5 {
679            let entry = WalEntry::Insert {
680                id: format!("vec{}", i),
681                vector: vec![i as f32],
682                metadata: None,
683                timestamp: i,
684            };
685            wal.append(entry).unwrap();
686        }
687
688        assert!(wal.last_checkpoint_sequence() > 0);
689    }
690
691    #[test]
692    fn test_wal_batch_operation() {
693        let temp_dir = TempDir::new().unwrap();
694        let config = WalConfig {
695            wal_directory: temp_dir.path().to_path_buf(),
696            ..Default::default()
697        };
698
699        let wal = WalManager::new(config).unwrap();
700
701        let batch = WalEntry::Batch {
702            entries: vec![
703                WalEntry::Insert {
704                    id: "vec1".to_string(),
705                    vector: vec![1.0],
706                    metadata: None,
707                    timestamp: 1,
708                },
709                WalEntry::Update {
710                    id: "vec2".to_string(),
711                    vector: vec![2.0],
712                    metadata: None,
713                    timestamp: 2,
714                },
715            ],
716            timestamp: 3,
717        };
718
719        wal.append(batch).unwrap();
720        wal.flush().unwrap();
721    }
722
723    #[test]
724    fn test_wal_transaction() {
725        let temp_dir = TempDir::new().unwrap();
726        let config = WalConfig {
727            wal_directory: temp_dir.path().to_path_buf(),
728            ..Default::default()
729        };
730
731        let wal = WalManager::new(config).unwrap();
732
733        // Begin transaction
734        wal.append(WalEntry::BeginTransaction {
735            transaction_id: 1,
736            timestamp: 100,
737        })
738        .unwrap();
739
740        // Operations
741        wal.append(WalEntry::Insert {
742            id: "vec1".to_string(),
743            vector: vec![1.0],
744            metadata: None,
745            timestamp: 101,
746        })
747        .unwrap();
748
749        // Commit transaction
750        wal.append(WalEntry::CommitTransaction {
751            transaction_id: 1,
752            timestamp: 102,
753        })
754        .unwrap();
755
756        wal.flush().unwrap();
757    }
758}