Skip to main content

storage/
wal.rs

1//! Write-Ahead Log (WAL) for Buffer
2//!
3//! Provides durability and crash recovery through sequential logging:
4//! - All mutations logged before application
5//! - Supports checkpointing for log compaction
6//! - Recovery from incomplete transactions
7
8use common::{DakeraError, Result};
9use parking_lot::RwLock;
10use serde::{Deserialize, Serialize};
11use std::fs::{self, File, OpenOptions};
12use std::io::{BufRead, BufReader, BufWriter, Write};
13use std::path::{Path, PathBuf};
14use std::sync::atomic::{AtomicU64, Ordering};
15
16/// WAL configuration
17#[derive(Debug, Clone)]
18pub struct WalConfig {
19    /// Directory for WAL files
20    pub wal_dir: PathBuf,
21    /// Maximum WAL segment size in bytes
22    pub max_segment_size: u64,
23    /// Sync mode for durability
24    pub sync_mode: WalSyncMode,
25    /// Maximum entries before forced checkpoint
26    pub checkpoint_threshold: u64,
27}
28
29impl Default for WalConfig {
30    fn default() -> Self {
31        Self {
32            wal_dir: PathBuf::from("./data/wal"),
33            max_segment_size: 64 * 1024 * 1024, // 64MB
34            sync_mode: WalSyncMode::EveryWrite,
35            checkpoint_threshold: 10000,
36        }
37    }
38}
39
40/// Sync mode for WAL durability
41#[derive(Debug, Clone, Copy, PartialEq, Eq)]
42pub enum WalSyncMode {
43    /// Sync after every write (safest)
44    EveryWrite,
45    /// Sync every N writes
46    Periodic(u32),
47    /// Sync on explicit flush only
48    Manual,
49}
50
51/// WAL entry types
52#[derive(Debug, Clone, Serialize, Deserialize)]
53pub enum WalEntry {
54    /// Insert or update vectors
55    Upsert {
56        namespace: String,
57        vectors: Vec<SerializedVector>,
58    },
59    /// Delete vectors
60    Delete { namespace: String, ids: Vec<String> },
61    /// Create namespace
62    CreateNamespace { namespace: String },
63    /// Delete namespace
64    DeleteNamespace { namespace: String },
65    /// Checkpoint marker
66    Checkpoint { lsn: u64 },
67    /// Transaction begin
68    TxnBegin { txn_id: u64 },
69    /// Transaction commit
70    TxnCommit { txn_id: u64 },
71    /// Transaction rollback
72    TxnRollback { txn_id: u64 },
73}
74
75/// Serialized vector for WAL storage
76#[derive(Debug, Clone, Serialize, Deserialize)]
77pub struct SerializedVector {
78    pub id: String,
79    pub values: Vec<f32>,
80    pub metadata: Option<String>,
81}
82
83/// WAL segment file
84#[derive(Debug)]
85struct WalSegment {
86    /// Current size
87    size: u64,
88    /// Starting LSN
89    start_lsn: u64,
90    /// Ending LSN
91    end_lsn: u64,
92}
93
94/// Write-Ahead Log manager
95pub struct WriteAheadLog {
96    config: WalConfig,
97    /// Current log sequence number
98    lsn: AtomicU64,
99    /// Current active segment
100    current_segment: RwLock<Option<WalSegment>>,
101    /// Active writer
102    writer: RwLock<Option<BufWriter<File>>>,
103    /// Entries since last checkpoint
104    entries_since_checkpoint: AtomicU64,
105    /// Last checkpointed LSN
106    last_checkpoint_lsn: AtomicU64,
107    /// Write count for periodic sync
108    write_count: AtomicU64,
109}
110
111impl WriteAheadLog {
112    /// Create a new WAL
113    pub fn new(config: WalConfig) -> Result<Self> {
114        // Ensure WAL directory exists
115        fs::create_dir_all(&config.wal_dir)
116            .map_err(|e| DakeraError::Storage(format!("Failed to create WAL dir: {}", e)))?;
117
118        let wal = Self {
119            config,
120            lsn: AtomicU64::new(0),
121            current_segment: RwLock::new(None),
122            writer: RwLock::new(None),
123            entries_since_checkpoint: AtomicU64::new(0),
124            last_checkpoint_lsn: AtomicU64::new(0),
125            write_count: AtomicU64::new(0),
126        };
127
128        // Recover LSN from existing segments
129        wal.recover_lsn()?;
130
131        Ok(wal)
132    }
133
134    /// Recover LSN from existing WAL segments
135    fn recover_lsn(&self) -> Result<()> {
136        let segments = self.list_segments()?;
137        if let Some(last_segment) = segments.last() {
138            // Read the last segment to find the highest LSN
139            let entries = self.read_segment(last_segment)?;
140            if let Some(last_entry) = entries.last() {
141                self.lsn.store(last_entry.0 + 1, Ordering::SeqCst);
142            }
143        }
144        Ok(())
145    }
146
147    /// List all WAL segment files
148    fn list_segments(&self) -> Result<Vec<PathBuf>> {
149        let mut segments = Vec::new();
150
151        if let Ok(entries) = fs::read_dir(&self.config.wal_dir) {
152            for entry in entries.flatten() {
153                let path = entry.path();
154                if path.extension().map(|e| e == "wal").unwrap_or(false) {
155                    segments.push(path);
156                }
157            }
158        }
159
160        segments.sort();
161        Ok(segments)
162    }
163
164    /// Read entries from a segment file
165    fn read_segment(&self, path: &Path) -> Result<Vec<(u64, WalEntry)>> {
166        let file = File::open(path)
167            .map_err(|e| DakeraError::Storage(format!("Failed to open WAL: {}", e)))?;
168
169        let reader = BufReader::new(file);
170        let mut entries = Vec::new();
171
172        for line_result in reader.lines() {
173            let line =
174                line_result.map_err(|e| DakeraError::Storage(format!("WAL read error: {}", e)))?;
175
176            if line.trim().is_empty() {
177                continue;
178            }
179
180            // Format: LSN|JSON_ENTRY
181            if let Some((lsn_str, entry_json)) = line.split_once('|') {
182                let lsn: u64 = lsn_str.parse().unwrap_or(0);
183                if let Ok(entry) = serde_json::from_str::<WalEntry>(entry_json) {
184                    entries.push((lsn, entry));
185                }
186            }
187        }
188
189        Ok(entries)
190    }
191
192    /// Append an entry to the WAL
193    pub fn append(&self, entry: WalEntry) -> Result<u64> {
194        let lsn = self.lsn.fetch_add(1, Ordering::SeqCst);
195
196        // Ensure we have an active segment
197        self.ensure_segment()?;
198
199        // Serialize and write
200        let entry_json = serde_json::to_string(&entry)
201            .map_err(|e| DakeraError::Storage(format!("WAL serialize error: {}", e)))?;
202
203        let line = format!("{}|{}\n", lsn, entry_json);
204
205        {
206            let mut writer_guard = self.writer.write();
207            let writer = writer_guard.as_mut().ok_or_else(|| {
208                DakeraError::Storage("WAL writer not available after ensure_segment".to_string())
209            })?;
210
211            writer
212                .write_all(line.as_bytes())
213                .map_err(|e| DakeraError::Storage(format!("WAL write error: {}", e)))?;
214
215            // Handle sync based on mode
216            let write_count = self.write_count.fetch_add(1, Ordering::Relaxed) + 1;
217            match self.config.sync_mode {
218                WalSyncMode::EveryWrite => {
219                    if let Err(e) = writer.flush() {
220                        tracing::warn!(error = %e, "WAL flush failed during EveryWrite sync");
221                    }
222                }
223                WalSyncMode::Periodic(n) if n > 0 && write_count.is_multiple_of(n as u64) => {
224                    if let Err(e) = writer.flush() {
225                        tracing::warn!(error = %e, write_count, "WAL flush failed during periodic sync");
226                    }
227                }
228                _ => {}
229            }
230        }
231
232        // Update segment size
233        {
234            let mut segment_guard = self.current_segment.write();
235            if let Some(ref mut segment) = *segment_guard {
236                segment.size += line.len() as u64;
237                segment.end_lsn = lsn;
238            }
239        }
240
241        // Track entries for checkpoint threshold
242        let _entries = self
243            .entries_since_checkpoint
244            .fetch_add(1, Ordering::Relaxed);
245
246        Ok(lsn)
247    }
248
249    /// Ensure we have an active segment for writing
250    fn ensure_segment(&self) -> Result<()> {
251        let needs_new_segment = {
252            let segment_guard = self.current_segment.read();
253            match &*segment_guard {
254                None => true,
255                Some(seg) => seg.size >= self.config.max_segment_size,
256            }
257        };
258
259        if needs_new_segment {
260            self.rotate_segment()?;
261        }
262
263        Ok(())
264    }
265
266    /// Rotate to a new segment file
267    fn rotate_segment(&self) -> Result<()> {
268        // Close current writer
269        {
270            let mut writer_guard = self.writer.write();
271            if let Some(ref mut writer) = *writer_guard {
272                if let Err(e) = writer.flush() {
273                    tracing::warn!(error = %e, "WAL flush failed during segment rotation");
274                }
275            }
276            *writer_guard = None;
277        }
278
279        // Create new segment
280        let current_lsn = self.lsn.load(Ordering::SeqCst);
281        let segment_id = current_lsn;
282        let segment_path = self.config.wal_dir.join(format!("{:020}.wal", segment_id));
283
284        let file = OpenOptions::new()
285            .create(true)
286            .append(true)
287            .open(&segment_path)
288            .map_err(|e| DakeraError::Storage(format!("Failed to create WAL segment: {}", e)))?;
289
290        let writer = BufWriter::new(file);
291
292        // Update state
293        {
294            let mut segment_guard = self.current_segment.write();
295            *segment_guard = Some(WalSegment {
296                size: 0,
297                start_lsn: current_lsn,
298                end_lsn: current_lsn,
299            });
300        }
301
302        tracing::debug!(
303            segment_id = segment_id,
304            path = %segment_path.display(),
305            "Created new WAL segment"
306        );
307
308        {
309            let mut writer_guard = self.writer.write();
310            *writer_guard = Some(writer);
311        }
312
313        Ok(())
314    }
315
316    /// Write a checkpoint marker
317    pub fn checkpoint(&self) -> Result<u64> {
318        let lsn = self.lsn.load(Ordering::SeqCst);
319
320        // Write checkpoint entry
321        self.append(WalEntry::Checkpoint { lsn })?;
322
323        // Update checkpoint tracking
324        self.last_checkpoint_lsn.store(lsn, Ordering::SeqCst);
325        self.entries_since_checkpoint.store(0, Ordering::SeqCst);
326
327        // Flush writer
328        {
329            let mut writer_guard = self.writer.write();
330            if let Some(ref mut writer) = *writer_guard {
331                if let Err(e) = writer.flush() {
332                    tracing::warn!(error = %e, lsn, "WAL flush failed during checkpoint");
333                }
334            }
335        }
336
337        Ok(lsn)
338    }
339
340    /// Recover entries since last checkpoint
341    pub fn recover(&self) -> Result<Vec<WalEntry>> {
342        let segments = self.list_segments()?;
343        let checkpoint_lsn = self.last_checkpoint_lsn.load(Ordering::SeqCst);
344
345        let mut entries = Vec::new();
346
347        for segment_path in segments {
348            let segment_entries = self.read_segment(&segment_path)?;
349
350            for (lsn, entry) in segment_entries {
351                // Skip entries before checkpoint (but only if checkpoint was done)
352                if checkpoint_lsn > 0 && lsn <= checkpoint_lsn {
353                    // But track checkpoint LSN updates
354                    if let WalEntry::Checkpoint { lsn: cp_lsn } = entry {
355                        self.last_checkpoint_lsn.store(cp_lsn, Ordering::SeqCst);
356                    }
357                    continue;
358                }
359
360                // Skip transaction control entries for recovery
361                match entry {
362                    WalEntry::TxnBegin { .. }
363                    | WalEntry::TxnCommit { .. }
364                    | WalEntry::TxnRollback { .. }
365                    | WalEntry::Checkpoint { .. } => continue,
366                    _ => entries.push(entry),
367                }
368            }
369        }
370
371        Ok(entries)
372    }
373
374    /// Truncate WAL up to (and including) the given LSN
375    pub fn truncate(&self, up_to_lsn: u64) -> Result<u64> {
376        let segments = self.list_segments()?;
377        let mut removed_count = 0u64;
378
379        // Get the active segment's start LSN to avoid deleting it
380        let active_start_lsn = {
381            let segment_guard = self.current_segment.read();
382            segment_guard.as_ref().map(|s| s.start_lsn)
383        };
384
385        for segment_path in segments {
386            // Check if entire segment is before truncation point
387            let segment_entries = self.read_segment(&segment_path)?;
388
389            if let Some((first_lsn, _)) = segment_entries.first() {
390                // Skip the currently active segment to avoid deleting a file being written to
391                if active_start_lsn == Some(*first_lsn) {
392                    continue;
393                }
394            }
395
396            if let Some((last_lsn, _)) = segment_entries.last() {
397                if *last_lsn <= up_to_lsn {
398                    // Safe to remove this segment
399                    fs::remove_file(&segment_path).ok();
400                    removed_count += segment_entries.len() as u64;
401                }
402            }
403        }
404
405        Ok(removed_count)
406    }
407
408    /// Get current LSN
409    pub fn current_lsn(&self) -> u64 {
410        self.lsn.load(Ordering::SeqCst)
411    }
412
413    /// Get WAL statistics
414    pub fn stats(&self) -> WalStats {
415        let segment_count = self.list_segments().map(|s| s.len()).unwrap_or(0);
416
417        let (current_segment_size, current_segment_entries) = {
418            let segment_guard = self.current_segment.read();
419            match &*segment_guard {
420                Some(seg) => (seg.size, seg.end_lsn.saturating_sub(seg.start_lsn)),
421                None => (0, 0),
422            }
423        };
424
425        WalStats {
426            current_lsn: self.lsn.load(Ordering::SeqCst),
427            last_checkpoint_lsn: self.last_checkpoint_lsn.load(Ordering::SeqCst),
428            segment_count,
429            current_segment_size,
430            current_segment_entries,
431            entries_since_checkpoint: self.entries_since_checkpoint.load(Ordering::Relaxed),
432        }
433    }
434
435    /// Flush WAL to disk
436    pub fn flush(&self) -> Result<()> {
437        let mut writer_guard = self.writer.write();
438        if let Some(ref mut writer) = *writer_guard {
439            writer
440                .flush()
441                .map_err(|e| DakeraError::Storage(format!("WAL flush error: {}", e)))?;
442        }
443        Ok(())
444    }
445}
446
447/// WAL statistics
448#[derive(Debug, Clone)]
449pub struct WalStats {
450    /// Current log sequence number
451    pub current_lsn: u64,
452    /// Last checkpointed LSN
453    pub last_checkpoint_lsn: u64,
454    /// Number of segment files
455    pub segment_count: usize,
456    /// Current segment size in bytes
457    pub current_segment_size: u64,
458    /// Entries in current segment
459    pub current_segment_entries: u64,
460    /// Entries since last checkpoint
461    pub entries_since_checkpoint: u64,
462}
463
464/// WAL-wrapped storage that logs all mutations
465pub struct WalStorage<S> {
466    /// Underlying storage
467    inner: S,
468    /// Write-ahead log
469    wal: WriteAheadLog,
470}
471
472impl<S> WalStorage<S> {
473    /// Create a new WAL-wrapped storage
474    pub fn new(inner: S, wal_config: WalConfig) -> Result<Self> {
475        let wal = WriteAheadLog::new(wal_config)?;
476        Ok(Self { inner, wal })
477    }
478
479    /// Get WAL reference
480    pub fn wal(&self) -> &WriteAheadLog {
481        &self.wal
482    }
483
484    /// Get inner storage reference
485    pub fn inner(&self) -> &S {
486        &self.inner
487    }
488
489    /// Checkpoint the WAL
490    pub fn checkpoint(&self) -> Result<u64> {
491        self.wal.checkpoint()
492    }
493}
494
495#[cfg(test)]
496mod tests {
497    use super::*;
498    use tempfile::TempDir;
499
500    fn test_config(dir: &Path) -> WalConfig {
501        WalConfig {
502            wal_dir: dir.to_path_buf(),
503            max_segment_size: 1024,
504            sync_mode: WalSyncMode::EveryWrite,
505            checkpoint_threshold: 100,
506        }
507    }
508
509    #[test]
510    fn test_wal_basic_operations() {
511        let temp_dir = TempDir::new().unwrap();
512        let config = test_config(temp_dir.path());
513        let wal = WriteAheadLog::new(config).unwrap();
514
515        // Append some entries
516        let lsn1 = wal
517            .append(WalEntry::CreateNamespace {
518                namespace: "test".to_string(),
519            })
520            .unwrap();
521
522        let lsn2 = wal
523            .append(WalEntry::Upsert {
524                namespace: "test".to_string(),
525                vectors: vec![SerializedVector {
526                    id: "v1".to_string(),
527                    values: vec![1.0, 2.0, 3.0],
528                    metadata: None,
529                }],
530            })
531            .unwrap();
532
533        assert_eq!(lsn1, 0);
534        assert_eq!(lsn2, 1);
535        assert_eq!(wal.current_lsn(), 2);
536    }
537
538    #[test]
539    fn test_wal_recovery() {
540        let temp_dir = TempDir::new().unwrap();
541        let config = test_config(temp_dir.path());
542
543        // Write some entries
544        {
545            let wal = WriteAheadLog::new(config.clone()).unwrap();
546            wal.append(WalEntry::CreateNamespace {
547                namespace: "test".to_string(),
548            })
549            .unwrap();
550            wal.append(WalEntry::Upsert {
551                namespace: "test".to_string(),
552                vectors: vec![SerializedVector {
553                    id: "v1".to_string(),
554                    values: vec![1.0, 2.0],
555                    metadata: None,
556                }],
557            })
558            .unwrap();
559            wal.flush().unwrap();
560        }
561
562        // Recover entries
563        {
564            let wal = WriteAheadLog::new(config).unwrap();
565            let entries = wal.recover().unwrap();
566
567            assert_eq!(entries.len(), 2);
568            assert!(matches!(entries[0], WalEntry::CreateNamespace { .. }));
569            assert!(matches!(entries[1], WalEntry::Upsert { .. }));
570        }
571    }
572
573    #[test]
574    fn test_wal_checkpoint() {
575        let temp_dir = TempDir::new().unwrap();
576        let config = test_config(temp_dir.path());
577        let wal = WriteAheadLog::new(config).unwrap();
578
579        // Write entries and checkpoint
580        wal.append(WalEntry::CreateNamespace {
581            namespace: "test".to_string(),
582        })
583        .unwrap();
584        let _checkpoint_lsn = wal.checkpoint().unwrap();
585
586        wal.append(WalEntry::Upsert {
587            namespace: "test".to_string(),
588            vectors: vec![],
589        })
590        .unwrap();
591
592        let stats = wal.stats();
593        assert!(stats.last_checkpoint_lsn > 0);
594        assert_eq!(stats.entries_since_checkpoint, 1); // One entry after checkpoint
595    }
596
597    #[test]
598    fn test_wal_stats() {
599        let temp_dir = TempDir::new().unwrap();
600        let config = test_config(temp_dir.path());
601        let wal = WriteAheadLog::new(config).unwrap();
602
603        for i in 0..5 {
604            wal.append(WalEntry::Upsert {
605                namespace: "test".to_string(),
606                vectors: vec![SerializedVector {
607                    id: format!("v{}", i),
608                    values: vec![i as f32],
609                    metadata: None,
610                }],
611            })
612            .unwrap();
613        }
614
615        let stats = wal.stats();
616        assert_eq!(stats.current_lsn, 5);
617        assert_eq!(stats.entries_since_checkpoint, 5);
618    }
619
620    #[test]
621    fn test_segment_rotation() {
622        let temp_dir = TempDir::new().unwrap();
623        let config = WalConfig {
624            wal_dir: temp_dir.path().to_path_buf(),
625            max_segment_size: 100, // Very small to force rotation
626            sync_mode: WalSyncMode::EveryWrite,
627            checkpoint_threshold: 1000,
628        };
629
630        let wal = WriteAheadLog::new(config).unwrap();
631
632        // Write enough to trigger rotation
633        for i in 0..10 {
634            wal.append(WalEntry::Upsert {
635                namespace: "test".to_string(),
636                vectors: vec![SerializedVector {
637                    id: format!("v{}", i),
638                    values: vec![i as f32; 10],
639                    metadata: Some("some metadata here".to_string()),
640                }],
641            })
642            .unwrap();
643        }
644
645        let stats = wal.stats();
646        assert!(stats.segment_count > 1);
647    }
648}