use crate::checkpoint::CheckpointFile;
use crate::error::PersistenceResult;
use crate::storage::Directory;
use crate::walog::{WalEntry, WalReader, WalReplayMode};
use std::collections::{HashMap, HashSet};
use std::sync::Arc;
#[derive(Debug, Clone, Copy)]
pub struct RecoveryOptions {
pub wal_mode: WalReplayMode,
pub ignore_corrupt_checkpoint: bool,
pub up_to_entry_id: Option<u64>,
}
impl RecoveryOptions {
pub fn strict() -> Self {
Self {
wal_mode: WalReplayMode::Strict,
ignore_corrupt_checkpoint: false,
up_to_entry_id: None,
}
}
pub fn best_effort() -> Self {
Self {
wal_mode: WalReplayMode::BestEffortTail,
ignore_corrupt_checkpoint: true,
up_to_entry_id: None,
}
}
pub fn up_to(entry_id: u64) -> Self {
Self {
wal_mode: WalReplayMode::Strict,
ignore_corrupt_checkpoint: false,
up_to_entry_id: Some(entry_id),
}
}
}
#[derive(Debug, Clone)]
pub struct Recovery<S> {
pub state: S,
pub last_entry_id: u64,
}
pub fn recover_with_wal<C, E, W>(
dir: &Arc<dyn Directory>,
checkpoint_path: Option<&str>,
options: RecoveryOptions,
init: impl FnOnce(Option<C>) -> W,
mut apply: impl FnMut(&mut W, u64, E),
) -> PersistenceResult<Recovery<W>>
where
C: serde::de::DeserializeOwned,
E: serde::de::DeserializeOwned,
{
let (mut state, mut last_entry_id) = match checkpoint_path {
Some(path) if dir.exists(path) => {
let ckpt = CheckpointFile::new(dir.clone());
match ckpt.read_postcard::<C>(path) {
Ok((last_id, checkpoint_state)) => (init(Some(checkpoint_state)), last_id),
Err(e) if options.ignore_corrupt_checkpoint => {
let _ = e; (init(None), 0u64)
}
Err(e) => return Err(e),
}
}
_ => (init(None), 0u64),
};
let wal = WalReader::<E>::new(dir.clone());
let since = last_entry_id;
let ceiling = options.up_to_entry_id;
wal.replay_each_with_mode(options.wal_mode, |record| {
if record.entry_id > since {
if let Some(max) = ceiling {
if record.entry_id > max {
return Ok(());
}
}
last_entry_id = record.entry_id;
apply(&mut state, record.entry_id, record.payload);
}
Ok(())
})?;
Ok(Recovery {
state,
last_entry_id,
})
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct CheckpointState {
pub segments: Vec<CheckpointSegment>,
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct CheckpointSegment {
pub segment_id: u64,
pub doc_count: u32,
pub deleted_docs: Vec<u32>,
}
#[derive(Debug, Clone)]
pub struct RecoveredState {
pub segments: Vec<RecoveredSegment>,
pub last_entry_id: u64,
}
#[derive(Debug, Clone)]
pub struct RecoveredSegment {
pub segment_id: u64,
pub doc_count: u32,
pub deleted_docs: HashSet<u32>,
}
pub struct RecoveryManager {
directory: Arc<dyn Directory>,
}
impl RecoveryManager {
pub fn new(directory: impl Into<Arc<dyn Directory>>) -> Self {
Self {
directory: directory.into(),
}
}
pub fn recover(&self, checkpoint_path: Option<&str>) -> PersistenceResult<RecoveredState> {
self.recover_with_options(checkpoint_path, RecoveryOptions::strict())
}
pub fn recover_best_effort(
&self,
checkpoint_path: Option<&str>,
) -> PersistenceResult<RecoveredState> {
self.recover_with_options(checkpoint_path, RecoveryOptions::best_effort())
}
pub fn recover_latest(&self) -> PersistenceResult<RecoveredState> {
let ckpt = self.latest_checkpoint_from_wal( false)?;
self.recover(ckpt.as_deref())
}
pub fn recover_latest_best_effort(&self) -> PersistenceResult<RecoveredState> {
let ckpt = self.latest_checkpoint_from_wal( true)?;
self.recover_best_effort(ckpt.as_deref())
}
fn latest_checkpoint_from_wal(
&self,
best_effort_wal: bool,
) -> PersistenceResult<Option<String>> {
let wal = WalReader::<WalEntry>::new(self.directory.clone());
let records = if best_effort_wal {
wal.replay_best_effort()?
} else {
wal.replay()?
};
let mut best: Option<(u64, String)> = None;
for r in records {
if let WalEntry::Checkpoint {
checkpoint_path, ..
} = r.payload
{
let entry_id = r.entry_id;
match &best {
None => best = Some((entry_id, checkpoint_path)),
Some((prev_id, _)) if entry_id > *prev_id => {
best = Some((entry_id, checkpoint_path))
}
_ => {}
}
}
}
Ok(best.map(|(_, p)| p))
}
fn recover_with_options(
&self,
checkpoint_path: Option<&str>,
options: RecoveryOptions,
) -> PersistenceResult<RecoveredState> {
let result = recover_with_wal::<CheckpointState, WalEntry, _>(
&self.directory,
checkpoint_path,
options,
|ckpt| {
let mut map: HashMap<u64, RecoveredSegment> = HashMap::new();
if let Some(state) = ckpt {
for s in state.segments {
map.insert(
s.segment_id,
RecoveredSegment {
segment_id: s.segment_id,
doc_count: s.doc_count,
deleted_docs: s.deleted_docs.into_iter().collect(),
},
);
}
}
map
},
|map, _entry_id, entry| {
Self::apply_entry(map, entry);
},
)?;
let mut segments: Vec<RecoveredSegment> = result.state.into_values().collect();
segments.sort_by_key(|s| s.segment_id);
Ok(RecoveredState {
segments,
last_entry_id: result.last_entry_id,
})
}
fn apply_entry(map: &mut HashMap<u64, RecoveredSegment>, entry: WalEntry) {
match entry {
WalEntry::AddSegment {
segment_id,
doc_count,
} => {
map.insert(
segment_id,
RecoveredSegment {
segment_id,
doc_count,
deleted_docs: HashSet::new(),
},
);
}
WalEntry::DeleteDocuments { deletes } => {
for (segment_id, doc_id) in deletes {
if let Some(seg) = map.get_mut(&segment_id) {
seg.deleted_docs.insert(doc_id);
}
}
}
WalEntry::EndMerge {
new_segment_id,
old_segment_ids,
remapped_deletes,
..
} => {
for old in old_segment_ids {
map.remove(&old);
}
let mut new_seg = RecoveredSegment {
segment_id: new_segment_id,
doc_count: 0,
deleted_docs: HashSet::new(),
};
for (seg_id, doc_id) in remapped_deletes {
if seg_id == new_segment_id {
new_seg.deleted_docs.insert(doc_id);
}
}
map.insert(new_segment_id, new_seg);
}
WalEntry::StartMerge { .. }
| WalEntry::CancelMerge { .. }
| WalEntry::Checkpoint { .. } => {}
}
}
pub fn to_checkpoint_state(state: &RecoveredState) -> CheckpointState {
let mut segments: Vec<CheckpointSegment> = state
.segments
.iter()
.map(|s| CheckpointSegment {
segment_id: s.segment_id,
doc_count: s.doc_count,
deleted_docs: {
let mut v: Vec<u32> = s.deleted_docs.iter().copied().collect();
v.sort_unstable();
v
},
})
.collect();
segments.sort_by_key(|s| s.segment_id);
CheckpointState { segments }
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::storage::MemoryDirectory;
use crate::walog::{WalEntry, WalWriter};
use std::io::Read;
#[test]
fn recovery_applies_checkpoint_then_wal() {
let dir: Arc<dyn Directory> = Arc::new(MemoryDirectory::new());
let ckpt = CheckpointFile::new(dir.clone());
let state = CheckpointState {
segments: vec![CheckpointSegment {
segment_id: 7,
doc_count: 3,
deleted_docs: vec![0],
}],
};
ckpt.write_postcard("checkpoints/c1.chk", 0, &state)
.unwrap();
let mut wal = WalWriter::<WalEntry>::new(dir.clone());
wal.append(&WalEntry::DeleteDocuments {
deletes: vec![(7, 2)],
})
.unwrap();
wal.append(&WalEntry::AddSegment {
segment_id: 9,
doc_count: 5,
})
.unwrap();
wal.flush().unwrap();
let rec = RecoveryManager::new(dir)
.recover(Some("checkpoints/c1.chk"))
.unwrap();
assert_eq!(rec.last_entry_id, 2);
assert_eq!(rec.segments.len(), 2);
assert_eq!(rec.segments[0].segment_id, 7);
assert!(rec.segments[0].deleted_docs.contains(&0));
assert!(rec.segments[0].deleted_docs.contains(&2));
assert_eq!(rec.segments[1].segment_id, 9);
}
#[test]
fn recover_strict_errors_on_corrupt_checkpoint() {
let dir: Arc<dyn Directory> = Arc::new(MemoryDirectory::new());
let ckpt = CheckpointFile::new(dir.clone());
let state = CheckpointState { segments: vec![] };
ckpt.write_postcard("checkpoints/c1.chk", 0, &state)
.unwrap();
let mut bytes = {
let mut f = dir.open_file("checkpoints/c1.chk").unwrap();
let mut buf = Vec::new();
f.read_to_end(&mut buf).unwrap();
buf
};
bytes[0] ^= 0xFF;
dir.atomic_write("checkpoints/c1.chk", &bytes).unwrap();
let err = RecoveryManager::new(dir.clone())
.recover(Some("checkpoints/c1.chk"))
.unwrap_err();
assert!(err.to_string().contains("invalid checkpoint magic"));
let ok = RecoveryManager::new(dir)
.recover_best_effort(Some("checkpoints/c1.chk"))
.unwrap();
assert_eq!(ok.segments.len(), 0);
assert_eq!(ok.last_entry_id, 0);
}
#[test]
fn to_checkpoint_state_is_deterministic() {
let mk_state = |delete_insert_order: &[u32], segment_order: &[u64]| {
let mut segments = Vec::new();
for &seg_id in segment_order {
let mut dels = HashSet::new();
for &d in delete_insert_order {
dels.insert(d);
}
segments.push(RecoveredSegment {
segment_id: seg_id,
doc_count: 10,
deleted_docs: dels,
});
}
RecoveredState {
segments,
last_entry_id: 123,
}
};
let s1 = mk_state(&[3, 1, 2], &[9, 7]);
let s2 = mk_state(&[2, 3, 1], &[7, 9]);
let c1 = RecoveryManager::to_checkpoint_state(&s1);
let c2 = RecoveryManager::to_checkpoint_state(&s2);
assert_eq!(c1.segments.len(), 2);
assert_eq!(c2.segments.len(), 2);
let dir: Arc<dyn Directory> = Arc::new(MemoryDirectory::new());
let ckpt = CheckpointFile::new(dir.clone());
ckpt.write_postcard("checkpoints/a.chk", 123, &c1).unwrap();
ckpt.write_postcard("checkpoints/b.chk", 123, &c2).unwrap();
let mut a = Vec::new();
dir.open_file("checkpoints/a.chk")
.unwrap()
.read_to_end(&mut a)
.unwrap();
let mut b = Vec::new();
dir.open_file("checkpoints/b.chk")
.unwrap()
.read_to_end(&mut b)
.unwrap();
assert_eq!(a, b);
}
}