use std::fs::{self, File};
use std::io::{BufReader, Read, Seek, SeekFrom};
use std::path::{Path, PathBuf};
use std::time::{Duration, Instant};
use super::arena::{ArenaValidation, CharNodeArena, HEADER_SIZE as ARENA_HEADER_SIZE};
use super::dict_impl_char::{CharTrieFileHeader, CHAR_FILE_HEADER_SIZE, CHAR_TRIE_MAGIC};
use crate::persistent_artrie::disk_manager::BLOCK_SIZE;
use crate::persistent_artrie::error::{PersistentARTrieError, Result};
use crate::persistent_artrie::wal::{Lsn, WalConfig, WalReader, WalRecord, WalWriter};
pub use crate::persistent_artrie::recovery::{
find_wal_archive_segments, rebuild_from_wal_segments, IncrementalRecovery, RecoveredState,
RecoveryError, RecoveryStats,
};
fn io_err(operation: &str, path: &Path, e: std::io::Error) -> PersistentARTrieError {
PersistentARTrieError::IoError {
operation: operation.to_string(),
path: path.display().to_string(),
source: e,
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum RecoveryMode {
Normal {
wal_records_replayed: usize,
},
PartialRecovery {
corrupted_arenas: Vec<u32>,
recovered_records: usize,
},
RebuildFromWal {
segments_processed: usize,
records_replayed: usize,
},
Unrecoverable {
reason: String,
},
}
impl RecoveryMode {
pub fn is_success(&self) -> bool {
!matches!(self, RecoveryMode::Unrecoverable { .. })
}
pub fn records_replayed(&self) -> usize {
match self {
RecoveryMode::Normal {
wal_records_replayed,
} => *wal_records_replayed,
RecoveryMode::PartialRecovery {
recovered_records, ..
} => *recovered_records,
RecoveryMode::RebuildFromWal {
records_replayed, ..
} => *records_replayed,
RecoveryMode::Unrecoverable { .. } => 0,
}
}
}
#[derive(Debug, Clone)]
pub struct RecoveryReport {
pub mode: RecoveryMode,
pub duration: Duration,
pub records_replayed: usize,
pub checkpoint_lsn: Option<Lsn>,
pub segments_processed: usize,
pub corrupted_records_skipped: usize,
}
impl RecoveryReport {
pub fn normal() -> Self {
Self {
mode: RecoveryMode::Normal {
wal_records_replayed: 0,
},
duration: Duration::ZERO,
records_replayed: 0,
checkpoint_lsn: None,
segments_processed: 0,
corrupted_records_skipped: 0,
}
}
}
#[derive(Debug, Clone)]
pub struct CorruptionInfo {
pub corruption_type: CorruptionType,
pub corrupted_arenas: Vec<u32>,
pub wal_available: bool,
pub wal_segments: usize,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum CorruptionType {
HeaderChecksum { stored: u32, computed: u32 },
InvalidMagic,
ArenaChecksum {
count: usize,
},
Truncated { expected: u64, actual: u64 },
FileNotReadable { reason: String },
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum RecoveryPolicy {
#[default]
AutoRecover,
FailOnCorruption,
RecoverPartial,
}
pub fn detect_corruption(path: &Path, validate_arenas: bool) -> Result<Option<CorruptionInfo>> {
if !path.exists() {
return Ok(None); }
let file = match File::open(path) {
Ok(f) => f,
Err(e) => {
return Ok(Some(CorruptionInfo {
corruption_type: CorruptionType::FileNotReadable {
reason: e.to_string(),
},
corrupted_arenas: vec![],
wal_available: false,
wal_segments: 0,
}));
}
};
let metadata = file.metadata().map_err(|e| io_err("metadata", path, e))?;
let file_size = metadata.len();
if file_size < CHAR_FILE_HEADER_SIZE as u64 {
return Ok(Some(CorruptionInfo {
corruption_type: CorruptionType::Truncated {
expected: CHAR_FILE_HEADER_SIZE as u64,
actual: file_size,
},
corrupted_arenas: vec![],
wal_available: check_wal_available(path),
wal_segments: count_wal_segments(path),
}));
}
let mut reader = BufReader::new(file);
let mut header_buf = [0u8; CHAR_FILE_HEADER_SIZE];
reader
.read_exact(&mut header_buf)
.map_err(|e| io_err("read header", path, e))?;
let header = CharTrieFileHeader::from_bytes(&header_buf);
if header.magic != CHAR_TRIE_MAGIC {
return Ok(Some(CorruptionInfo {
corruption_type: CorruptionType::InvalidMagic,
corrupted_arenas: vec![],
wal_available: check_wal_available(path),
wal_segments: count_wal_segments(path),
}));
}
if header.has_checksum() && !header.verify_checksum() {
return Ok(Some(CorruptionInfo {
corruption_type: CorruptionType::HeaderChecksum {
stored: header.header_checksum,
computed: header.compute_checksum(),
},
corrupted_arenas: vec![],
wal_available: check_wal_available(path),
wal_segments: count_wal_segments(path),
}));
}
if validate_arenas {
let corrupted = validate_all_arenas(&mut reader, file_size);
if !corrupted.is_empty() {
return Ok(Some(CorruptionInfo {
corruption_type: CorruptionType::ArenaChecksum {
count: corrupted.len(),
},
corrupted_arenas: corrupted,
wal_available: check_wal_available(path),
wal_segments: count_wal_segments(path),
}));
}
}
Ok(None)
}
fn validate_all_arenas<R: Read + Seek>(reader: &mut R, file_size: u64) -> Vec<u32> {
let mut corrupted = Vec::new();
let mut block_id = 0u32;
let mut offset = CHAR_FILE_HEADER_SIZE as u64;
offset = (offset + BLOCK_SIZE as u64 - 1) / BLOCK_SIZE as u64 * BLOCK_SIZE as u64;
while offset + BLOCK_SIZE as u64 <= file_size {
if reader.seek(SeekFrom::Start(offset)).is_err() {
break;
}
let mut block_buf = vec![0u8; BLOCK_SIZE];
if reader.read_exact(&mut block_buf).is_err() {
break; }
if block_buf.len() >= ARENA_HEADER_SIZE {
match CharNodeArena::validate_checksums(&block_buf) {
Ok(ArenaValidation::Valid) => {}
Ok(ArenaValidation::HeaderChecksumMismatch { .. })
| Ok(ArenaValidation::DataChecksumMismatch { .. })
| Ok(ArenaValidation::Truncated { .. }) => {
corrupted.push(block_id);
}
Ok(ArenaValidation::InvalidMagic) => {
}
Err(_) => {
corrupted.push(block_id);
}
}
}
offset += BLOCK_SIZE as u64;
block_id += 1;
}
corrupted
}
fn check_wal_available(trie_path: &Path) -> bool {
let wal_path = trie_path.with_extension("wal");
wal_path.exists()
}
fn count_wal_segments(trie_path: &Path) -> usize {
let wal_path = trie_path.with_extension("wal");
let archive_dir = trie_path
.parent()
.unwrap_or(Path::new("."))
.join("wal_archive");
let mut count = 0;
if wal_path.exists() {
count += 1;
}
if archive_dir.exists() {
if let Ok(entries) = fs::read_dir(&archive_dir) {
count += entries
.filter_map(|e| e.ok())
.filter(|e| e.path().extension().map_or(false, |ext| ext == "segment"))
.count();
}
}
count
}
pub struct RecoveryManager {
trie_path: PathBuf,
wal_path: PathBuf,
wal_config: WalConfig,
}
impl RecoveryManager {
pub fn new(trie_path: impl AsRef<Path>, wal_config: WalConfig) -> Self {
let trie_path = trie_path.as_ref().to_path_buf();
let wal_path = trie_path.with_extension("wal");
Self {
trie_path,
wal_path,
wal_config,
}
}
pub fn needs_recovery(&self) -> Result<bool> {
detect_corruption(&self.trie_path, false).map(|opt| opt.is_some())
}
pub fn recover_with_callback<F>(
&self,
policy: RecoveryPolicy,
mut apply_fn: F,
) -> Result<RecoveryReport>
where
F: FnMut(RecoveredOperation) -> Result<()>,
{
let start = Instant::now();
let corruption = detect_corruption(&self.trie_path, true)?;
match (corruption, policy) {
(None, _) => {
let replay_count = self.replay_wal_after_checkpoint(&mut apply_fn)?;
Ok(RecoveryReport {
mode: RecoveryMode::Normal {
wal_records_replayed: replay_count,
},
duration: start.elapsed(),
records_replayed: replay_count,
checkpoint_lsn: self.get_checkpoint_lsn()?,
segments_processed: 1,
corrupted_records_skipped: 0,
})
}
(Some(info), RecoveryPolicy::FailOnCorruption) => {
Err(PersistentARTrieError::CorruptedFile {
reason: format!("{:?}", info.corruption_type),
})
}
(Some(info), RecoveryPolicy::AutoRecover | RecoveryPolicy::RecoverPartial) => {
if !info.wal_available && info.wal_segments == 0 {
return Err(PersistentARTrieError::RecoveryError {
reason: "No WAL available for recovery".to_string(),
});
}
self.rebuild_from_wal(&mut apply_fn, start)
}
}
}
fn replay_wal_after_checkpoint<F>(&self, apply_fn: &mut F) -> Result<usize>
where
F: FnMut(RecoveredOperation) -> Result<()>,
{
if !self.wal_path.exists() {
return Ok(0);
}
let checkpoint_lsn = self.get_checkpoint_lsn()?.unwrap_or(0);
let reader = match WalReader::new(&self.wal_path) {
Ok(r) => r,
Err(_) => return Ok(0),
};
let mut replayed = 0;
for result in reader.iter() {
match result {
Ok((lsn, record)) => {
if lsn <= checkpoint_lsn {
continue;
}
for op in self.record_to_operations(lsn, record) {
apply_fn(op)?;
replayed += 1;
}
}
Err(_) => {
break;
}
}
}
Ok(replayed)
}
fn rebuild_from_wal<F>(&self, apply_fn: &mut F, start: Instant) -> Result<RecoveryReport>
where
F: FnMut(RecoveredOperation) -> Result<()>,
{
let wal_writer = if self.wal_path.exists() {
WalWriter::open(&self.wal_path)?
} else {
return Err(PersistentARTrieError::RecoveryError {
reason: "WAL file not found".to_string(),
});
};
let segments = wal_writer.collect_wal_segments(&self.wal_config)?;
let segment_count = segments.len();
if segments.is_empty() {
return Ok(RecoveryReport {
mode: RecoveryMode::RebuildFromWal {
segments_processed: 0,
records_replayed: 0,
},
duration: start.elapsed(),
records_replayed: 0,
checkpoint_lsn: None,
segments_processed: 0,
corrupted_records_skipped: 0,
});
}
let mut replayed = 0;
let mut corrupted_skipped = 0;
'segments: for segment_path in &segments {
let reader = match WalReader::new(segment_path) {
Ok(r) => r,
Err(_) => continue, };
for result in reader.iter() {
match result {
Ok((lsn, record)) => {
for op in self.record_to_operations(lsn, record) {
apply_fn(op)?;
replayed += 1;
}
}
Err(_) => {
corrupted_skipped += 1;
break 'segments;
}
}
}
}
Ok(RecoveryReport {
mode: RecoveryMode::RebuildFromWal {
segments_processed: segment_count,
records_replayed: replayed,
},
duration: start.elapsed(),
records_replayed: replayed,
checkpoint_lsn: None,
segments_processed: segment_count,
corrupted_records_skipped: corrupted_skipped,
})
}
fn get_checkpoint_lsn(&self) -> Result<Option<Lsn>> {
if !self.trie_path.exists() {
return Ok(None);
}
let mut file =
File::open(&self.trie_path).map_err(|e| io_err("open", &self.trie_path, e))?;
let mut header_buf = [0u8; CHAR_FILE_HEADER_SIZE];
file.read_exact(&mut header_buf)
.map_err(|e| io_err("read header", &self.trie_path, e))?;
let header = CharTrieFileHeader::from_bytes(&header_buf);
if header.checkpoint_lsn > 0 {
Ok(Some(header.checkpoint_lsn))
} else {
Ok(None)
}
}
fn record_to_operations(&self, lsn: Lsn, record: WalRecord) -> Vec<RecoveredOperation> {
crate::persistent_artrie::recovery::recovered_operations_from_record(lsn, record)
.into_iter()
.map(RecoveredOperation::from)
.collect()
}
}
#[derive(Debug, Clone)]
pub enum RecoveredOperation {
Insert {
lsn: Lsn,
term: Vec<u8>,
value: Option<Vec<u8>>,
},
Remove {
lsn: Lsn,
term: Vec<u8>,
},
Increment {
lsn: Lsn,
term: Vec<u8>,
delta: i64,
result: Option<i64>,
},
Upsert {
lsn: Lsn,
term: Vec<u8>,
value: Vec<u8>,
},
CompareAndSwap {
lsn: Lsn,
term: Vec<u8>,
new_value: Vec<u8>,
success: bool,
},
}
impl From<crate::persistent_artrie::recovery::RecoveredOperation> for RecoveredOperation {
fn from(op: crate::persistent_artrie::recovery::RecoveredOperation) -> Self {
match op {
crate::persistent_artrie::recovery::RecoveredOperation::Insert { lsn, term, value } => {
Self::Insert { lsn, term, value }
}
crate::persistent_artrie::recovery::RecoveredOperation::Remove { lsn, term } => {
Self::Remove { lsn, term }
}
crate::persistent_artrie::recovery::RecoveredOperation::Increment {
lsn,
term,
delta,
result,
} => Self::Increment {
lsn,
term,
delta,
result,
},
crate::persistent_artrie::recovery::RecoveredOperation::Upsert { lsn, term, value } => {
Self::Upsert { lsn, term, value }
}
crate::persistent_artrie::recovery::RecoveredOperation::CompareAndSwap {
lsn,
term,
new_value,
success,
} => Self::CompareAndSwap {
lsn,
term,
new_value,
success,
},
}
}
}
impl RecoveredOperation {
pub fn term_str(&self) -> Option<&str> {
let bytes = match self {
Self::Insert { term, .. } => term,
Self::Remove { term, .. } => term,
Self::Increment { term, .. } => term,
Self::Upsert { term, .. } => term,
Self::CompareAndSwap { term, .. } => term,
};
std::str::from_utf8(bytes).ok()
}
pub fn lsn(&self) -> Lsn {
match self {
Self::Insert { lsn, .. } => *lsn,
Self::Remove { lsn, .. } => *lsn,
Self::Increment { lsn, .. } => *lsn,
Self::Upsert { lsn, .. } => *lsn,
Self::CompareAndSwap { lsn, .. } => *lsn,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::tempdir;
#[test]
fn test_recovery_mode_is_success() {
assert!(RecoveryMode::Normal {
wal_records_replayed: 0
}
.is_success());
assert!(RecoveryMode::PartialRecovery {
corrupted_arenas: vec![],
recovered_records: 0
}
.is_success());
assert!(RecoveryMode::RebuildFromWal {
segments_processed: 0,
records_replayed: 0
}
.is_success());
assert!(!RecoveryMode::Unrecoverable {
reason: "test".to_string()
}
.is_success());
}
#[test]
fn test_detect_corruption_missing_file() {
let dir = tempdir().expect("create tempdir");
let path = dir.path().join("nonexistent.artrie");
let result = detect_corruption(&path, false).expect("detect_corruption");
assert!(result.is_none(), "Missing file should not be corruption");
}
#[test]
fn test_detect_corruption_truncated_file() {
let dir = tempdir().expect("create tempdir");
let path = dir.path().join("truncated.artrie");
fs::write(&path, &[0u8; 10]).expect("write file");
let result = detect_corruption(&path, false).expect("detect_corruption");
assert!(result.is_some());
match result.unwrap().corruption_type {
CorruptionType::Truncated { expected, actual } => {
assert_eq!(expected, CHAR_FILE_HEADER_SIZE as u64);
assert_eq!(actual, 10);
}
_ => panic!("Expected Truncated corruption type"),
}
}
#[test]
fn test_detect_corruption_invalid_magic() {
let dir = tempdir().expect("create tempdir");
let path = dir.path().join("bad_magic.artrie");
let mut data = [0u8; CHAR_FILE_HEADER_SIZE];
data[0..4].copy_from_slice(b"XXXX"); fs::write(&path, &data).expect("write file");
let result = detect_corruption(&path, false).expect("detect_corruption");
assert!(result.is_some());
assert!(matches!(
result.unwrap().corruption_type,
CorruptionType::InvalidMagic
));
}
#[test]
fn test_corruption_info_wal_check() {
let dir = tempdir().expect("create tempdir");
let trie_path = dir.path().join("test.artrie");
let wal_path = dir.path().join("test.wal");
fs::write(&trie_path, &[0u8; 10]).expect("write trie");
fs::write(&wal_path, &[0u8; 100]).expect("write wal");
assert!(check_wal_available(&trie_path));
assert!(count_wal_segments(&trie_path) >= 1);
}
#[test]
fn test_recovery_manager_no_file() {
let dir = tempdir().expect("create tempdir");
let path = dir.path().join("missing.artrie");
let config = WalConfig::default();
let manager = RecoveryManager::new(&path, config);
assert!(!manager.needs_recovery().expect("needs_recovery"));
}
#[test]
fn test_recovered_operation_term_str() {
let op = RecoveredOperation::Insert {
lsn: 1,
term: b"hello".to_vec(),
value: None,
};
assert_eq!(op.term_str(), Some("hello"));
assert_eq!(op.lsn(), 1);
let op = RecoveredOperation::Insert {
lsn: 2,
term: vec![0xFF, 0xFE],
value: None,
};
assert_eq!(op.term_str(), None);
}
#[test]
fn test_recovery_report_normal() {
let report = RecoveryReport::normal();
assert!(report.mode.is_success());
assert_eq!(report.records_replayed, 0);
}
}