oxirs_vec/
wal.rs

1//! Write-Ahead Logging (WAL) for crash recovery
2//!
3//! This module provides comprehensive write-ahead logging for vector index operations,
4//! enabling crash recovery and ensuring data durability. The WAL records all modifications
5//! before they are applied to the index, allowing the system to recover from crashes by
6//! replaying the log.
7//!
8//! # Features
9//!
10//! - Transaction-based logging
11//! - Automatic crash recovery
12//! - Log compaction and checkpointing
13//! - Concurrent write support with proper synchronization
14//! - Configurable fsync behavior for performance tuning
15//!
16//! # Architecture
17//!
18//! ```text
19//! ┌─────────────┐
20//! │ Index Ops   │
21//! └──────┬──────┘
22//!        │
23//!        ▼
24//! ┌─────────────┐     ┌──────────────┐
25//! │ WAL Writer  │────▶│ Log File     │
26//! └─────────────┘     └──────────────┘
27//!        │                    │
28//!        │                    │ (on crash)
29//!        ▼                    ▼
30//! ┌─────────────┐     ┌──────────────┐
31//! │ Index       │◀────│ WAL Recovery │
32//! └─────────────┘     └──────────────┘
33//! ```
34
35use anyhow::{anyhow, Result};
36use bincode::{Decode, Encode};
37use serde::{Deserialize, Serialize};
38use std::collections::HashMap;
39use std::fs::{File, OpenOptions};
40use std::io::{BufReader, BufWriter, Read, Write};
41use std::path::PathBuf;
42use std::sync::{Arc, Mutex};
43use std::time::{SystemTime, UNIX_EPOCH};
44
45/// WAL magic number for file format validation
46const WAL_MAGIC: &[u8; 4] = b"WALV"; // WAL Vector
47
48/// WAL format version
49const WAL_VERSION: u32 = 1;
50
51/// Write-Ahead Log entry representing a single operation
52#[derive(Debug, Clone, Serialize, Deserialize, Encode, Decode)]
53pub enum WalEntry {
54    /// Insert a new vector
55    Insert {
56        id: String,
57        vector: Vec<f32>,
58        metadata: Option<HashMap<String, String>>,
59        timestamp: u64,
60    },
61    /// Update an existing vector
62    Update {
63        id: String,
64        vector: Vec<f32>,
65        metadata: Option<HashMap<String, String>>,
66        timestamp: u64,
67    },
68    /// Delete a vector
69    Delete { id: String, timestamp: u64 },
70    /// Batch operation (multiple entries)
71    Batch {
72        entries: Vec<WalEntry>,
73        timestamp: u64,
74    },
75    /// Checkpoint marker (all operations before this are persisted)
76    Checkpoint {
77        sequence_number: u64,
78        timestamp: u64,
79    },
80    /// Transaction begin
81    BeginTransaction { transaction_id: u64, timestamp: u64 },
82    /// Transaction commit
83    CommitTransaction { transaction_id: u64, timestamp: u64 },
84    /// Transaction abort
85    AbortTransaction { transaction_id: u64, timestamp: u64 },
86}
87
88impl WalEntry {
89    /// Get the timestamp of this entry
90    pub fn timestamp(&self) -> u64 {
91        match self {
92            WalEntry::Insert { timestamp, .. }
93            | WalEntry::Update { timestamp, .. }
94            | WalEntry::Delete { timestamp, .. }
95            | WalEntry::Batch { timestamp, .. }
96            | WalEntry::Checkpoint { timestamp, .. }
97            | WalEntry::BeginTransaction { timestamp, .. }
98            | WalEntry::CommitTransaction { timestamp, .. }
99            | WalEntry::AbortTransaction { timestamp, .. } => *timestamp,
100        }
101    }
102
103    /// Check if this is a checkpoint entry
104    pub fn is_checkpoint(&self) -> bool {
105        matches!(self, WalEntry::Checkpoint { .. })
106    }
107}
108
109/// WAL configuration
110#[derive(Debug, Clone)]
111pub struct WalConfig {
112    /// Directory where WAL files are stored
113    pub wal_directory: PathBuf,
114    /// Maximum size of a single WAL file before rotation (in bytes)
115    pub max_file_size: u64,
116    /// Whether to call fsync after each write (slower but safer)
117    pub sync_on_write: bool,
118    /// Checkpoint interval (number of operations)
119    pub checkpoint_interval: u64,
120    /// Keep this many checkpoint files
121    pub checkpoint_retention: usize,
122    /// Buffer size for WAL writes
123    pub buffer_size: usize,
124}
125
126impl Default for WalConfig {
127    fn default() -> Self {
128        Self {
129            wal_directory: PathBuf::from("./wal"),
130            max_file_size: 100 * 1024 * 1024, // 100MB
131            sync_on_write: false,             // Better performance, acceptable risk
132            checkpoint_interval: 10000,
133            checkpoint_retention: 3,
134            buffer_size: 64 * 1024, // 64KB buffer
135        }
136    }
137}
138
139/// Write-Ahead Log manager
140pub struct WalManager {
141    config: WalConfig,
142    current_file: Arc<Mutex<Option<BufWriter<File>>>>,
143    current_file_path: Arc<Mutex<PathBuf>>,
144    sequence_number: Arc<Mutex<u64>>,
145    last_checkpoint: Arc<Mutex<u64>>,
146}
147
148impl WalManager {
149    /// Create a new WAL manager
150    pub fn new(config: WalConfig) -> Result<Self> {
151        // Ensure WAL directory exists
152        std::fs::create_dir_all(&config.wal_directory)?;
153
154        let manager = Self {
155            config,
156            current_file: Arc::new(Mutex::new(None)),
157            current_file_path: Arc::new(Mutex::new(PathBuf::new())),
158            sequence_number: Arc::new(Mutex::new(0)),
159            last_checkpoint: Arc::new(Mutex::new(0)),
160        };
161
162        // Open or create the current WAL file
163        manager.rotate_wal_file()?;
164
165        Ok(manager)
166    }
167
168    /// Append an entry to the WAL
169    pub fn append(&self, entry: WalEntry) -> Result<u64> {
170        let seq = {
171            let mut seq_guard = self.sequence_number.lock().unwrap();
172            let seq = *seq_guard;
173            *seq_guard += 1;
174            seq
175        };
176
177        // Write to file
178        let needs_checkpoint = {
179            let mut file_guard = self.current_file.lock().unwrap();
180
181            if let Some(ref mut writer) = *file_guard {
182                // Serialize the entry
183                let entry_bytes = bincode::encode_to_vec(&entry, bincode::config::standard())
184                    .map_err(|e| anyhow!("Failed to serialize WAL entry: {}", e))?;
185                let entry_len = entry_bytes.len() as u32;
186
187                // Write sequence number, length, and data
188                writer.write_all(&seq.to_le_bytes())?;
189                writer.write_all(&entry_len.to_le_bytes())?;
190                writer.write_all(&entry_bytes)?;
191
192                if self.config.sync_on_write {
193                    writer.flush()?;
194                    writer.get_ref().sync_all()?;
195                }
196
197                // Check if file rotation is needed
198                let needs_rotation = if let Ok(metadata) = writer.get_ref().metadata() {
199                    metadata.len() >= self.config.max_file_size
200                } else {
201                    false
202                };
203
204                if needs_rotation {
205                    drop(file_guard);
206                    self.rotate_wal_file()?;
207                }
208
209                // Check if checkpoint is needed
210                let last_checkpoint = *self.last_checkpoint.lock().unwrap();
211                seq - last_checkpoint >= self.config.checkpoint_interval
212            } else {
213                return Err(anyhow!("WAL file not open"));
214            }
215        };
216
217        // Checkpoint outside of lock
218        if needs_checkpoint {
219            self.checkpoint(seq)?;
220        }
221
222        Ok(seq)
223    }
224
225    /// Create a checkpoint
226    pub fn checkpoint(&self, sequence_number: u64) -> Result<()> {
227        tracing::info!("Creating WAL checkpoint at sequence {}", sequence_number);
228
229        let timestamp = SystemTime::now()
230            .duration_since(UNIX_EPOCH)
231            .unwrap()
232            .as_secs();
233
234        let checkpoint_entry = WalEntry::Checkpoint {
235            sequence_number,
236            timestamp,
237        };
238
239        // Write checkpoint directly without going through append() to avoid recursion
240        let seq = {
241            let mut seq_guard = self.sequence_number.lock().unwrap();
242            let seq = *seq_guard;
243            *seq_guard += 1;
244            seq
245        };
246
247        {
248            let mut file_guard = self.current_file.lock().unwrap();
249            if let Some(ref mut writer) = *file_guard {
250                let entry_bytes =
251                    bincode::encode_to_vec(&checkpoint_entry, bincode::config::standard())
252                        .map_err(|e| anyhow!("Failed to serialize checkpoint entry: {}", e))?;
253                let entry_len = entry_bytes.len() as u32;
254
255                writer.write_all(&seq.to_le_bytes())?;
256                writer.write_all(&entry_len.to_le_bytes())?;
257                writer.write_all(&entry_bytes)?;
258
259                if self.config.sync_on_write {
260                    writer.flush()?;
261                    writer.get_ref().sync_all()?;
262                }
263            }
264        }
265
266        let mut last_checkpoint = self.last_checkpoint.lock().unwrap();
267        *last_checkpoint = sequence_number;
268
269        // Cleanup old WAL files
270        self.cleanup_old_files()?;
271
272        Ok(())
273    }
274
275    /// Rotate to a new WAL file
276    fn rotate_wal_file(&self) -> Result<()> {
277        let timestamp = SystemTime::now()
278            .duration_since(UNIX_EPOCH)
279            .unwrap()
280            .as_secs();
281
282        let filename = format!("wal-{:016x}.log", timestamp);
283        let filepath = self.config.wal_directory.join(&filename);
284
285        tracing::info!("Rotating WAL to new file: {:?}", filepath);
286
287        let file = OpenOptions::new()
288            .create(true)
289            .append(true)
290            .open(&filepath)?;
291
292        let mut writer = BufWriter::with_capacity(self.config.buffer_size, file);
293
294        // Write WAL file header
295        writer.write_all(WAL_MAGIC)?;
296        writer.write_all(&WAL_VERSION.to_le_bytes())?;
297        writer.write_all(&timestamp.to_le_bytes())?;
298
299        if self.config.sync_on_write {
300            writer.flush()?;
301            writer.get_ref().sync_all()?;
302        }
303
304        let mut file_guard = self.current_file.lock().unwrap();
305        let mut path_guard = self.current_file_path.lock().unwrap();
306
307        // Flush and close old file
308        if let Some(mut old_writer) = file_guard.take() {
309            old_writer.flush()?;
310        }
311
312        *file_guard = Some(writer);
313        *path_guard = filepath;
314
315        Ok(())
316    }
317
318    /// Clean up old WAL files (keep only recent checkpoints)
319    fn cleanup_old_files(&self) -> Result<()> {
320        let mut wal_files: Vec<_> = std::fs::read_dir(&self.config.wal_directory)?
321            .filter_map(|entry| entry.ok())
322            .filter(|entry| {
323                entry
324                    .file_name()
325                    .to_str()
326                    .map(|s| s.starts_with("wal-") && s.ends_with(".log"))
327                    .unwrap_or(false)
328            })
329            .collect();
330
331        // Sort by filename (timestamp-based)
332        wal_files.sort_by_key(|entry| entry.file_name());
333
334        // Keep the most recent files
335        if wal_files.len() > self.config.checkpoint_retention {
336            let to_remove = wal_files.len() - self.config.checkpoint_retention;
337            for entry in wal_files.iter().take(to_remove) {
338                tracing::info!("Removing old WAL file: {:?}", entry.path());
339                std::fs::remove_file(entry.path())?;
340            }
341        }
342
343        Ok(())
344    }
345
346    /// Recover from WAL files
347    pub fn recover(&self) -> Result<Vec<WalEntry>> {
348        tracing::info!("Starting WAL recovery");
349
350        let mut all_entries = Vec::new();
351        let mut last_checkpoint_seq = 0u64;
352
353        // Find all WAL files
354        let mut wal_files: Vec<_> = std::fs::read_dir(&self.config.wal_directory)?
355            .filter_map(|entry| entry.ok())
356            .filter(|entry| {
357                entry
358                    .file_name()
359                    .to_str()
360                    .map(|s| s.starts_with("wal-") && s.ends_with(".log"))
361                    .unwrap_or(false)
362            })
363            .collect();
364
365        // Sort by filename (timestamp-based)
366        wal_files.sort_by_key(|entry| entry.file_name());
367
368        // Read all WAL files
369        for entry in wal_files {
370            let path = entry.path();
371            tracing::debug!("Reading WAL file: {:?}", path);
372
373            let file = File::open(&path)?;
374            let mut reader = BufReader::new(file);
375
376            // Verify magic number
377            let mut magic = [0u8; 4];
378            reader.read_exact(&mut magic)?;
379            if &magic != WAL_MAGIC {
380                tracing::warn!("Invalid WAL file magic number: {:?}", path);
381                continue;
382            }
383
384            // Read version
385            let mut version_bytes = [0u8; 4];
386            reader.read_exact(&mut version_bytes)?;
387            let version = u32::from_le_bytes(version_bytes);
388            if version != WAL_VERSION {
389                tracing::warn!("Unsupported WAL version {} in {:?}", version, path);
390                continue;
391            }
392
393            // Read file timestamp
394            let mut timestamp_bytes = [0u8; 8];
395            reader.read_exact(&mut timestamp_bytes)?;
396
397            // Read entries with robust error handling for incomplete writes
398            loop {
399                // Read sequence number
400                let mut seq_bytes = [0u8; 8];
401                match reader.read_exact(&mut seq_bytes) {
402                    Ok(_) => {}
403                    Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => {
404                        tracing::debug!("Reached end of WAL file (expected)");
405                        break;
406                    }
407                    Err(e) => return Err(e.into()),
408                }
409                let seq = u64::from_le_bytes(seq_bytes);
410
411                // Read entry length
412                let mut len_bytes = [0u8; 4];
413                match reader.read_exact(&mut len_bytes) {
414                    Ok(_) => {}
415                    Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => {
416                        tracing::warn!(
417                            "Incomplete entry at sequence {}: missing length field. Skipping rest of file.",
418                            seq
419                        );
420                        break;
421                    }
422                    Err(e) => return Err(e.into()),
423                }
424                let len = u32::from_le_bytes(len_bytes);
425
426                // Sanity check on entry length (prevent excessive memory allocation)
427                if len > 100_000_000 {
428                    // 100MB max entry size
429                    tracing::warn!(
430                        "Entry at sequence {} has suspicious length {}. Possibly corrupted. Skipping.",
431                        seq,
432                        len
433                    );
434                    break;
435                }
436
437                // Read entry data
438                let mut entry_bytes = vec![0u8; len as usize];
439                match reader.read_exact(&mut entry_bytes) {
440                    Ok(_) => {}
441                    Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => {
442                        tracing::warn!(
443                            "Incomplete entry at sequence {}: expected {} bytes but reached EOF. Skipping rest of file.",
444                            seq,
445                            len
446                        );
447                        break;
448                    }
449                    Err(e) => return Err(e.into()),
450                }
451
452                // Deserialize entry
453                let entry: WalEntry =
454                    match bincode::decode_from_slice(&entry_bytes, bincode::config::standard()) {
455                        Ok((e, _)) => e,
456                        Err(e) => {
457                            tracing::warn!(
458                                "Failed to deserialize entry at sequence {}: {}. Skipping entry.",
459                                seq,
460                                e
461                            );
462                            continue; // Skip corrupted entry but continue reading
463                        }
464                    };
465
466                // Track last checkpoint
467                if let WalEntry::Checkpoint {
468                    sequence_number, ..
469                } = &entry
470                {
471                    last_checkpoint_seq = *sequence_number;
472                }
473
474                all_entries.push((seq, entry));
475            }
476        }
477
478        // Filter entries after last checkpoint
479        // Note: If last_checkpoint_seq == 0 (no checkpoint), recover all entries including seq 0
480        // Otherwise, only recover entries strictly after the checkpoint
481        let recovered_entries: Vec<_> = all_entries
482            .iter()
483            .filter(|(seq, _)| {
484                if last_checkpoint_seq == 0 {
485                    true // No checkpoint, recover everything
486                } else {
487                    *seq > last_checkpoint_seq // Checkpoint exists, only after it
488                }
489            })
490            .map(|(_, entry)| entry.clone())
491            .collect();
492
493        tracing::info!(
494            "Recovered {} entries from WAL (after checkpoint {})",
495            recovered_entries.len(),
496            last_checkpoint_seq
497        );
498
499        // Update sequence number based on the maximum sequence number seen
500        if let Some((max_seq, _)) = all_entries.iter().max_by_key(|(seq, _)| seq) {
501            let mut seq = self.sequence_number.lock().unwrap();
502            *seq = max_seq + 1;
503        }
504
505        Ok(recovered_entries)
506    }
507
508    /// Flush all pending writes to disk
509    pub fn flush(&self) -> Result<()> {
510        let mut file_guard = self.current_file.lock().unwrap();
511        if let Some(ref mut writer) = *file_guard {
512            writer.flush()?;
513            writer.get_ref().sync_all()?;
514        }
515        Ok(())
516    }
517
518    /// Get current sequence number
519    pub fn current_sequence(&self) -> u64 {
520        *self.sequence_number.lock().unwrap()
521    }
522
523    /// Get last checkpoint sequence number
524    pub fn last_checkpoint_sequence(&self) -> u64 {
525        *self.last_checkpoint.lock().unwrap()
526    }
527}
528
529impl Drop for WalManager {
530    fn drop(&mut self) {
531        // Ensure all data is flushed on drop
532        let _ = self.flush();
533    }
534}
535
536#[cfg(test)]
537mod tests {
538    use super::*;
539    use tempfile::TempDir;
540
541    #[test]
542    fn test_wal_creation() {
543        let temp_dir = TempDir::new().unwrap();
544        let config = WalConfig {
545            wal_directory: temp_dir.path().to_path_buf(),
546            ..Default::default()
547        };
548
549        let wal = WalManager::new(config).unwrap();
550        assert_eq!(wal.current_sequence(), 0);
551    }
552
553    #[test]
554    fn test_wal_append() {
555        let temp_dir = TempDir::new().unwrap();
556        let config = WalConfig {
557            wal_directory: temp_dir.path().to_path_buf(),
558            sync_on_write: true,
559            ..Default::default()
560        };
561
562        let wal = WalManager::new(config).unwrap();
563
564        let entry = WalEntry::Insert {
565            id: "vec1".to_string(),
566            vector: vec![1.0, 2.0, 3.0],
567            metadata: None,
568            timestamp: 12345,
569        };
570
571        let seq = wal.append(entry).unwrap();
572        assert_eq!(seq, 0);
573    }
574
575    #[test]
576    fn test_wal_recovery() {
577        let temp_dir = TempDir::new().unwrap();
578        let config = WalConfig {
579            wal_directory: temp_dir.path().to_path_buf(),
580            sync_on_write: true,
581            checkpoint_interval: 100,
582            ..Default::default()
583        };
584
585        // Write some entries
586        {
587            let wal = WalManager::new(config.clone()).unwrap();
588
589            for i in 0..5 {
590                let entry = WalEntry::Insert {
591                    id: format!("vec{}", i),
592                    vector: vec![i as f32, (i * 2) as f32],
593                    metadata: None,
594                    timestamp: (i + 1) * 1000, // Use unique timestamps
595                };
596                wal.append(entry).unwrap();
597            }
598
599            wal.flush().unwrap();
600            // Ensure Drop is called to flush everything
601            drop(wal);
602        }
603
604        // Small delay to ensure file is written
605        std::thread::sleep(std::time::Duration::from_millis(100));
606
607        // Recover
608        {
609            let wal = WalManager::new(config).unwrap();
610            let recovered = wal.recover().unwrap();
611
612            // Should recover 5 entries
613            assert_eq!(
614                recovered.len(),
615                5,
616                "Expected exactly 5 entries, got {}",
617                recovered.len()
618            );
619
620            // Verify all timestamps are present
621            let timestamps: Vec<u64> = recovered.iter().map(|e| e.timestamp()).collect();
622            assert_eq!(timestamps, vec![1000, 2000, 3000, 4000, 5000]);
623        }
624    }
625
626    #[test]
627    fn test_wal_checkpoint() {
628        let temp_dir = TempDir::new().unwrap();
629        let config = WalConfig {
630            wal_directory: temp_dir.path().to_path_buf(),
631            sync_on_write: true,
632            checkpoint_interval: 3,
633            ..Default::default()
634        };
635
636        let wal = WalManager::new(config).unwrap();
637
638        // Write entries (should trigger checkpoint)
639        for i in 0..5 {
640            let entry = WalEntry::Insert {
641                id: format!("vec{}", i),
642                vector: vec![i as f32],
643                metadata: None,
644                timestamp: i,
645            };
646            wal.append(entry).unwrap();
647        }
648
649        assert!(wal.last_checkpoint_sequence() > 0);
650    }
651
652    #[test]
653    fn test_wal_batch_operation() {
654        let temp_dir = TempDir::new().unwrap();
655        let config = WalConfig {
656            wal_directory: temp_dir.path().to_path_buf(),
657            ..Default::default()
658        };
659
660        let wal = WalManager::new(config).unwrap();
661
662        let batch = WalEntry::Batch {
663            entries: vec![
664                WalEntry::Insert {
665                    id: "vec1".to_string(),
666                    vector: vec![1.0],
667                    metadata: None,
668                    timestamp: 1,
669                },
670                WalEntry::Update {
671                    id: "vec2".to_string(),
672                    vector: vec![2.0],
673                    metadata: None,
674                    timestamp: 2,
675                },
676            ],
677            timestamp: 3,
678        };
679
680        wal.append(batch).unwrap();
681        wal.flush().unwrap();
682    }
683
684    #[test]
685    fn test_wal_transaction() {
686        let temp_dir = TempDir::new().unwrap();
687        let config = WalConfig {
688            wal_directory: temp_dir.path().to_path_buf(),
689            ..Default::default()
690        };
691
692        let wal = WalManager::new(config).unwrap();
693
694        // Begin transaction
695        wal.append(WalEntry::BeginTransaction {
696            transaction_id: 1,
697            timestamp: 100,
698        })
699        .unwrap();
700
701        // Operations
702        wal.append(WalEntry::Insert {
703            id: "vec1".to_string(),
704            vector: vec![1.0],
705            metadata: None,
706            timestamp: 101,
707        })
708        .unwrap();
709
710        // Commit transaction
711        wal.append(WalEntry::CommitTransaction {
712            transaction_id: 1,
713            timestamp: 102,
714        })
715        .unwrap();
716
717        wal.flush().unwrap();
718    }
719}