use serde::{Deserialize, Serialize};
use std::collections::VecDeque;
use std::path::PathBuf;
use std::sync::Arc;
use std::time::SystemTime;
use tokio::sync::RwLock;
use tracing::{error, info};
use crate::raft::OxirsNodeId;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CheckpointConfig {
pub checkpoint_interval_secs: u64,
pub max_checkpoints: usize,
pub checkpoint_dir: PathBuf,
pub enable_compression: bool,
pub compression_level: u32,
pub enable_incremental: bool,
}
impl Default for CheckpointConfig {
fn default() -> Self {
Self {
checkpoint_interval_secs: 300, max_checkpoints: 10,
checkpoint_dir: PathBuf::from("./checkpoints"),
enable_compression: true,
compression_level: 6,
enable_incremental: true,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WalConfig {
pub wal_dir: PathBuf,
pub max_wal_size_bytes: usize,
pub sync_on_write: bool,
pub buffer_size: usize,
pub enable_compression: bool,
}
impl Default for WalConfig {
fn default() -> Self {
Self {
wal_dir: PathBuf::from("./wal"),
max_wal_size_bytes: 100 * 1024 * 1024, sync_on_write: true,
buffer_size: 1000,
enable_compression: false,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RecoveryConfig {
pub checkpoint_config: CheckpointConfig,
pub wal_config: WalConfig,
pub enable_auto_recovery: bool,
pub max_recovery_time_secs: u64,
pub enable_corruption_detection: bool,
pub enable_recovery_verification: bool,
}
impl Default for RecoveryConfig {
fn default() -> Self {
Self {
checkpoint_config: CheckpointConfig::default(),
wal_config: WalConfig::default(),
enable_auto_recovery: true,
max_recovery_time_secs: 300,
enable_corruption_detection: true,
enable_recovery_verification: true,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CheckpointMetadata {
pub checkpoint_id: String,
pub node_id: OxirsNodeId,
pub timestamp: SystemTime,
pub sequence_number: u64,
pub state_size_bytes: usize,
pub compressed_size_bytes: usize,
pub checksum: String,
pub is_incremental: bool,
pub base_checkpoint_id: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WalEntry {
pub sequence_number: u64,
pub timestamp: SystemTime,
pub operation_type: WalOperationType,
pub data: Vec<u8>,
pub checksum: String,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
pub enum WalOperationType {
Insert,
Delete,
Update,
TransactionBegin,
TransactionCommit,
TransactionRollback,
Checkpoint,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum RecoveryState {
Idle,
LoadingCheckpoint,
ReplayingWal,
Verifying,
Completed,
Failed,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RecoveryStats {
pub total_checkpoints: u64,
pub total_wal_entries: u64,
pub total_recovery_attempts: u64,
pub successful_recoveries: u64,
pub failed_recoveries: u64,
pub last_checkpoint: Option<SystemTime>,
pub last_recovery: Option<SystemTime>,
pub avg_checkpoint_size_bytes: f64,
pub avg_recovery_time_ms: f64,
pub current_wal_size_bytes: usize,
pub corruption_events: u64,
}
impl Default for RecoveryStats {
fn default() -> Self {
Self {
total_checkpoints: 0,
total_wal_entries: 0,
total_recovery_attempts: 0,
successful_recoveries: 0,
failed_recoveries: 0,
last_checkpoint: None,
last_recovery: None,
avg_checkpoint_size_bytes: 0.0,
avg_recovery_time_ms: 0.0,
current_wal_size_bytes: 0,
corruption_events: 0,
}
}
}
pub struct CrashRecoveryManager {
config: RecoveryConfig,
node_id: OxirsNodeId,
checkpoints: Arc<RwLock<Vec<CheckpointMetadata>>>,
wal_buffer: Arc<RwLock<VecDeque<WalEntry>>>,
sequence_number: Arc<RwLock<u64>>,
recovery_state: Arc<RwLock<RecoveryState>>,
stats: Arc<RwLock<RecoveryStats>>,
}
impl CrashRecoveryManager {
pub fn new(node_id: OxirsNodeId, config: RecoveryConfig) -> Self {
Self {
config,
node_id,
checkpoints: Arc::new(RwLock::new(Vec::new())),
wal_buffer: Arc::new(RwLock::new(VecDeque::new())),
sequence_number: Arc::new(RwLock::new(0)),
recovery_state: Arc::new(RwLock::new(RecoveryState::Idle)),
stats: Arc::new(RwLock::new(RecoveryStats::default())),
}
}
pub async fn create_checkpoint(&self, state_data: &[u8]) -> Result<String, String> {
let start = std::time::Instant::now();
let checkpoint_id = format!(
"checkpoint-{}-{}",
self.node_id,
SystemTime::now()
.duration_since(SystemTime::UNIX_EPOCH)
.expect("SystemTime should be after UNIX_EPOCH")
.as_secs()
);
let checksum = Self::calculate_checksum(state_data);
let (compressed_data, compressed_size) = if self.config.checkpoint_config.enable_compression
{
Self::compress_data(state_data, self.config.checkpoint_config.compression_level)
} else {
(state_data.to_vec(), state_data.len())
};
let mut sequence_number = self.sequence_number.write().await;
*sequence_number += 1;
let metadata = CheckpointMetadata {
checkpoint_id: checkpoint_id.clone(),
node_id: self.node_id,
timestamp: SystemTime::now(),
sequence_number: *sequence_number,
state_size_bytes: state_data.len(),
compressed_size_bytes: compressed_size,
checksum,
is_incremental: false,
base_checkpoint_id: None,
};
self.save_checkpoint_to_disk(&checkpoint_id, &compressed_data)
.await?;
let mut checkpoints = self.checkpoints.write().await;
checkpoints.push(metadata.clone());
if checkpoints.len() > self.config.checkpoint_config.max_checkpoints {
let old_checkpoint = checkpoints.remove(0);
self.delete_checkpoint(&old_checkpoint.checkpoint_id)
.await?;
}
let mut stats = self.stats.write().await;
stats.total_checkpoints += 1;
stats.last_checkpoint = Some(SystemTime::now());
let total = stats.total_checkpoints as f64;
stats.avg_checkpoint_size_bytes =
(stats.avg_checkpoint_size_bytes * (total - 1.0) + state_data.len() as f64) / total;
info!(
"Created checkpoint {} ({} bytes, compressed to {} bytes) in {:?}",
checkpoint_id,
state_data.len(),
compressed_size,
start.elapsed()
);
Ok(checkpoint_id)
}
pub async fn write_wal_entry(
&self,
operation_type: WalOperationType,
data: Vec<u8>,
) -> Result<u64, String> {
let mut sequence_number = self.sequence_number.write().await;
*sequence_number += 1;
let entry = WalEntry {
sequence_number: *sequence_number,
timestamp: SystemTime::now(),
operation_type,
data: data.clone(),
checksum: Self::calculate_checksum(&data),
};
let mut wal_buffer = self.wal_buffer.write().await;
wal_buffer.push_back(entry.clone());
let mut stats = self.stats.write().await;
stats.total_wal_entries += 1;
stats.current_wal_size_bytes += data.len();
if wal_buffer.len() >= self.config.wal_config.buffer_size
|| self.config.wal_config.sync_on_write
{
drop(wal_buffer);
drop(stats);
self.flush_wal().await?;
}
Ok(*sequence_number)
}
async fn flush_wal(&self) -> Result<(), String> {
let mut wal_buffer = self.wal_buffer.write().await;
if wal_buffer.is_empty() {
return Ok(());
}
for _entry in wal_buffer.iter() {
}
wal_buffer.clear();
Ok(())
}
pub async fn recover(&self) -> Result<(), String> {
if !self.config.enable_auto_recovery {
return Err("Auto recovery is disabled".to_string());
}
let start = std::time::Instant::now();
let mut stats = self.stats.write().await;
stats.total_recovery_attempts += 1;
drop(stats);
*self.recovery_state.write().await = RecoveryState::LoadingCheckpoint;
info!("Starting crash recovery for node {}", self.node_id);
let checkpoint_result = self.load_latest_checkpoint().await;
if checkpoint_result.is_err() {
if self.config.enable_corruption_detection {
error!("Checkpoint loading failed, attempting repair");
self.detect_and_repair_corruption().await?;
} else {
*self.recovery_state.write().await = RecoveryState::Failed;
let mut stats = self.stats.write().await;
stats.failed_recoveries += 1;
return Err("Checkpoint loading failed".to_string());
}
}
*self.recovery_state.write().await = RecoveryState::ReplayingWal;
self.replay_wal_from_checkpoint().await?;
if self.config.enable_recovery_verification {
*self.recovery_state.write().await = RecoveryState::Verifying;
self.verify_recovery().await?;
}
*self.recovery_state.write().await = RecoveryState::Completed;
let recovery_time = start.elapsed();
let mut stats = self.stats.write().await;
stats.successful_recoveries += 1;
stats.last_recovery = Some(SystemTime::now());
let total = stats.successful_recoveries as f64;
stats.avg_recovery_time_ms =
(stats.avg_recovery_time_ms * (total - 1.0) + recovery_time.as_millis() as f64) / total;
info!("Crash recovery completed in {:?}", recovery_time);
Ok(())
}
async fn load_latest_checkpoint(&self) -> Result<Vec<u8>, String> {
let checkpoints = self.checkpoints.read().await;
if checkpoints.is_empty() {
return Err("No checkpoints available".to_string());
}
let latest = checkpoints
.last()
.expect("collection validated to be non-empty");
info!(
"Loading checkpoint {} (sequence: {})",
latest.checkpoint_id, latest.sequence_number
);
let data = self
.load_checkpoint_from_disk(&latest.checkpoint_id)
.await?;
let checksum = Self::calculate_checksum(&data);
if checksum != latest.checksum {
return Err("Checkpoint checksum mismatch".to_string());
}
if self.config.checkpoint_config.enable_compression {
Self::decompress_data(&data)
} else {
Ok(data)
}
}
async fn replay_wal_from_checkpoint(&self) -> Result<(), String> {
let checkpoints = self.checkpoints.read().await;
if checkpoints.is_empty() {
return Ok(());
}
let latest = checkpoints
.last()
.expect("collection validated to be non-empty");
let checkpoint_seq = latest.sequence_number;
info!("Replaying WAL from sequence {}", checkpoint_seq);
let wal_entries = self.load_wal_entries_after(checkpoint_seq).await?;
for entry in wal_entries {
if Self::calculate_checksum(&entry.data) != entry.checksum {
if self.config.enable_corruption_detection {
let mut stats = self.stats.write().await;
stats.corruption_events += 1;
continue; } else {
return Err(format!(
"WAL entry {} checksum mismatch",
entry.sequence_number
));
}
}
self.replay_operation(&entry).await?;
}
Ok(())
}
async fn verify_recovery(&self) -> Result<(), String> {
info!("Verifying recovery...");
Ok(())
}
async fn detect_and_repair_corruption(&self) -> Result<(), String> {
let mut stats = self.stats.write().await;
stats.corruption_events += 1;
drop(stats);
info!("Detecting and repairing corruption...");
Ok(())
}
pub async fn get_stats(&self) -> RecoveryStats {
self.stats.read().await.clone()
}
pub async fn get_recovery_state(&self) -> RecoveryState {
*self.recovery_state.read().await
}
pub async fn get_checkpoints(&self) -> Vec<CheckpointMetadata> {
self.checkpoints.read().await.clone()
}
pub async fn clear(&self) {
self.checkpoints.write().await.clear();
self.wal_buffer.write().await.clear();
*self.sequence_number.write().await = 0;
*self.recovery_state.write().await = RecoveryState::Idle;
*self.stats.write().await = RecoveryStats::default();
}
fn calculate_checksum(data: &[u8]) -> String {
format!("{:x}", data.len()) }
fn compress_data(data: &[u8], _level: u32) -> (Vec<u8>, usize) {
let compressed = data.to_vec(); (compressed.clone(), compressed.len())
}
fn decompress_data(data: &[u8]) -> Result<Vec<u8>, String> {
Ok(data.to_vec())
}
async fn save_checkpoint_to_disk(&self, _id: &str, _data: &[u8]) -> Result<(), String> {
Ok(())
}
async fn load_checkpoint_from_disk(&self, _id: &str) -> Result<Vec<u8>, String> {
Ok(vec![0u8; 100])
}
async fn delete_checkpoint(&self, _id: &str) -> Result<(), String> {
Ok(())
}
async fn load_wal_entries_after(&self, _seq: u64) -> Result<Vec<WalEntry>, String> {
Ok(vec![])
}
async fn replay_operation(&self, _entry: &WalEntry) -> Result<(), String> {
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_crash_recovery_creation() {
let config = RecoveryConfig::default();
let manager = CrashRecoveryManager::new(1, config);
let stats = manager.get_stats().await;
assert_eq!(stats.total_checkpoints, 0);
assert_eq!(stats.total_wal_entries, 0);
}
#[tokio::test]
async fn test_create_checkpoint() {
let config = RecoveryConfig::default();
let manager = CrashRecoveryManager::new(1, config);
let state_data = vec![1, 2, 3, 4, 5];
let result = manager.create_checkpoint(&state_data).await;
assert!(result.is_ok());
let stats = manager.get_stats().await;
assert_eq!(stats.total_checkpoints, 1);
assert!(stats.last_checkpoint.is_some());
}
#[tokio::test]
async fn test_write_wal_entry() {
let config = RecoveryConfig::default();
let manager = CrashRecoveryManager::new(1, config);
let data = vec![1, 2, 3];
let result = manager
.write_wal_entry(WalOperationType::Insert, data)
.await;
assert!(result.is_ok());
let stats = manager.get_stats().await;
assert_eq!(stats.total_wal_entries, 1);
}
#[tokio::test]
async fn test_multiple_wal_entries() {
let config = RecoveryConfig::default();
let manager = CrashRecoveryManager::new(1, config);
for i in 0..10 {
let data = vec![i];
let result = manager
.write_wal_entry(WalOperationType::Insert, data)
.await;
assert!(result.is_ok());
}
let stats = manager.get_stats().await;
assert_eq!(stats.total_wal_entries, 10);
}
#[tokio::test]
async fn test_checkpoint_rotation() {
let mut config = RecoveryConfig::default();
config.checkpoint_config.max_checkpoints = 3;
let manager = CrashRecoveryManager::new(1, config);
for i in 0..5 {
let data = vec![i; 100];
let _result = manager.create_checkpoint(&data).await;
}
let checkpoints = manager.get_checkpoints().await;
assert_eq!(checkpoints.len(), 3);
}
#[tokio::test]
async fn test_recovery_state_transitions() {
let config = RecoveryConfig::default();
let manager = CrashRecoveryManager::new(1, config);
let state = manager.get_recovery_state().await;
assert_eq!(state, RecoveryState::Idle);
let data = vec![1, 2, 3];
let _result = manager.create_checkpoint(&data).await;
let _result = manager.recover().await;
let state = manager.get_recovery_state().await;
assert!(state == RecoveryState::Completed || state == RecoveryState::Failed);
}
#[tokio::test]
async fn test_recovery_stats() {
let config = RecoveryConfig::default();
let manager = CrashRecoveryManager::new(1, config);
let data = vec![1, 2, 3];
let _result = manager.create_checkpoint(&data).await;
let _result = manager.recover().await;
let stats = manager.get_stats().await;
assert_eq!(stats.total_recovery_attempts, 1);
}
#[tokio::test]
async fn test_compression_disabled() {
let mut config = RecoveryConfig::default();
config.checkpoint_config.enable_compression = false;
let manager = CrashRecoveryManager::new(1, config);
let data = vec![1, 2, 3, 4, 5];
let result = manager.create_checkpoint(&data).await;
assert!(result.is_ok());
let checkpoints = manager.get_checkpoints().await;
assert_eq!(checkpoints[0].state_size_bytes, 5);
}
#[tokio::test]
async fn test_clear() {
let config = RecoveryConfig::default();
let manager = CrashRecoveryManager::new(1, config);
let data = vec![1, 2, 3];
let _result = manager.create_checkpoint(&data).await;
let _result = manager
.write_wal_entry(WalOperationType::Insert, data)
.await;
manager.clear().await;
let stats = manager.get_stats().await;
assert_eq!(stats.total_checkpoints, 0);
assert_eq!(stats.total_wal_entries, 0);
let checkpoints = manager.get_checkpoints().await;
assert!(checkpoints.is_empty());
}
#[tokio::test]
async fn test_wal_operation_types() {
let config = RecoveryConfig::default();
let manager = CrashRecoveryManager::new(1, config);
let operations = [
WalOperationType::Insert,
WalOperationType::Delete,
WalOperationType::Update,
WalOperationType::TransactionBegin,
WalOperationType::TransactionCommit,
WalOperationType::TransactionRollback,
];
for op in operations.iter() {
let data = vec![1, 2, 3];
let result = manager.write_wal_entry(*op, data).await;
assert!(result.is_ok());
}
let stats = manager.get_stats().await;
assert_eq!(stats.total_wal_entries, 6);
}
#[tokio::test]
async fn test_checkpoint_metadata() {
let config = RecoveryConfig::default();
let manager = CrashRecoveryManager::new(1, config);
let data = vec![1, 2, 3, 4, 5];
let _result = manager.create_checkpoint(&data).await;
let checkpoints = manager.get_checkpoints().await;
assert_eq!(checkpoints.len(), 1);
let checkpoint = &checkpoints[0];
assert_eq!(checkpoint.node_id, 1);
assert_eq!(checkpoint.state_size_bytes, 5);
assert!(!checkpoint.is_incremental);
assert!(checkpoint.base_checkpoint_id.is_none());
}
#[test]
fn test_wal_operation_type_ordering() {
assert!(WalOperationType::Insert < WalOperationType::Delete);
assert!(WalOperationType::TransactionBegin < WalOperationType::TransactionCommit);
}
}