1use crate::incremental::index_manager::{IndexUpdate, UpdateResult};
7use crate::RragResult;
8use serde::{Deserialize, Serialize};
9use std::collections::{HashMap, VecDeque};
10use std::sync::Arc;
11use tokio::sync::RwLock;
12use uuid::Uuid;
13
14#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct RollbackConfig {
17 pub max_operation_log_size: usize,
19
20 pub enable_snapshots: bool,
22
23 pub snapshot_interval: usize,
25
26 pub max_snapshots: usize,
28
29 pub enable_auto_rollback: bool,
31
32 pub rollback_timeout_secs: u64,
34
35 pub enable_verification: bool,
37
38 pub rollback_strategy: RollbackStrategy,
40}
41
42#[derive(Debug, Clone, Serialize, Deserialize)]
44pub enum RollbackStrategy {
45 LastKnownGood,
47 SpecificSnapshot,
49 Selective,
51 Complete,
53 Custom(String),
55}
56
57impl Default for RollbackConfig {
58 fn default() -> Self {
59 Self {
60 max_operation_log_size: 10000,
61 enable_snapshots: true,
62 snapshot_interval: 100,
63 max_snapshots: 50,
64 enable_auto_rollback: true,
65 rollback_timeout_secs: 300,
66 enable_verification: true,
67 rollback_strategy: RollbackStrategy::LastKnownGood,
68 }
69 }
70}
71
72#[derive(Debug, Clone, Serialize, Deserialize)]
74pub enum RollbackOperation {
75 RestoreSnapshot {
77 snapshot_id: String,
78 target_state: SystemState,
79 },
80
81 UndoOperations { operation_ids: Vec<String> },
83
84 RevertToTimestamp {
86 timestamp: chrono::DateTime<chrono::Utc>,
87 },
88
89 SelectiveRollback {
91 document_ids: Vec<String>,
92 target_versions: HashMap<String, String>,
93 },
94
95 SystemReset { reset_to_snapshot: String },
97}
98
99#[derive(Debug, Clone, Serialize, Deserialize)]
101pub struct SystemState {
102 pub snapshot_id: String,
104
105 pub created_at: chrono::DateTime<chrono::Utc>,
107
108 pub document_states: HashMap<String, DocumentState>,
110
111 pub index_states: HashMap<String, IndexState>,
113
114 pub system_metadata: HashMap<String, serde_json::Value>,
116
117 pub operations_count: u64,
119
120 pub size_bytes: u64,
122
123 pub compression_ratio: f64,
125}
126
127#[derive(Debug, Clone, Serialize, Deserialize)]
129pub struct DocumentState {
130 pub document_id: String,
132
133 pub version_id: String,
135
136 pub content_hash: String,
138
139 pub metadata_hash: String,
141
142 pub chunk_states: Vec<ChunkState>,
144
145 pub embedding_states: Vec<EmbeddingState>,
147}
148
149#[derive(Debug, Clone, Serialize, Deserialize)]
151pub struct ChunkState {
152 pub chunk_index: usize,
154
155 pub chunk_hash: String,
157
158 pub size_bytes: usize,
160}
161
162#[derive(Debug, Clone, Serialize, Deserialize)]
164pub struct EmbeddingState {
165 pub embedding_id: String,
167
168 pub source_id: String,
170
171 pub vector_hash: String,
173
174 pub metadata: HashMap<String, serde_json::Value>,
176}
177
178#[derive(Debug, Clone, Serialize, Deserialize)]
180pub struct IndexState {
181 pub index_name: String,
183
184 pub index_type: String,
186
187 pub document_count: usize,
189
190 pub metadata: HashMap<String, serde_json::Value>,
192
193 pub health_status: String,
195}
196
197#[derive(Debug, Clone, Serialize, Deserialize)]
199pub struct OperationLogEntry {
200 pub entry_id: String,
202
203 pub operation: IndexUpdate,
205
206 pub result: Option<UpdateResult>,
208
209 pub pre_state_hash: String,
211
212 pub post_state_hash: Option<String>,
214
215 pub timestamp: chrono::DateTime<chrono::Utc>,
217
218 pub source: String,
220
221 pub rollback_info: RollbackOperationInfo,
223}
224
225#[derive(Debug, Clone, Serialize, Deserialize)]
227pub struct RollbackOperationInfo {
228 pub is_rollbackable: bool,
230
231 pub rollback_priority: u8,
233
234 pub rollback_dependencies: Vec<String>,
236
237 pub custom_rollback_steps: Vec<CustomRollbackStep>,
239
240 pub estimated_rollback_time_ms: u64,
242}
243
244#[derive(Debug, Clone, Serialize, Deserialize)]
246pub struct CustomRollbackStep {
247 pub step_name: String,
249
250 pub step_type: RollbackStepType,
252
253 pub parameters: HashMap<String, serde_json::Value>,
255
256 pub order: u32,
258}
259
260#[derive(Debug, Clone, Serialize, Deserialize)]
262pub enum RollbackStepType {
263 Delete,
265 Restore,
267 Update,
269 RebuildIndex,
271 Custom(String),
273}
274
275#[derive(Debug, Clone, Serialize, Deserialize)]
277pub struct RollbackPoint {
278 pub rollback_point_id: String,
280
281 pub created_at: chrono::DateTime<chrono::Utc>,
283
284 pub description: String,
286
287 pub operation_ids: Vec<String>,
289
290 pub system_state: SystemState,
292
293 pub metadata: HashMap<String, serde_json::Value>,
295
296 pub auto_rollback_eligible: bool,
298}
299
300#[derive(Debug, Clone, Serialize, Deserialize)]
302pub struct RecoveryResult {
303 pub recovery_id: String,
305
306 pub success: bool,
308
309 pub rolled_back_operations: Vec<String>,
311
312 pub final_state: Option<SystemState>,
314
315 pub recovery_time_ms: u64,
317
318 pub verification_results: Vec<VerificationResult>,
320
321 pub errors: Vec<String>,
323
324 pub metadata: HashMap<String, serde_json::Value>,
326}
327
328#[derive(Debug, Clone, Serialize, Deserialize)]
330pub struct VerificationResult {
331 pub check_name: String,
333
334 pub passed: bool,
336
337 pub details: String,
339
340 pub comparison: Option<HashMap<String, serde_json::Value>>,
342}
343
344pub struct OperationLog {
346 entries: VecDeque<OperationLogEntry>,
348
349 max_size: usize,
351
352 total_operations: u64,
354}
355
356impl OperationLog {
357 pub fn new(max_size: usize) -> Self {
359 Self {
360 entries: VecDeque::new(),
361 max_size,
362 total_operations: 0,
363 }
364 }
365
366 pub fn log_operation(
368 &mut self,
369 operation: IndexUpdate,
370 result: Option<UpdateResult>,
371 pre_state_hash: String,
372 post_state_hash: Option<String>,
373 ) {
374 let entry = OperationLogEntry {
375 entry_id: Uuid::new_v4().to_string(),
376 operation,
377 result,
378 pre_state_hash,
379 post_state_hash,
380 timestamp: chrono::Utc::now(),
381 source: "operation_log".to_string(),
382 rollback_info: RollbackOperationInfo {
383 is_rollbackable: true,
384 rollback_priority: 5,
385 rollback_dependencies: Vec::new(),
386 custom_rollback_steps: Vec::new(),
387 estimated_rollback_time_ms: 1000,
388 },
389 };
390
391 self.entries.push_back(entry);
392 self.total_operations += 1;
393
394 while self.entries.len() > self.max_size {
396 self.entries.pop_front();
397 }
398 }
399
400 pub fn get_recent_operations(&self, count: usize) -> Vec<&OperationLogEntry> {
402 self.entries.iter().rev().take(count).collect()
403 }
404
405 pub fn find_operations<F>(&self, predicate: F) -> Vec<&OperationLogEntry>
407 where
408 F: Fn(&OperationLogEntry) -> bool,
409 {
410 self.entries
411 .iter()
412 .filter(|entry| predicate(entry))
413 .collect()
414 }
415}
416
417pub struct RollbackManager {
419 config: RollbackConfig,
421
422 operation_log: Arc<RwLock<OperationLog>>,
424
425 snapshots: Arc<RwLock<VecDeque<SystemState>>>,
427
428 rollback_points: Arc<RwLock<HashMap<String, RollbackPoint>>>,
430
431 recovery_history: Arc<RwLock<VecDeque<RecoveryResult>>>,
433
434 stats: Arc<RwLock<RollbackStats>>,
436
437 task_handles: Arc<tokio::sync::Mutex<Vec<tokio::task::JoinHandle<()>>>>,
439}
440
441#[derive(Debug, Clone, Serialize, Deserialize)]
443pub struct RollbackStats {
444 pub total_operations_logged: u64,
446
447 pub total_rollbacks: u64,
449
450 pub successful_rollbacks: u64,
452
453 pub failed_rollbacks: u64,
455
456 pub avg_rollback_time_ms: f64,
458
459 pub total_snapshots: u64,
461
462 pub storage_usage_bytes: u64,
464
465 pub last_snapshot_at: Option<chrono::DateTime<chrono::Utc>>,
467
468 pub last_updated: chrono::DateTime<chrono::Utc>,
470}
471
472impl RollbackManager {
473 pub async fn new(config: RollbackConfig) -> RragResult<Self> {
475 let manager = Self {
476 config: config.clone(),
477 operation_log: Arc::new(RwLock::new(OperationLog::new(
478 config.max_operation_log_size,
479 ))),
480 snapshots: Arc::new(RwLock::new(VecDeque::new())),
481 rollback_points: Arc::new(RwLock::new(HashMap::new())),
482 recovery_history: Arc::new(RwLock::new(VecDeque::new())),
483 stats: Arc::new(RwLock::new(RollbackStats {
484 total_operations_logged: 0,
485 total_rollbacks: 0,
486 successful_rollbacks: 0,
487 failed_rollbacks: 0,
488 avg_rollback_time_ms: 0.0,
489 total_snapshots: 0,
490 storage_usage_bytes: 0,
491 last_snapshot_at: None,
492 last_updated: chrono::Utc::now(),
493 })),
494 task_handles: Arc::new(tokio::sync::Mutex::new(Vec::new())),
495 };
496
497 manager.start_background_tasks().await?;
498 Ok(manager)
499 }
500
501 pub async fn log_operation(
503 &self,
504 operation: IndexUpdate,
505 result: Option<UpdateResult>,
506 pre_state_hash: String,
507 post_state_hash: Option<String>,
508 ) -> RragResult<()> {
509 let mut log = self.operation_log.write().await;
510 log.log_operation(operation, result, pre_state_hash, post_state_hash);
511
512 {
514 let mut stats = self.stats.write().await;
515 stats.total_operations_logged += 1;
516 stats.last_updated = chrono::Utc::now();
517 }
518
519 if self.config.enable_snapshots
521 && log.total_operations % self.config.snapshot_interval as u64 == 0
522 {
523 drop(log);
524 self.create_snapshot("auto_snapshot".to_string()).await?;
525 }
526
527 Ok(())
528 }
529
530 pub async fn create_snapshot(&self, _description: String) -> RragResult<String> {
532 let snapshot_id = Uuid::new_v4().to_string();
533
534 let snapshot = SystemState {
536 snapshot_id: snapshot_id.clone(),
537 created_at: chrono::Utc::now(),
538 document_states: self.collect_document_states().await?,
539 index_states: self.collect_index_states().await?,
540 system_metadata: HashMap::new(),
541 operations_count: {
542 let log = self.operation_log.read().await;
543 log.total_operations
544 },
545 size_bytes: 0, compression_ratio: 1.0,
547 };
548
549 {
551 let mut snapshots = self.snapshots.write().await;
552 snapshots.push_back(snapshot);
553
554 while snapshots.len() > self.config.max_snapshots {
556 snapshots.pop_front();
557 }
558 }
559
560 {
562 let mut stats = self.stats.write().await;
563 stats.total_snapshots += 1;
564 stats.last_snapshot_at = Some(chrono::Utc::now());
565 stats.last_updated = chrono::Utc::now();
566 }
567
568 Ok(snapshot_id)
569 }
570
571 pub async fn create_rollback_point(
573 &self,
574 description: String,
575 operation_ids: Vec<String>,
576 auto_eligible: bool,
577 ) -> RragResult<String> {
578 let rollback_point_id = Uuid::new_v4().to_string();
579
580 let snapshot_id = self
582 .create_snapshot(format!("rollback_point_{}", description))
583 .await?;
584
585 let snapshot = {
586 let snapshots = self.snapshots.read().await;
587 snapshots
588 .iter()
589 .find(|s| s.snapshot_id == snapshot_id)
590 .unwrap()
591 .clone()
592 };
593
594 let rollback_point = RollbackPoint {
595 rollback_point_id: rollback_point_id.clone(),
596 created_at: chrono::Utc::now(),
597 description,
598 operation_ids,
599 system_state: snapshot,
600 metadata: HashMap::new(),
601 auto_rollback_eligible: auto_eligible,
602 };
603
604 {
605 let mut points = self.rollback_points.write().await;
606 points.insert(rollback_point_id.clone(), rollback_point);
607 }
608
609 Ok(rollback_point_id)
610 }
611
612 pub async fn rollback(&self, rollback_op: RollbackOperation) -> RragResult<RecoveryResult> {
614 let start_time = std::time::Instant::now();
615 let recovery_id = Uuid::new_v4().to_string();
616
617 let mut recovery_result = RecoveryResult {
618 recovery_id: recovery_id.clone(),
619 success: false,
620 rolled_back_operations: Vec::new(),
621 final_state: None,
622 recovery_time_ms: 0,
623 verification_results: Vec::new(),
624 errors: Vec::new(),
625 metadata: HashMap::new(),
626 };
627
628 match rollback_op {
629 RollbackOperation::RestoreSnapshot { snapshot_id, .. } => {
630 match self.restore_from_snapshot(&snapshot_id).await {
631 Ok(operations) => {
632 recovery_result.rolled_back_operations = operations;
633 recovery_result.success = true;
634 }
635 Err(e) => {
636 recovery_result.errors.push(e.to_string());
637 }
638 }
639 }
640
641 RollbackOperation::UndoOperations { operation_ids } => {
642 match self.undo_operations(&operation_ids).await {
643 Ok(operations) => {
644 recovery_result.rolled_back_operations = operations;
645 recovery_result.success = true;
646 }
647 Err(e) => {
648 recovery_result.errors.push(e.to_string());
649 }
650 }
651 }
652
653 RollbackOperation::RevertToTimestamp { timestamp } => {
654 match self.revert_to_timestamp(timestamp).await {
655 Ok(operations) => {
656 recovery_result.rolled_back_operations = operations;
657 recovery_result.success = true;
658 }
659 Err(e) => {
660 recovery_result.errors.push(e.to_string());
661 }
662 }
663 }
664
665 _ => {
666 recovery_result
667 .errors
668 .push("Rollback operation not implemented".to_string());
669 }
670 }
671
672 recovery_result.recovery_time_ms = start_time.elapsed().as_millis() as u64;
673
674 if self.config.enable_verification {
676 recovery_result.verification_results = self.verify_rollback(&recovery_result).await?;
677 }
678
679 {
681 let mut history = self.recovery_history.write().await;
682 history.push_back(recovery_result.clone());
683
684 if history.len() > 100 {
686 history.pop_front();
687 }
688 }
689
690 {
692 let mut stats = self.stats.write().await;
693 stats.total_rollbacks += 1;
694 if recovery_result.success {
695 stats.successful_rollbacks += 1;
696 } else {
697 stats.failed_rollbacks += 1;
698 }
699 stats.avg_rollback_time_ms =
700 (stats.avg_rollback_time_ms + recovery_result.recovery_time_ms as f64) / 2.0;
701 stats.last_updated = chrono::Utc::now();
702 }
703
704 Ok(recovery_result)
705 }
706
707 pub async fn get_stats(&self) -> RollbackStats {
709 self.stats.read().await.clone()
710 }
711
712 pub async fn get_snapshots(&self) -> RragResult<Vec<SystemState>> {
714 let snapshots = self.snapshots.read().await;
715 Ok(snapshots.iter().cloned().collect())
716 }
717
718 pub async fn get_rollback_points(&self) -> RragResult<Vec<RollbackPoint>> {
720 let points = self.rollback_points.read().await;
721 Ok(points.values().cloned().collect())
722 }
723
724 pub async fn health_check(&self) -> RragResult<bool> {
726 let handles = self.task_handles.lock().await;
727 let all_running = handles.iter().all(|handle| !handle.is_finished());
728
729 let stats = self.get_stats().await;
730 let healthy_stats = stats.failed_rollbacks < stats.successful_rollbacks * 2; Ok(all_running && healthy_stats)
733 }
734
735 async fn start_background_tasks(&self) -> RragResult<()> {
737 let mut handles = self.task_handles.lock().await;
738
739 if self.config.enable_snapshots {
740 handles.push(self.start_snapshot_cleanup_task().await);
741 }
742
743 Ok(())
744 }
745
746 async fn start_snapshot_cleanup_task(&self) -> tokio::task::JoinHandle<()> {
748 let snapshots = Arc::clone(&self.snapshots);
749 let config = self.config.clone();
750
751 tokio::spawn(async move {
752 let mut interval = tokio::time::interval(tokio::time::Duration::from_secs(3600)); loop {
755 interval.tick().await;
756
757 let mut snapshots_guard = snapshots.write().await;
758
759 while snapshots_guard.len() > config.max_snapshots {
761 snapshots_guard.pop_front();
762 }
763
764 }
766 })
767 }
768
769 async fn collect_document_states(&self) -> RragResult<HashMap<String, DocumentState>> {
771 Ok(HashMap::new())
773 }
774
775 async fn collect_index_states(&self) -> RragResult<HashMap<String, IndexState>> {
777 Ok(HashMap::new())
779 }
780
781 async fn restore_from_snapshot(&self, _snapshot_id: &str) -> RragResult<Vec<String>> {
783 Ok(Vec::new())
785 }
786
787 async fn undo_operations(&self, operation_ids: &[String]) -> RragResult<Vec<String>> {
789 Ok(operation_ids.to_vec())
791 }
792
793 async fn revert_to_timestamp(
795 &self,
796 _timestamp: chrono::DateTime<chrono::Utc>,
797 ) -> RragResult<Vec<String>> {
798 Ok(Vec::new())
800 }
801
802 async fn verify_rollback(
804 &self,
805 _result: &RecoveryResult,
806 ) -> RragResult<Vec<VerificationResult>> {
807 Ok(vec![VerificationResult {
809 check_name: "system_integrity".to_string(),
810 passed: true,
811 details: "System integrity verified".to_string(),
812 comparison: None,
813 }])
814 }
815}
816
817#[cfg(test)]
818mod tests {
819 use super::*;
820 use crate::incremental::index_manager::IndexOperation;
821 use crate::Document;
822
823 #[tokio::test]
824 async fn test_rollback_manager_creation() {
825 let config = RollbackConfig::default();
826 let manager = RollbackManager::new(config).await.unwrap();
827 assert!(manager.health_check().await.unwrap());
828 }
829
830 #[tokio::test]
831 async fn test_operation_logging() {
832 let manager = RollbackManager::new(RollbackConfig::default())
833 .await
834 .unwrap();
835
836 let doc = Document::new("Test content");
837 let operation = IndexOperation::Add {
838 document: doc,
839 chunks: Vec::new(),
840 embeddings: Vec::new(),
841 };
842
843 let update = IndexUpdate {
844 operation_id: Uuid::new_v4().to_string(),
845 operation,
846 priority: 5,
847 timestamp: chrono::Utc::now(),
848 source: "test".to_string(),
849 metadata: HashMap::new(),
850 dependencies: Vec::new(),
851 max_retries: 3,
852 retry_count: 0,
853 };
854
855 manager
856 .log_operation(
857 update,
858 None,
859 "pre_hash".to_string(),
860 Some("post_hash".to_string()),
861 )
862 .await
863 .unwrap();
864
865 let stats = manager.get_stats().await;
866 assert_eq!(stats.total_operations_logged, 1);
867 }
868
869 #[tokio::test]
870 async fn test_snapshot_creation() {
871 let manager = RollbackManager::new(RollbackConfig::default())
872 .await
873 .unwrap();
874
875 let snapshot_id = manager
876 .create_snapshot("test_snapshot".to_string())
877 .await
878 .unwrap();
879 assert!(!snapshot_id.is_empty());
880
881 let snapshots = manager.get_snapshots().await.unwrap();
882 assert_eq!(snapshots.len(), 1);
883 assert_eq!(snapshots[0].snapshot_id, snapshot_id);
884
885 let stats = manager.get_stats().await;
886 assert_eq!(stats.total_snapshots, 1);
887 }
888
889 #[tokio::test]
890 async fn test_rollback_point_creation() {
891 let manager = RollbackManager::new(RollbackConfig::default())
892 .await
893 .unwrap();
894
895 let point_id = manager
896 .create_rollback_point(
897 "test_point".to_string(),
898 vec!["op1".to_string(), "op2".to_string()],
899 true,
900 )
901 .await
902 .unwrap();
903
904 assert!(!point_id.is_empty());
905
906 let points = manager.get_rollback_points().await.unwrap();
907 assert_eq!(points.len(), 1);
908 assert_eq!(points[0].rollback_point_id, point_id);
909 assert_eq!(points[0].operation_ids.len(), 2);
910 }
911
912 #[tokio::test]
913 async fn test_rollback_operation() {
914 let manager = RollbackManager::new(RollbackConfig::default())
915 .await
916 .unwrap();
917
918 let snapshot_id = manager
920 .create_snapshot("test_snapshot".to_string())
921 .await
922 .unwrap();
923
924 let rollback_op = RollbackOperation::RestoreSnapshot {
926 snapshot_id,
927 target_state: SystemState {
928 snapshot_id: "dummy".to_string(),
929 created_at: chrono::Utc::now(),
930 document_states: HashMap::new(),
931 index_states: HashMap::new(),
932 system_metadata: HashMap::new(),
933 operations_count: 0,
934 size_bytes: 0,
935 compression_ratio: 1.0,
936 },
937 };
938
939 let result = manager.rollback(rollback_op).await.unwrap();
940 assert!(result.success);
941 assert!(result.recovery_time_ms > 0);
942
943 let stats = manager.get_stats().await;
944 assert_eq!(stats.total_rollbacks, 1);
945 assert_eq!(stats.successful_rollbacks, 1);
946 }
947
948 #[test]
949 fn test_rollback_strategies() {
950 let strategies = vec![
951 RollbackStrategy::LastKnownGood,
952 RollbackStrategy::SpecificSnapshot,
953 RollbackStrategy::Selective,
954 RollbackStrategy::Complete,
955 RollbackStrategy::Custom("custom".to_string()),
956 ];
957
958 for (i, strategy1) in strategies.iter().enumerate() {
960 for (j, strategy2) in strategies.iter().enumerate() {
961 if i != j {
962 assert_ne!(format!("{:?}", strategy1), format!("{:?}", strategy2));
963 }
964 }
965 }
966 }
967
968 #[test]
969 fn test_rollback_step_types() {
970 let step_types = vec![
971 RollbackStepType::Delete,
972 RollbackStepType::Restore,
973 RollbackStepType::Update,
974 RollbackStepType::RebuildIndex,
975 RollbackStepType::Custom("custom".to_string()),
976 ];
977
978 for (i, type1) in step_types.iter().enumerate() {
980 for (j, type2) in step_types.iter().enumerate() {
981 if i != j {
982 assert_ne!(format!("{:?}", type1), format!("{:?}", type2));
983 }
984 }
985 }
986 }
987}