Skip to main content

oxirs_vec/persistence/
point_in_time.rs

1//! Point-in-time snapshot restore for vector indexes.
2//!
3//! Enables recovery of a vector store to its exact state at any historical
4//! timestamp by:
5//! 1. Locating the latest WAL checkpoint whose timestamp precedes the target.
6//! 2. Replaying WAL entries from that checkpoint up to (and including) the
7//!    target timestamp.
8//!
9//! WAL timestamps are Unix-epoch seconds (`u64`) matching the format used by
10//! [`crate::wal::WalEntry`].
11
12use crate::wal::{WalConfig, WalEntry, WalManager};
13use anyhow::{anyhow, Result};
14use std::path::{Path, PathBuf};
15
16// ─────────────────────────────────────────────────────────────────────────────
17// CheckpointRef
18// ─────────────────────────────────────────────────────────────────────────────
19
20/// Reference to a WAL checkpoint discovered during point-in-time search.
21#[derive(Debug, Clone)]
22pub struct CheckpointRef {
23    /// Sequence number of the checkpoint marker in the WAL.
24    pub sequence_number: u64,
25    /// Unix-epoch timestamp of the checkpoint (seconds).
26    pub timestamp: u64,
27}
28
29impl Default for CheckpointRef {
30    /// An empty base — use the very beginning of the WAL (sequence 0, epoch 0).
31    fn default() -> Self {
32        Self {
33            sequence_number: 0,
34            timestamp: 0,
35        }
36    }
37}
38
39// ─────────────────────────────────────────────────────────────────────────────
40// PointInTimeRestore
41// ─────────────────────────────────────────────────────────────────────────────
42
43/// Driver for point-in-time recovery.
44///
45/// Usage:
46/// ```ignore
47/// let pit = PointInTimeRestore::new(target_ts, wal_dir);
48/// let base = pit.find_base_checkpoint()?;
49/// let entries = pit.replay_wal_to_timestamp(base.as_ref())?;
50/// // Apply entries to the index...
51/// ```
52pub struct PointInTimeRestore {
53    /// Target Unix-epoch timestamp in seconds.
54    pub target_timestamp: u64,
55    /// Directory that contains the WAL files.
56    pub wal_dir: PathBuf,
57}
58
59impl PointInTimeRestore {
60    /// Create a new driver.
61    ///
62    /// `target_timestamp_secs` is seconds since Unix epoch; callers that have a
63    /// `std::time::SystemTime` should convert with
64    /// `system_time.duration_since(UNIX_EPOCH).unwrap_or_default().as_secs()`.
65    pub fn new(target_timestamp_secs: u64, wal_dir: PathBuf) -> Self {
66        Self {
67            target_timestamp: target_timestamp_secs,
68            wal_dir,
69        }
70    }
71
72    /// Find the latest [`CheckpointRef`] whose timestamp is ≤ `target_timestamp`.
73    ///
74    /// Returns `None` when no checkpoint precedes the target (the caller should
75    /// treat that as "start from an empty base").
76    pub fn find_base_checkpoint(&self) -> Result<Option<CheckpointRef>> {
77        let entries = self.read_all_wal_entries()?;
78
79        let mut best: Option<CheckpointRef> = None;
80
81        for entry in &entries {
82            if let WalEntry::Checkpoint {
83                sequence_number,
84                timestamp,
85            } = entry
86            {
87                if *timestamp <= self.target_timestamp {
88                    let candidate = CheckpointRef {
89                        sequence_number: *sequence_number,
90                        timestamp: *timestamp,
91                    };
92                    match &best {
93                        None => best = Some(candidate),
94                        Some(prev) if candidate.timestamp > prev.timestamp => {
95                            best = Some(candidate)
96                        }
97                        _ => {}
98                    }
99                }
100            }
101        }
102
103        Ok(best)
104    }
105
106    /// Replay WAL entries that fall between `base.sequence_number` (exclusive)
107    /// and `target_timestamp` (inclusive).
108    ///
109    /// When `base` is `None` (no prior checkpoint) all entries up to
110    /// `target_timestamp` are returned.
111    pub fn replay_wal_to_timestamp(&self, base: Option<&CheckpointRef>) -> Result<Vec<WalEntry>> {
112        let base_seq = base.map(|b| b.sequence_number).unwrap_or(0);
113        let all_entries = self.read_all_indexed_wal_entries()?;
114
115        let mut result = Vec::new();
116        for (seq, entry) in all_entries {
117            // Skip everything at or before the base checkpoint sequence number
118            if seq <= base_seq && base.is_some() {
119                continue;
120            }
121            // Skip entries beyond the target timestamp
122            if entry.timestamp() > self.target_timestamp {
123                continue;
124            }
125            // Skip structural markers — they have no data to replay
126            if entry.is_checkpoint() {
127                continue;
128            }
129            match &entry {
130                WalEntry::BeginTransaction { .. }
131                | WalEntry::CommitTransaction { .. }
132                | WalEntry::AbortTransaction { .. } => {
133                    // Skip transaction bookkeeping; only data entries matter
134                }
135                _ => result.push(entry),
136            }
137        }
138
139        Ok(result)
140    }
141
142    // ─────────────────────────────────────────────────────────────────────────
143    // Private helpers
144    // ─────────────────────────────────────────────────────────────────────────
145
146    /// Read all WAL entries in chronological file order.
147    fn read_all_wal_entries(&self) -> Result<Vec<WalEntry>> {
148        Ok(self
149            .read_all_indexed_wal_entries()?
150            .into_iter()
151            .map(|(_, e)| e)
152            .collect())
153    }
154
155    /// Read all WAL entries together with their sequence numbers, ordered by
156    /// (file name, sequence number).  WAL files are named `wal-<hex_ts>.log`
157    /// so lexicographic order equals chronological order.
158    fn read_all_indexed_wal_entries(&self) -> Result<Vec<(u64, WalEntry)>> {
159        if !self.wal_dir.exists() {
160            return Ok(Vec::new());
161        }
162
163        // Open a temporary WalManager in read-only mode (just to run recover())
164        let config = WalConfig {
165            wal_directory: self.wal_dir.clone(),
166            // Very large interval so we never auto-trigger a checkpoint during
167            // our recovery scan.
168            checkpoint_interval: u64::MAX,
169            checkpoint_retention: usize::MAX,
170            sync_on_write: false,
171            ..WalConfig::default()
172        };
173        let mgr = WalManager::new(config)
174            .map_err(|e| anyhow!("Cannot open WAL for PIT recovery: {}", e))?;
175
176        // recover() returns entries *after* the last checkpoint — we need all
177        // entries, so we scan the files directly via the WalManager's recover
178        // (it skips nothing when checkpoint_retention is huge and the interval
179        // is MAX, because no new checkpoint will be written).
180        //
181        // However recover() still filters on last_checkpoint_seq.  To get
182        // *everything*, we need to parse files ourselves.  We borrow the
183        // helper below.
184        let entries = self.scan_wal_files(&self.wal_dir)?;
185        drop(mgr);
186        Ok(entries)
187    }
188
189    /// Low-level file scanner: returns (sequence_number, WalEntry) for every
190    /// parseable record across all `wal-*.log` files in `dir`.
191    fn scan_wal_files(&self, dir: &Path) -> Result<Vec<(u64, WalEntry)>> {
192        use std::fs::File;
193        use std::io::{BufReader, Read};
194
195        const WAL_MAGIC: &[u8; 4] = b"WALV";
196
197        let mut wal_files: Vec<_> = std::fs::read_dir(dir)?
198            .filter_map(|e| e.ok())
199            .filter(|e| {
200                e.file_name()
201                    .to_str()
202                    .map(|s| s.starts_with("wal-") && s.ends_with(".log"))
203                    .unwrap_or(false)
204            })
205            .collect();
206
207        // Lexicographic == chronological for hex-timestamp file names
208        wal_files.sort_by_key(|e| e.file_name());
209
210        let mut result: Vec<(u64, WalEntry)> = Vec::new();
211
212        for file_entry in wal_files {
213            let path = file_entry.path();
214            let file = match File::open(&path) {
215                Ok(f) => f,
216                Err(e) => {
217                    tracing::warn!("PIT: cannot open WAL file {:?}: {}", path, e);
218                    continue;
219                }
220            };
221            let mut reader = BufReader::new(file);
222
223            // Validate magic
224            let mut magic = [0u8; 4];
225            if reader.read_exact(&mut magic).is_err() {
226                continue;
227            }
228            if &magic != WAL_MAGIC {
229                tracing::warn!("PIT: invalid magic in {:?}", path);
230                continue;
231            }
232
233            // Version (4 bytes) + file timestamp (8 bytes) — skip both
234            let mut skip = [0u8; 12];
235            if reader.read_exact(&mut skip).is_err() {
236                continue;
237            }
238
239            loop {
240                // Sequence number
241                let mut seq_bytes = [0u8; 8];
242                match reader.read_exact(&mut seq_bytes) {
243                    Ok(_) => {}
244                    Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => break,
245                    Err(e) => {
246                        tracing::warn!("PIT: read error in {:?}: {}", path, e);
247                        break;
248                    }
249                }
250                let seq = u64::from_le_bytes(seq_bytes);
251
252                // Entry length
253                let mut len_bytes = [0u8; 4];
254                match reader.read_exact(&mut len_bytes) {
255                    Ok(_) => {}
256                    Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => break,
257                    Err(e) => {
258                        tracing::warn!("PIT: read error in {:?}: {}", path, e);
259                        break;
260                    }
261                }
262                let len = u32::from_le_bytes(len_bytes) as usize;
263
264                // Sanity guard
265                if len > 100_000_000 {
266                    tracing::warn!(
267                        "PIT: suspicious entry length {} at seq {} in {:?}",
268                        len,
269                        seq,
270                        path
271                    );
272                    break;
273                }
274
275                // Entry bytes
276                let mut entry_bytes = vec![0u8; len];
277                match reader.read_exact(&mut entry_bytes) {
278                    Ok(_) => {}
279                    Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => break,
280                    Err(e) => {
281                        tracing::warn!("PIT: read error in {:?}: {}", path, e);
282                        break;
283                    }
284                }
285
286                match oxicode::serde::decode_from_slice::<WalEntry, _>(
287                    &entry_bytes,
288                    oxicode::config::standard(),
289                ) {
290                    Ok((entry, _)) => result.push((seq, entry)),
291                    Err(e) => {
292                        tracing::warn!("PIT: cannot deserialise entry at seq {}: {}", seq, e);
293                    }
294                }
295            }
296        }
297
298        Ok(result)
299    }
300}
301
302// ─────────────────────────────────────────────────────────────────────────────
303// Tests
304// ─────────────────────────────────────────────────────────────────────────────
305
306#[cfg(test)]
307mod tests {
308    use super::*;
309    use crate::wal::{WalConfig, WalEntry, WalManager};
310    use tempfile::TempDir;
311
312    fn write_test_wal(dir: &std::path::Path, entries: &[WalEntry]) -> Result<()> {
313        let config = WalConfig {
314            wal_directory: dir.to_path_buf(),
315            checkpoint_interval: u64::MAX,
316            sync_on_write: true,
317            ..WalConfig::default()
318        };
319        let mgr = WalManager::new(config)?;
320        for entry in entries {
321            mgr.append(entry.clone())?;
322        }
323        mgr.flush()?;
324        Ok(())
325    }
326
327    #[test]
328    fn test_find_base_checkpoint_no_checkpoint() -> Result<()> {
329        let tmp = TempDir::new()?;
330
331        let entries = vec![
332            WalEntry::Insert {
333                id: "v1".into(),
334                vector: vec![1.0],
335                metadata: None,
336                timestamp: 1000,
337            },
338            WalEntry::Insert {
339                id: "v2".into(),
340                vector: vec![2.0],
341                metadata: None,
342                timestamp: 2000,
343            },
344        ];
345        write_test_wal(tmp.path(), &entries)?;
346
347        let pit = PointInTimeRestore::new(5000, tmp.path().to_path_buf());
348        let base = pit.find_base_checkpoint()?;
349        assert!(base.is_none(), "expected None when no checkpoint exists");
350        Ok(())
351    }
352
353    #[test]
354    fn test_find_base_checkpoint_selects_latest_before_target() -> Result<()> {
355        let tmp = TempDir::new()?;
356
357        let entries = vec![
358            WalEntry::Checkpoint {
359                sequence_number: 0,
360                timestamp: 1000,
361            },
362            WalEntry::Checkpoint {
363                sequence_number: 1,
364                timestamp: 3000,
365            },
366            WalEntry::Checkpoint {
367                sequence_number: 2,
368                timestamp: 6000,
369            },
370        ];
371        write_test_wal(tmp.path(), &entries)?;
372
373        let pit = PointInTimeRestore::new(4000, tmp.path().to_path_buf());
374        let base = pit.find_base_checkpoint()?;
375        let base = base.expect("should find a checkpoint");
376        assert_eq!(base.timestamp, 3000, "should pick checkpoint at ts=3000");
377        Ok(())
378    }
379
380    #[test]
381    fn test_replay_wal_to_timestamp_filters_correctly() -> Result<()> {
382        let tmp = TempDir::new()?;
383
384        let raw = vec![
385            WalEntry::Insert {
386                id: "v1".into(),
387                vector: vec![1.0],
388                metadata: None,
389                timestamp: 1000,
390            },
391            WalEntry::Insert {
392                id: "v2".into(),
393                vector: vec![2.0],
394                metadata: None,
395                timestamp: 2000,
396            },
397            WalEntry::Insert {
398                id: "v3".into(),
399                vector: vec![3.0],
400                metadata: None,
401                timestamp: 4000,
402            },
403        ];
404        write_test_wal(tmp.path(), &raw)?;
405
406        // Target = 2500 → should replay v1 and v2 but NOT v3
407        let pit = PointInTimeRestore::new(2500, tmp.path().to_path_buf());
408        let replayed = pit.replay_wal_to_timestamp(None)?;
409        assert_eq!(replayed.len(), 2);
410        let ids: Vec<_> = replayed
411            .iter()
412            .filter_map(|e| {
413                if let WalEntry::Insert { id, .. } = e {
414                    Some(id.as_str())
415                } else {
416                    None
417                }
418            })
419            .collect();
420        assert!(ids.contains(&"v1"));
421        assert!(ids.contains(&"v2"));
422        assert!(!ids.contains(&"v3"));
423
424        Ok(())
425    }
426
427    #[test]
428    fn test_checkpoint_discovery_ordered() -> Result<()> {
429        let tmp = TempDir::new()?;
430
431        let timestamps = [500u64, 1500, 2500, 3500, 4500];
432        let entries: Vec<WalEntry> = timestamps
433            .iter()
434            .enumerate()
435            .map(|(i, &ts)| WalEntry::Checkpoint {
436                sequence_number: i as u64,
437                timestamp: ts,
438            })
439            .collect();
440        write_test_wal(tmp.path(), &entries)?;
441
442        // Target = 3000 → best checkpoint is ts=2500 (seq=2)
443        let pit = PointInTimeRestore::new(3000, tmp.path().to_path_buf());
444        let base = pit.find_base_checkpoint()?.expect("checkpoint expected");
445        assert_eq!(base.timestamp, 2500);
446        assert_eq!(base.sequence_number, 2);
447
448        Ok(())
449    }
450}