1use crate::wal::{WalConfig, WalEntry, WalManager};
13use anyhow::{anyhow, Result};
14use std::path::{Path, PathBuf};
15
16#[derive(Debug, Clone)]
22pub struct CheckpointRef {
23 pub sequence_number: u64,
25 pub timestamp: u64,
27}
28
29impl Default for CheckpointRef {
30 fn default() -> Self {
32 Self {
33 sequence_number: 0,
34 timestamp: 0,
35 }
36 }
37}
38
39pub struct PointInTimeRestore {
53 pub target_timestamp: u64,
55 pub wal_dir: PathBuf,
57}
58
59impl PointInTimeRestore {
60 pub fn new(target_timestamp_secs: u64, wal_dir: PathBuf) -> Self {
66 Self {
67 target_timestamp: target_timestamp_secs,
68 wal_dir,
69 }
70 }
71
72 pub fn find_base_checkpoint(&self) -> Result<Option<CheckpointRef>> {
77 let entries = self.read_all_wal_entries()?;
78
79 let mut best: Option<CheckpointRef> = None;
80
81 for entry in &entries {
82 if let WalEntry::Checkpoint {
83 sequence_number,
84 timestamp,
85 } = entry
86 {
87 if *timestamp <= self.target_timestamp {
88 let candidate = CheckpointRef {
89 sequence_number: *sequence_number,
90 timestamp: *timestamp,
91 };
92 match &best {
93 None => best = Some(candidate),
94 Some(prev) if candidate.timestamp > prev.timestamp => {
95 best = Some(candidate)
96 }
97 _ => {}
98 }
99 }
100 }
101 }
102
103 Ok(best)
104 }
105
106 pub fn replay_wal_to_timestamp(&self, base: Option<&CheckpointRef>) -> Result<Vec<WalEntry>> {
112 let base_seq = base.map(|b| b.sequence_number).unwrap_or(0);
113 let all_entries = self.read_all_indexed_wal_entries()?;
114
115 let mut result = Vec::new();
116 for (seq, entry) in all_entries {
117 if seq <= base_seq && base.is_some() {
119 continue;
120 }
121 if entry.timestamp() > self.target_timestamp {
123 continue;
124 }
125 if entry.is_checkpoint() {
127 continue;
128 }
129 match &entry {
130 WalEntry::BeginTransaction { .. }
131 | WalEntry::CommitTransaction { .. }
132 | WalEntry::AbortTransaction { .. } => {
133 }
135 _ => result.push(entry),
136 }
137 }
138
139 Ok(result)
140 }
141
142 fn read_all_wal_entries(&self) -> Result<Vec<WalEntry>> {
148 Ok(self
149 .read_all_indexed_wal_entries()?
150 .into_iter()
151 .map(|(_, e)| e)
152 .collect())
153 }
154
155 fn read_all_indexed_wal_entries(&self) -> Result<Vec<(u64, WalEntry)>> {
159 if !self.wal_dir.exists() {
160 return Ok(Vec::new());
161 }
162
163 let config = WalConfig {
165 wal_directory: self.wal_dir.clone(),
166 checkpoint_interval: u64::MAX,
169 checkpoint_retention: usize::MAX,
170 sync_on_write: false,
171 ..WalConfig::default()
172 };
173 let mgr = WalManager::new(config)
174 .map_err(|e| anyhow!("Cannot open WAL for PIT recovery: {}", e))?;
175
176 let entries = self.scan_wal_files(&self.wal_dir)?;
185 drop(mgr);
186 Ok(entries)
187 }
188
189 fn scan_wal_files(&self, dir: &Path) -> Result<Vec<(u64, WalEntry)>> {
192 use std::fs::File;
193 use std::io::{BufReader, Read};
194
195 const WAL_MAGIC: &[u8; 4] = b"WALV";
196
197 let mut wal_files: Vec<_> = std::fs::read_dir(dir)?
198 .filter_map(|e| e.ok())
199 .filter(|e| {
200 e.file_name()
201 .to_str()
202 .map(|s| s.starts_with("wal-") && s.ends_with(".log"))
203 .unwrap_or(false)
204 })
205 .collect();
206
207 wal_files.sort_by_key(|e| e.file_name());
209
210 let mut result: Vec<(u64, WalEntry)> = Vec::new();
211
212 for file_entry in wal_files {
213 let path = file_entry.path();
214 let file = match File::open(&path) {
215 Ok(f) => f,
216 Err(e) => {
217 tracing::warn!("PIT: cannot open WAL file {:?}: {}", path, e);
218 continue;
219 }
220 };
221 let mut reader = BufReader::new(file);
222
223 let mut magic = [0u8; 4];
225 if reader.read_exact(&mut magic).is_err() {
226 continue;
227 }
228 if &magic != WAL_MAGIC {
229 tracing::warn!("PIT: invalid magic in {:?}", path);
230 continue;
231 }
232
233 let mut skip = [0u8; 12];
235 if reader.read_exact(&mut skip).is_err() {
236 continue;
237 }
238
239 loop {
240 let mut seq_bytes = [0u8; 8];
242 match reader.read_exact(&mut seq_bytes) {
243 Ok(_) => {}
244 Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => break,
245 Err(e) => {
246 tracing::warn!("PIT: read error in {:?}: {}", path, e);
247 break;
248 }
249 }
250 let seq = u64::from_le_bytes(seq_bytes);
251
252 let mut len_bytes = [0u8; 4];
254 match reader.read_exact(&mut len_bytes) {
255 Ok(_) => {}
256 Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => break,
257 Err(e) => {
258 tracing::warn!("PIT: read error in {:?}: {}", path, e);
259 break;
260 }
261 }
262 let len = u32::from_le_bytes(len_bytes) as usize;
263
264 if len > 100_000_000 {
266 tracing::warn!(
267 "PIT: suspicious entry length {} at seq {} in {:?}",
268 len,
269 seq,
270 path
271 );
272 break;
273 }
274
275 let mut entry_bytes = vec![0u8; len];
277 match reader.read_exact(&mut entry_bytes) {
278 Ok(_) => {}
279 Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => break,
280 Err(e) => {
281 tracing::warn!("PIT: read error in {:?}: {}", path, e);
282 break;
283 }
284 }
285
286 match oxicode::serde::decode_from_slice::<WalEntry, _>(
287 &entry_bytes,
288 oxicode::config::standard(),
289 ) {
290 Ok((entry, _)) => result.push((seq, entry)),
291 Err(e) => {
292 tracing::warn!("PIT: cannot deserialise entry at seq {}: {}", seq, e);
293 }
294 }
295 }
296 }
297
298 Ok(result)
299 }
300}
301
302#[cfg(test)]
307mod tests {
308 use super::*;
309 use crate::wal::{WalConfig, WalEntry, WalManager};
310 use tempfile::TempDir;
311
312 fn write_test_wal(dir: &std::path::Path, entries: &[WalEntry]) -> Result<()> {
313 let config = WalConfig {
314 wal_directory: dir.to_path_buf(),
315 checkpoint_interval: u64::MAX,
316 sync_on_write: true,
317 ..WalConfig::default()
318 };
319 let mgr = WalManager::new(config)?;
320 for entry in entries {
321 mgr.append(entry.clone())?;
322 }
323 mgr.flush()?;
324 Ok(())
325 }
326
327 #[test]
328 fn test_find_base_checkpoint_no_checkpoint() -> Result<()> {
329 let tmp = TempDir::new()?;
330
331 let entries = vec![
332 WalEntry::Insert {
333 id: "v1".into(),
334 vector: vec![1.0],
335 metadata: None,
336 timestamp: 1000,
337 },
338 WalEntry::Insert {
339 id: "v2".into(),
340 vector: vec![2.0],
341 metadata: None,
342 timestamp: 2000,
343 },
344 ];
345 write_test_wal(tmp.path(), &entries)?;
346
347 let pit = PointInTimeRestore::new(5000, tmp.path().to_path_buf());
348 let base = pit.find_base_checkpoint()?;
349 assert!(base.is_none(), "expected None when no checkpoint exists");
350 Ok(())
351 }
352
353 #[test]
354 fn test_find_base_checkpoint_selects_latest_before_target() -> Result<()> {
355 let tmp = TempDir::new()?;
356
357 let entries = vec![
358 WalEntry::Checkpoint {
359 sequence_number: 0,
360 timestamp: 1000,
361 },
362 WalEntry::Checkpoint {
363 sequence_number: 1,
364 timestamp: 3000,
365 },
366 WalEntry::Checkpoint {
367 sequence_number: 2,
368 timestamp: 6000,
369 },
370 ];
371 write_test_wal(tmp.path(), &entries)?;
372
373 let pit = PointInTimeRestore::new(4000, tmp.path().to_path_buf());
374 let base = pit.find_base_checkpoint()?;
375 let base = base.expect("should find a checkpoint");
376 assert_eq!(base.timestamp, 3000, "should pick checkpoint at ts=3000");
377 Ok(())
378 }
379
380 #[test]
381 fn test_replay_wal_to_timestamp_filters_correctly() -> Result<()> {
382 let tmp = TempDir::new()?;
383
384 let raw = vec![
385 WalEntry::Insert {
386 id: "v1".into(),
387 vector: vec![1.0],
388 metadata: None,
389 timestamp: 1000,
390 },
391 WalEntry::Insert {
392 id: "v2".into(),
393 vector: vec![2.0],
394 metadata: None,
395 timestamp: 2000,
396 },
397 WalEntry::Insert {
398 id: "v3".into(),
399 vector: vec![3.0],
400 metadata: None,
401 timestamp: 4000,
402 },
403 ];
404 write_test_wal(tmp.path(), &raw)?;
405
406 let pit = PointInTimeRestore::new(2500, tmp.path().to_path_buf());
408 let replayed = pit.replay_wal_to_timestamp(None)?;
409 assert_eq!(replayed.len(), 2);
410 let ids: Vec<_> = replayed
411 .iter()
412 .filter_map(|e| {
413 if let WalEntry::Insert { id, .. } = e {
414 Some(id.as_str())
415 } else {
416 None
417 }
418 })
419 .collect();
420 assert!(ids.contains(&"v1"));
421 assert!(ids.contains(&"v2"));
422 assert!(!ids.contains(&"v3"));
423
424 Ok(())
425 }
426
427 #[test]
428 fn test_checkpoint_discovery_ordered() -> Result<()> {
429 let tmp = TempDir::new()?;
430
431 let timestamps = [500u64, 1500, 2500, 3500, 4500];
432 let entries: Vec<WalEntry> = timestamps
433 .iter()
434 .enumerate()
435 .map(|(i, &ts)| WalEntry::Checkpoint {
436 sequence_number: i as u64,
437 timestamp: ts,
438 })
439 .collect();
440 write_test_wal(tmp.path(), &entries)?;
441
442 let pit = PointInTimeRestore::new(3000, tmp.path().to_path_buf());
444 let base = pit.find_base_checkpoint()?.expect("checkpoint expected");
445 assert_eq!(base.timestamp, 2500);
446 assert_eq!(base.sequence_number, 2);
447
448 Ok(())
449 }
450}