use crate::incremental::index_manager::{IndexUpdate, UpdateResult};
use crate::RragResult;
use serde::{Deserialize, Serialize};
use std::collections::{HashMap, VecDeque};
use std::sync::Arc;
use tokio::sync::RwLock;
use uuid::Uuid;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RollbackConfig {
pub max_operation_log_size: usize,
pub enable_snapshots: bool,
pub snapshot_interval: usize,
pub max_snapshots: usize,
pub enable_auto_rollback: bool,
pub rollback_timeout_secs: u64,
pub enable_verification: bool,
pub rollback_strategy: RollbackStrategy,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum RollbackStrategy {
LastKnownGood,
SpecificSnapshot,
Selective,
Complete,
Custom(String),
}
impl Default for RollbackConfig {
fn default() -> Self {
Self {
max_operation_log_size: 10000,
enable_snapshots: true,
snapshot_interval: 100,
max_snapshots: 50,
enable_auto_rollback: true,
rollback_timeout_secs: 300,
enable_verification: true,
rollback_strategy: RollbackStrategy::LastKnownGood,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum RollbackOperation {
RestoreSnapshot {
snapshot_id: String,
target_state: SystemState,
},
UndoOperations { operation_ids: Vec<String> },
RevertToTimestamp {
timestamp: chrono::DateTime<chrono::Utc>,
},
SelectiveRollback {
document_ids: Vec<String>,
target_versions: HashMap<String, String>,
},
SystemReset { reset_to_snapshot: String },
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SystemState {
pub snapshot_id: String,
pub created_at: chrono::DateTime<chrono::Utc>,
pub document_states: HashMap<String, DocumentState>,
pub index_states: HashMap<String, IndexState>,
pub system_metadata: HashMap<String, serde_json::Value>,
pub operations_count: u64,
pub size_bytes: u64,
pub compression_ratio: f64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DocumentState {
pub document_id: String,
pub version_id: String,
pub content_hash: String,
pub metadata_hash: String,
pub chunk_states: Vec<ChunkState>,
pub embedding_states: Vec<EmbeddingState>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChunkState {
pub chunk_index: usize,
pub chunk_hash: String,
pub size_bytes: usize,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EmbeddingState {
pub embedding_id: String,
pub source_id: String,
pub vector_hash: String,
pub metadata: HashMap<String, serde_json::Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct IndexState {
pub index_name: String,
pub index_type: String,
pub document_count: usize,
pub metadata: HashMap<String, serde_json::Value>,
pub health_status: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OperationLogEntry {
pub entry_id: String,
pub operation: IndexUpdate,
pub result: Option<UpdateResult>,
pub pre_state_hash: String,
pub post_state_hash: Option<String>,
pub timestamp: chrono::DateTime<chrono::Utc>,
pub source: String,
pub rollback_info: RollbackOperationInfo,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RollbackOperationInfo {
pub is_rollbackable: bool,
pub rollback_priority: u8,
pub rollback_dependencies: Vec<String>,
pub custom_rollback_steps: Vec<CustomRollbackStep>,
pub estimated_rollback_time_ms: u64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CustomRollbackStep {
pub step_name: String,
pub step_type: RollbackStepType,
pub parameters: HashMap<String, serde_json::Value>,
pub order: u32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum RollbackStepType {
Delete,
Restore,
Update,
RebuildIndex,
Custom(String),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RollbackPoint {
pub rollback_point_id: String,
pub created_at: chrono::DateTime<chrono::Utc>,
pub description: String,
pub operation_ids: Vec<String>,
pub system_state: SystemState,
pub metadata: HashMap<String, serde_json::Value>,
pub auto_rollback_eligible: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RecoveryResult {
pub recovery_id: String,
pub success: bool,
pub rolled_back_operations: Vec<String>,
pub final_state: Option<SystemState>,
pub recovery_time_ms: u64,
pub verification_results: Vec<VerificationResult>,
pub errors: Vec<String>,
pub metadata: HashMap<String, serde_json::Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct VerificationResult {
pub check_name: String,
pub passed: bool,
pub details: String,
pub comparison: Option<HashMap<String, serde_json::Value>>,
}
pub struct OperationLog {
entries: VecDeque<OperationLogEntry>,
max_size: usize,
total_operations: u64,
}
impl OperationLog {
pub fn new(max_size: usize) -> Self {
Self {
entries: VecDeque::new(),
max_size,
total_operations: 0,
}
}
pub fn log_operation(
&mut self,
operation: IndexUpdate,
result: Option<UpdateResult>,
pre_state_hash: String,
post_state_hash: Option<String>,
) {
let entry = OperationLogEntry {
entry_id: Uuid::new_v4().to_string(),
operation,
result,
pre_state_hash,
post_state_hash,
timestamp: chrono::Utc::now(),
source: "operation_log".to_string(),
rollback_info: RollbackOperationInfo {
is_rollbackable: true,
rollback_priority: 5,
rollback_dependencies: Vec::new(),
custom_rollback_steps: Vec::new(),
estimated_rollback_time_ms: 1000,
},
};
self.entries.push_back(entry);
self.total_operations += 1;
while self.entries.len() > self.max_size {
self.entries.pop_front();
}
}
pub fn get_recent_operations(&self, count: usize) -> Vec<&OperationLogEntry> {
self.entries.iter().rev().take(count).collect()
}
pub fn find_operations<F>(&self, predicate: F) -> Vec<&OperationLogEntry>
where
F: Fn(&OperationLogEntry) -> bool,
{
self.entries
.iter()
.filter(|entry| predicate(entry))
.collect()
}
}
pub struct RollbackManager {
config: RollbackConfig,
operation_log: Arc<RwLock<OperationLog>>,
snapshots: Arc<RwLock<VecDeque<SystemState>>>,
rollback_points: Arc<RwLock<HashMap<String, RollbackPoint>>>,
recovery_history: Arc<RwLock<VecDeque<RecoveryResult>>>,
stats: Arc<RwLock<RollbackStats>>,
task_handles: Arc<tokio::sync::Mutex<Vec<tokio::task::JoinHandle<()>>>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RollbackStats {
pub total_operations_logged: u64,
pub total_rollbacks: u64,
pub successful_rollbacks: u64,
pub failed_rollbacks: u64,
pub avg_rollback_time_ms: f64,
pub total_snapshots: u64,
pub storage_usage_bytes: u64,
pub last_snapshot_at: Option<chrono::DateTime<chrono::Utc>>,
pub last_updated: chrono::DateTime<chrono::Utc>,
}
impl RollbackManager {
pub async fn new(config: RollbackConfig) -> RragResult<Self> {
let manager = Self {
config: config.clone(),
operation_log: Arc::new(RwLock::new(OperationLog::new(
config.max_operation_log_size,
))),
snapshots: Arc::new(RwLock::new(VecDeque::new())),
rollback_points: Arc::new(RwLock::new(HashMap::new())),
recovery_history: Arc::new(RwLock::new(VecDeque::new())),
stats: Arc::new(RwLock::new(RollbackStats {
total_operations_logged: 0,
total_rollbacks: 0,
successful_rollbacks: 0,
failed_rollbacks: 0,
avg_rollback_time_ms: 0.0,
total_snapshots: 0,
storage_usage_bytes: 0,
last_snapshot_at: None,
last_updated: chrono::Utc::now(),
})),
task_handles: Arc::new(tokio::sync::Mutex::new(Vec::new())),
};
manager.start_background_tasks().await?;
Ok(manager)
}
pub async fn log_operation(
&self,
operation: IndexUpdate,
result: Option<UpdateResult>,
pre_state_hash: String,
post_state_hash: Option<String>,
) -> RragResult<()> {
let mut log = self.operation_log.write().await;
log.log_operation(operation, result, pre_state_hash, post_state_hash);
{
let mut stats = self.stats.write().await;
stats.total_operations_logged += 1;
stats.last_updated = chrono::Utc::now();
}
if self.config.enable_snapshots
&& log.total_operations % self.config.snapshot_interval as u64 == 0
{
drop(log);
self.create_snapshot("auto_snapshot".to_string()).await?;
}
Ok(())
}
pub async fn create_snapshot(&self, _description: String) -> RragResult<String> {
let snapshot_id = Uuid::new_v4().to_string();
let snapshot = SystemState {
snapshot_id: snapshot_id.clone(),
created_at: chrono::Utc::now(),
document_states: self.collect_document_states().await?,
index_states: self.collect_index_states().await?,
system_metadata: HashMap::new(),
operations_count: {
let log = self.operation_log.read().await;
log.total_operations
},
size_bytes: 0, compression_ratio: 1.0,
};
{
let mut snapshots = self.snapshots.write().await;
snapshots.push_back(snapshot);
while snapshots.len() > self.config.max_snapshots {
snapshots.pop_front();
}
}
{
let mut stats = self.stats.write().await;
stats.total_snapshots += 1;
stats.last_snapshot_at = Some(chrono::Utc::now());
stats.last_updated = chrono::Utc::now();
}
Ok(snapshot_id)
}
pub async fn create_rollback_point(
&self,
description: String,
operation_ids: Vec<String>,
auto_eligible: bool,
) -> RragResult<String> {
let rollback_point_id = Uuid::new_v4().to_string();
let snapshot_id = self
.create_snapshot(format!("rollback_point_{}", description))
.await?;
let snapshot = {
let snapshots = self.snapshots.read().await;
snapshots
.iter()
.find(|s| s.snapshot_id == snapshot_id)
.unwrap()
.clone()
};
let rollback_point = RollbackPoint {
rollback_point_id: rollback_point_id.clone(),
created_at: chrono::Utc::now(),
description,
operation_ids,
system_state: snapshot,
metadata: HashMap::new(),
auto_rollback_eligible: auto_eligible,
};
{
let mut points = self.rollback_points.write().await;
points.insert(rollback_point_id.clone(), rollback_point);
}
Ok(rollback_point_id)
}
pub async fn rollback(&self, rollback_op: RollbackOperation) -> RragResult<RecoveryResult> {
let start_time = std::time::Instant::now();
let recovery_id = Uuid::new_v4().to_string();
let mut recovery_result = RecoveryResult {
recovery_id: recovery_id.clone(),
success: false,
rolled_back_operations: Vec::new(),
final_state: None,
recovery_time_ms: 0,
verification_results: Vec::new(),
errors: Vec::new(),
metadata: HashMap::new(),
};
match rollback_op {
RollbackOperation::RestoreSnapshot { snapshot_id, .. } => {
match self.restore_from_snapshot(&snapshot_id).await {
Ok(operations) => {
recovery_result.rolled_back_operations = operations;
recovery_result.success = true;
}
Err(e) => {
recovery_result.errors.push(e.to_string());
}
}
}
RollbackOperation::UndoOperations { operation_ids } => {
match self.undo_operations(&operation_ids).await {
Ok(operations) => {
recovery_result.rolled_back_operations = operations;
recovery_result.success = true;
}
Err(e) => {
recovery_result.errors.push(e.to_string());
}
}
}
RollbackOperation::RevertToTimestamp { timestamp } => {
match self.revert_to_timestamp(timestamp).await {
Ok(operations) => {
recovery_result.rolled_back_operations = operations;
recovery_result.success = true;
}
Err(e) => {
recovery_result.errors.push(e.to_string());
}
}
}
_ => {
recovery_result
.errors
.push("Rollback operation not implemented".to_string());
}
}
recovery_result.recovery_time_ms = start_time.elapsed().as_millis() as u64;
if self.config.enable_verification {
recovery_result.verification_results = self.verify_rollback(&recovery_result).await?;
}
{
let mut history = self.recovery_history.write().await;
history.push_back(recovery_result.clone());
if history.len() > 100 {
history.pop_front();
}
}
{
let mut stats = self.stats.write().await;
stats.total_rollbacks += 1;
if recovery_result.success {
stats.successful_rollbacks += 1;
} else {
stats.failed_rollbacks += 1;
}
stats.avg_rollback_time_ms =
(stats.avg_rollback_time_ms + recovery_result.recovery_time_ms as f64) / 2.0;
stats.last_updated = chrono::Utc::now();
}
Ok(recovery_result)
}
pub async fn get_stats(&self) -> RollbackStats {
self.stats.read().await.clone()
}
pub async fn get_snapshots(&self) -> RragResult<Vec<SystemState>> {
let snapshots = self.snapshots.read().await;
Ok(snapshots.iter().cloned().collect())
}
pub async fn get_rollback_points(&self) -> RragResult<Vec<RollbackPoint>> {
let points = self.rollback_points.read().await;
Ok(points.values().cloned().collect())
}
pub async fn health_check(&self) -> RragResult<bool> {
let handles = self.task_handles.lock().await;
let all_running = handles.iter().all(|handle| !handle.is_finished());
let stats = self.get_stats().await;
let healthy_stats = stats.failed_rollbacks < stats.successful_rollbacks * 2;
Ok(all_running && healthy_stats)
}
async fn start_background_tasks(&self) -> RragResult<()> {
let mut handles = self.task_handles.lock().await;
if self.config.enable_snapshots {
handles.push(self.start_snapshot_cleanup_task().await);
}
Ok(())
}
async fn start_snapshot_cleanup_task(&self) -> tokio::task::JoinHandle<()> {
let snapshots = Arc::clone(&self.snapshots);
let config = self.config.clone();
tokio::spawn(async move {
let mut interval = tokio::time::interval(tokio::time::Duration::from_secs(3600));
loop {
interval.tick().await;
let mut snapshots_guard = snapshots.write().await;
while snapshots_guard.len() > config.max_snapshots {
snapshots_guard.pop_front();
}
}
})
}
async fn collect_document_states(&self) -> RragResult<HashMap<String, DocumentState>> {
Ok(HashMap::new())
}
async fn collect_index_states(&self) -> RragResult<HashMap<String, IndexState>> {
Ok(HashMap::new())
}
async fn restore_from_snapshot(&self, _snapshot_id: &str) -> RragResult<Vec<String>> {
Ok(Vec::new())
}
async fn undo_operations(&self, operation_ids: &[String]) -> RragResult<Vec<String>> {
Ok(operation_ids.to_vec())
}
async fn revert_to_timestamp(
&self,
_timestamp: chrono::DateTime<chrono::Utc>,
) -> RragResult<Vec<String>> {
Ok(Vec::new())
}
async fn verify_rollback(
&self,
_result: &RecoveryResult,
) -> RragResult<Vec<VerificationResult>> {
Ok(vec![VerificationResult {
check_name: "system_integrity".to_string(),
passed: true,
details: "System integrity verified".to_string(),
comparison: None,
}])
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::incremental::index_manager::IndexOperation;
use crate::Document;
#[tokio::test]
async fn test_rollback_manager_creation() {
let config = RollbackConfig::default();
let manager = RollbackManager::new(config).await.unwrap();
assert!(manager.health_check().await.unwrap());
}
#[tokio::test]
async fn test_operation_logging() {
let manager = RollbackManager::new(RollbackConfig::default())
.await
.unwrap();
let doc = Document::new("Test content");
let operation = IndexOperation::Add {
document: doc,
chunks: Vec::new(),
embeddings: Vec::new(),
};
let update = IndexUpdate {
operation_id: Uuid::new_v4().to_string(),
operation,
priority: 5,
timestamp: chrono::Utc::now(),
source: "test".to_string(),
metadata: HashMap::new(),
dependencies: Vec::new(),
max_retries: 3,
retry_count: 0,
};
manager
.log_operation(
update,
None,
"pre_hash".to_string(),
Some("post_hash".to_string()),
)
.await
.unwrap();
let stats = manager.get_stats().await;
assert_eq!(stats.total_operations_logged, 1);
}
#[tokio::test]
async fn test_snapshot_creation() {
let manager = RollbackManager::new(RollbackConfig::default())
.await
.unwrap();
let snapshot_id = manager
.create_snapshot("test_snapshot".to_string())
.await
.unwrap();
assert!(!snapshot_id.is_empty());
let snapshots = manager.get_snapshots().await.unwrap();
assert_eq!(snapshots.len(), 1);
assert_eq!(snapshots[0].snapshot_id, snapshot_id);
let stats = manager.get_stats().await;
assert_eq!(stats.total_snapshots, 1);
}
#[tokio::test]
async fn test_rollback_point_creation() {
let manager = RollbackManager::new(RollbackConfig::default())
.await
.unwrap();
let point_id = manager
.create_rollback_point(
"test_point".to_string(),
vec!["op1".to_string(), "op2".to_string()],
true,
)
.await
.unwrap();
assert!(!point_id.is_empty());
let points = manager.get_rollback_points().await.unwrap();
assert_eq!(points.len(), 1);
assert_eq!(points[0].rollback_point_id, point_id);
assert_eq!(points[0].operation_ids.len(), 2);
}
#[tokio::test]
async fn test_rollback_operation() {
let manager = RollbackManager::new(RollbackConfig::default())
.await
.unwrap();
let snapshot_id = manager
.create_snapshot("test_snapshot".to_string())
.await
.unwrap();
let rollback_op = RollbackOperation::RestoreSnapshot {
snapshot_id,
target_state: SystemState {
snapshot_id: "dummy".to_string(),
created_at: chrono::Utc::now(),
document_states: HashMap::new(),
index_states: HashMap::new(),
system_metadata: HashMap::new(),
operations_count: 0,
size_bytes: 0,
compression_ratio: 1.0,
},
};
let result = manager.rollback(rollback_op).await.unwrap();
assert!(result.success);
assert!(result.recovery_time_ms > 0);
let stats = manager.get_stats().await;
assert_eq!(stats.total_rollbacks, 1);
assert_eq!(stats.successful_rollbacks, 1);
}
#[test]
fn test_rollback_strategies() {
let strategies = vec![
RollbackStrategy::LastKnownGood,
RollbackStrategy::SpecificSnapshot,
RollbackStrategy::Selective,
RollbackStrategy::Complete,
RollbackStrategy::Custom("custom".to_string()),
];
for (i, strategy1) in strategies.iter().enumerate() {
for (j, strategy2) in strategies.iter().enumerate() {
if i != j {
assert_ne!(format!("{:?}", strategy1), format!("{:?}", strategy2));
}
}
}
}
#[test]
fn test_rollback_step_types() {
let step_types = vec![
RollbackStepType::Delete,
RollbackStepType::Restore,
RollbackStepType::Update,
RollbackStepType::RebuildIndex,
RollbackStepType::Custom("custom".to_string()),
];
for (i, type1) in step_types.iter().enumerate() {
for (j, type2) in step_types.iter().enumerate() {
if i != j {
assert_ne!(format!("{:?}", type1), format!("{:?}", type2));
}
}
}
}
}