1use crate::audit::AuditLog;
22use crate::workflow::dag::Workflow;
23use crate::workflow::task::{CompensationAction, CompensationType, TaskContext, TaskId, TaskError, TaskResult, WorkflowTask};
24use chrono::Utc;
25use petgraph::graph::NodeIndex;
26use petgraph::visit::IntoNeighborsDirected;
27use petgraph::Direction;
28use std::collections::{HashMap, HashSet, VecDeque};
29use std::fs;
30use std::path::Path;
31use std::sync::Arc;
32use thiserror::Error;
33
34#[derive(Clone)]
40pub struct ExecutableCompensation {
41 pub action: CompensationAction,
43 #[allow(clippy::type_complexity)]
45 undo_fn: Option<Arc<dyn Fn(&TaskContext) -> Result<TaskResult, TaskError> + Send + Sync>>,
46}
47
48impl ExecutableCompensation {
49 pub fn new(action: CompensationAction) -> Self {
51 Self {
52 action,
53 undo_fn: None,
54 }
55 }
56
57 pub fn with_undo<F>(description: impl Into<String>, undo_fn: F) -> Self
59 where
60 F: Fn(&TaskContext) -> Result<TaskResult, TaskError> + Send + Sync + 'static,
61 {
62 Self {
63 action: CompensationAction::undo(description),
64 undo_fn: Some(Arc::new(undo_fn)),
65 }
66 }
67
68 pub fn skip(description: impl Into<String>) -> Self {
70 Self::new(CompensationAction::skip(description))
71 }
72
73 pub fn retry(description: impl Into<String>) -> Self {
75 Self::new(CompensationAction::retry(description))
76 }
77
78 pub fn execute(&self, context: &TaskContext) -> Result<TaskResult, TaskError> {
80 match self.action.action_type {
81 CompensationType::UndoFunction => {
82 if let Some(undo_fn) = &self.undo_fn {
83 undo_fn(context)
84 } else {
85 Ok(TaskResult::Skipped)
86 }
87 }
88 CompensationType::Skip => Ok(TaskResult::Skipped),
89 CompensationType::Retry => Ok(TaskResult::Skipped),
90 }
91 }
92}
93
94impl From<ExecutableCompensation> for CompensationAction {
95 fn from(exec: ExecutableCompensation) -> Self {
96 exec.action
97 }
98}
99
100#[derive(Clone)]
106pub struct ToolCompensation {
107 pub description: String,
109 #[allow(clippy::type_complexity)]
111 compensate: Arc<dyn Fn(&TaskContext) -> Result<TaskResult, TaskError> + Send + Sync>,
112}
113
114impl ToolCompensation {
115 pub fn new<F>(description: impl Into<String>, compensate_fn: F) -> Self
117 where
118 F: Fn(&TaskContext) -> Result<TaskResult, TaskError> + Send + Sync + 'static,
119 {
120 Self {
121 description: description.into(),
122 compensate: Arc::new(compensate_fn),
123 }
124 }
125
126 pub fn execute(&self, context: &TaskContext) -> Result<TaskResult, TaskError> {
128 (self.compensate)(context)
129 }
130
131 pub fn file_compensation(file_path: impl Into<String>) -> Self {
147 let path = file_path.into();
148 Self::new(format!("Delete file: {}", path), move |_context| {
149 if Path::new(&path).exists() {
151 fs::remove_file(&path).map_err(|e| {
152 TaskError::ExecutionFailed(format!("Failed to delete file {}: {}", path, e))
153 })?;
154 }
155 Ok(TaskResult::Success)
156 })
157 }
158
159 pub fn process_compensation(pid: u32) -> Self {
175 Self::new(format!("Terminate process: {}", pid), move |_context| {
176 #[cfg(unix)]
178 {
179 use std::process::Command;
180 let result = Command::new("kill")
181 .arg("-TERM")
182 .arg(pid.to_string())
183 .output();
184
185 match result {
186 Ok(_) => Ok(TaskResult::Success),
187 Err(e) => Ok(TaskResult::Failed(format!("Failed to terminate process {}: {}", pid, e))),
188 }
189 }
190
191 #[cfg(not(unix))]
192 {
193 Ok(TaskResult::Failed(format!("Process termination not supported on this platform")))
194 }
195 })
196 }
197
198 pub fn skip(description: impl Into<String>) -> Self {
202 Self::new(description, |_context| Ok(TaskResult::Skipped))
203 }
204
205 pub fn retry(description: impl Into<String>) -> Self {
209 Self::new(description, |_context| Ok(TaskResult::Skipped))
210 }
211}
212
213impl From<CompensationAction> for ToolCompensation {
214 fn from(action: CompensationAction) -> Self {
215 match action.action_type {
216 CompensationType::Skip => ToolCompensation::skip(action.description),
217 CompensationType::Retry => ToolCompensation::retry(action.description),
218 CompensationType::UndoFunction => {
219 ToolCompensation::skip(format!(
222 "{} (no undo function available)",
223 action.description
224 ))
225 }
226 }
227 }
228}
229
230pub struct CompensationRegistry {
236 compensations: HashMap<TaskId, ToolCompensation>,
238}
239
240impl CompensationRegistry {
241 pub fn new() -> Self {
243 Self {
244 compensations: HashMap::new(),
245 }
246 }
247
248 pub fn register(&mut self, task_id: TaskId, compensation: ToolCompensation) {
264 self.compensations.insert(task_id, compensation);
265 }
266
267 pub fn get(&self, task_id: &TaskId) -> Option<&ToolCompensation> {
278 self.compensations.get(task_id)
279 }
280
281 pub fn has_compensation(&self, task_id: &TaskId) -> bool {
292 self.compensations.contains_key(task_id)
293 }
294
295 pub fn remove(&mut self, task_id: &TaskId) -> Option<ToolCompensation> {
308 self.compensations.remove(task_id)
309 }
310
311 pub fn validate_coverage(&self, task_ids: &[TaskId]) -> CompensationReport {
323 let mut with_compensation = Vec::new();
324 let mut without_compensation = Vec::new();
325
326 for task_id in task_ids {
327 if self.has_compensation(task_id) {
328 with_compensation.push(task_id.clone());
329 } else {
330 without_compensation.push(task_id.clone());
331 }
332 }
333
334 let total = task_ids.len();
335 let coverage = CompensationReport::calculate(with_compensation.len(), total);
336
337 CompensationReport {
338 tasks_with_compensation: with_compensation,
339 tasks_without_compensation: without_compensation,
340 coverage_percentage: coverage,
341 }
342 }
343
344 pub fn register_file_creation(&mut self, task_id: TaskId, file_path: impl Into<String>) {
353 self.register(task_id, ToolCompensation::file_compensation(file_path));
354 }
355
356 pub fn register_process_spawn(&mut self, task_id: TaskId, pid: u32) {
365 self.register(task_id, ToolCompensation::process_compensation(pid));
366 }
367
368 pub fn len(&self) -> usize {
370 self.compensations.len()
371 }
372
373 pub fn is_empty(&self) -> bool {
375 self.compensations.is_empty()
376 }
377
378 pub fn task_ids(&self) -> Vec<TaskId> {
380 self.compensations.keys().cloned().collect()
381 }
382}
383
384impl Default for CompensationRegistry {
385 fn default() -> Self {
386 Self::new()
387 }
388}
389
390#[derive(Clone, Copy, Debug, PartialEq, Eq)]
392pub enum RollbackStrategy {
393 AllDependent,
395 FailedOnly,
397 Custom,
399}
400
401#[derive(Clone, Debug)]
403pub struct RollbackReport {
404 pub rolled_back_tasks: Vec<TaskId>,
406 pub skipped_tasks: Vec<TaskId>,
408 pub failed_compensations: Vec<(TaskId, String)>,
410}
411
412impl RollbackReport {
413 fn new() -> Self {
415 Self {
416 rolled_back_tasks: Vec::new(),
417 skipped_tasks: Vec::new(),
418 failed_compensations: Vec::new(),
419 }
420 }
421
422 pub fn total_processed(&self) -> usize {
424 self.rolled_back_tasks.len() + self.skipped_tasks.len() + self.failed_compensations.len()
425 }
426}
427
428#[derive(Clone, Debug)]
430pub struct CompensationReport {
431 pub tasks_with_compensation: Vec<TaskId>,
433 pub tasks_without_compensation: Vec<TaskId>,
435 pub coverage_percentage: f64,
437}
438
439impl CompensationReport {
440 fn calculate(with_compensation: usize, total: usize) -> f64 {
442 if total == 0 {
443 1.0
444 } else {
445 with_compensation as f64 / total as f64
446 }
447 }
448}
449
450#[derive(Error, Debug)]
452pub enum RollbackError {
453 #[error("Failed to determine rollback set: {0}")]
455 RollbackSetFailed(String),
456
457 #[error("Task not found during rollback: {0}")]
459 TaskNotFound(TaskId),
460
461 #[error("Compensation failed for task {0}: {1}")]
463 CompensationFailed(TaskId, String),
464
465 #[error("Graph traversal error: {0}")]
467 TraversalError(String),
468}
469
470pub struct RollbackEngine {
475 _private: (),
476}
477
478impl RollbackEngine {
479 pub fn new() -> Self {
481 Self { _private: () }
482 }
483
484 pub fn find_rollback_set(
511 &self,
512 workflow: &Workflow,
513 failed_task: &TaskId,
514 strategy: RollbackStrategy,
515 ) -> Result<Vec<TaskId>, RollbackError> {
516 let failed_idx = *workflow
518 .task_map
519 .get(failed_task)
520 .ok_or_else(|| RollbackError::TaskNotFound(failed_task.clone()))?;
521
522 match strategy {
523 RollbackStrategy::FailedOnly => {
524 Ok(vec![failed_task.clone()])
526 }
527 RollbackStrategy::AllDependent => {
528 let dependent_set = self.find_dependent_tasks(workflow, failed_idx)?;
530 self.reverse_execution_order(workflow, dependent_set)
532 }
533 RollbackStrategy::Custom => {
534 let dependent_set = self.find_dependent_tasks(workflow, failed_idx)?;
537 self.reverse_execution_order(workflow, dependent_set)
538 }
539 }
540 }
541
542 fn find_dependent_tasks(
559 &self,
560 workflow: &Workflow,
561 failed_idx: NodeIndex,
562 ) -> Result<HashSet<TaskId>, RollbackError> {
563 let mut dependent_set = HashSet::new();
564 let mut visited = HashSet::new();
565 let mut stack = VecDeque::new();
566
567 stack.push_back(failed_idx);
570 visited.insert(failed_idx);
571
572 while let Some(current_idx) = stack.pop_front() {
573 if let Some(node) = workflow.graph.node_weight(current_idx) {
575 let task_id = node.id().clone();
576 dependent_set.insert(task_id);
577 }
578
579 for neighbor in workflow
583 .graph
584 .neighbors_directed(current_idx, Direction::Outgoing)
585 {
586 if !visited.contains(&neighbor) {
587 visited.insert(neighbor);
588 stack.push_back(neighbor);
589 }
590 }
591 }
592
593 Ok(dependent_set)
594 }
595
596 fn reverse_execution_order(
610 &self,
611 workflow: &Workflow,
612 tasks: HashSet<TaskId>,
613 ) -> Result<Vec<TaskId>, RollbackError> {
614 let execution_order = workflow
616 .execution_order()
617 .map_err(|e| RollbackError::TraversalError(e.to_string()))?;
618
619 let position_map: HashMap<TaskId, usize> = execution_order
621 .iter()
622 .enumerate()
623 .map(|(pos, task_id)| (task_id.clone(), pos))
624 .collect();
625
626 let mut rollback_tasks: Vec<TaskId> = tasks.into_iter().collect();
628
629 rollback_tasks.sort_by(|a, b| {
631 let pos_a = position_map.get(a).copied().unwrap_or(0);
632 let pos_b = position_map.get(b).copied().unwrap_or(0);
633 pos_b.cmp(&pos_a) });
635
636 Ok(rollback_tasks)
637 }
638
639 pub async fn execute_rollback(
657 &self,
658 workflow: &Workflow,
659 tasks: Vec<TaskId>,
660 workflow_id: &str,
661 audit_log: &mut AuditLog,
662 compensation_registry: &CompensationRegistry,
663 ) -> Result<RollbackReport, RollbackError> {
664 let mut report = RollbackReport::new();
665
666 for task_id in &tasks {
667 let node_idx = workflow
669 .task_map
670 .get(task_id)
671 .ok_or_else(|| RollbackError::TaskNotFound(task_id.clone()))?;
672
673 let _node = workflow
674 .graph
675 .node_weight(*node_idx)
676 .expect("Node index should be valid");
677
678 if let Some(compensation) = compensation_registry.get(task_id) {
680 let context = TaskContext::new(workflow_id, task_id.clone());
682
683 match compensation.execute(&context) {
685 Ok(_) => {
686 let _ = audit_log
688 .record(crate::audit::AuditEvent::WorkflowTaskRolledBack {
689 timestamp: Utc::now(),
690 workflow_id: workflow_id.to_string(),
691 task_id: task_id.to_string(),
692 compensation: compensation.description.clone(),
693 })
694 .await;
695
696 report.rolled_back_tasks.push(task_id.clone());
697 }
698 Err(e) => {
699 let error_msg = e.to_string();
701 let _ = audit_log
702 .record(crate::audit::AuditEvent::WorkflowTaskRolledBack {
703 timestamp: Utc::now(),
704 workflow_id: workflow_id.to_string(),
705 task_id: task_id.to_string(),
706 compensation: format!("Failed: {}", error_msg),
707 })
708 .await;
709
710 report.failed_compensations.push((task_id.clone(), error_msg));
711 }
712 }
713 } else {
714 report.skipped_tasks.push(task_id.clone());
716
717 let _ = audit_log
719 .record(crate::audit::AuditEvent::WorkflowTaskRolledBack {
720 timestamp: Utc::now(),
721 workflow_id: workflow_id.to_string(),
722 task_id: task_id.to_string(),
723 compensation: "No compensation registered".to_string(),
724 })
725 .await;
726 }
727 }
728
729 let _ = audit_log
731 .record(crate::audit::AuditEvent::WorkflowRolledBack {
732 timestamp: Utc::now(),
733 workflow_id: workflow_id.to_string(),
734 reason: "Task failure triggered rollback".to_string(),
735 rolled_back_tasks: tasks.iter().map(|id| id.to_string()).collect(),
736 })
737 .await;
738
739 Ok(report)
740 }
741
742 pub fn validate_compensation_coverage(
754 &self,
755 workflow: &Workflow,
756 ) -> CompensationReport {
757 let total_tasks = workflow.task_count();
758 let with_compensation = Vec::new();
759 let mut without_compensation = Vec::new();
760
761 for task_id in workflow.task_ids() {
765 without_compensation.push(task_id);
767 }
768
769 let coverage = CompensationReport::calculate(with_compensation.len(), total_tasks);
770
771 CompensationReport {
772 tasks_with_compensation: with_compensation,
773 tasks_without_compensation: without_compensation,
774 coverage_percentage: coverage,
775 }
776 }
777}
778
779impl Default for RollbackEngine {
780 fn default() -> Self {
781 Self::new()
782 }
783}
784
785#[cfg(test)]
786mod tests {
787 use super::*;
788 use crate::workflow::task::{TaskContext, TaskError, TaskResult, WorkflowTask};
789 use async_trait::async_trait;
790 use std::fs::File;
791 use std::io::Write;
792
793 #[test]
794 fn test_tool_compensation_creation() {
795 let comp = ToolCompensation::new("Test compensation", |_ctx| Ok(TaskResult::Success));
796 assert_eq!(comp.description, "Test compensation");
797 }
798
799 #[test]
800 fn test_tool_compensation_execute() {
801 let comp = ToolCompensation::new("Execute test", |_ctx| Ok(TaskResult::Success));
802 let context = TaskContext::new("test", TaskId::new("a"));
803 let result = comp.execute(&context).unwrap();
804 assert_eq!(result, TaskResult::Success);
805 }
806
807 #[test]
808 fn test_tool_compensation_execute_error() {
809 let comp = ToolCompensation::new("Execute test", |_ctx| {
810 Err(TaskError::ExecutionFailed("Test error".to_string()))
811 });
812 let context = TaskContext::new("test", TaskId::new("a"));
813 let result = comp.execute(&context);
814 assert!(result.is_err());
815 }
816
817 #[test]
818 fn test_tool_compensation_skip() {
819 let comp = ToolCompensation::skip("No action needed");
820 let context = TaskContext::new("test", TaskId::new("a"));
821 let result = comp.execute(&context).unwrap();
822 assert_eq!(result, TaskResult::Skipped);
823 }
824
825 #[test]
826 fn test_tool_compensation_retry() {
827 let comp = ToolCompensation::retry("Retry recommended");
828 let context = TaskContext::new("test", TaskId::new("a"));
829 let result = comp.execute(&context).unwrap();
830 assert_eq!(result, TaskResult::Skipped);
831 }
832
833 #[test]
834 fn test_tool_compensation_file() {
835 let temp_file = "/tmp/test_tool_compensation.txt";
837 let mut file = File::create(temp_file).unwrap();
838 writeln!(file, "test content").unwrap();
839 drop(file);
840
841 assert!(Path::new(temp_file).exists());
843
844 let comp = ToolCompensation::file_compensation(temp_file);
846 let context = TaskContext::new("test", TaskId::new("a"));
847 let result = comp.execute(&context);
848
849 assert!(result.is_ok());
850 assert!(!Path::new(temp_file).exists()); }
852
853 #[test]
854 fn test_tool_compensation_from_compensation_action() {
855 let skip_action = CompensationAction::skip("Skip action");
856 let skip_comp: ToolCompensation = skip_action.into();
857 assert_eq!(skip_comp.description, "Skip action");
858
859 let retry_action = CompensationAction::retry("Retry action");
860 let retry_comp: ToolCompensation = retry_action.into();
861 assert_eq!(retry_comp.description, "Retry action");
862
863 let undo_action = CompensationAction::undo("Undo action");
864 let undo_comp: ToolCompensation = undo_action.into();
865 assert!(undo_comp.description.contains("no undo function available"));
866 }
867
868 #[test]
869 fn test_compensation_registry_new() {
870 let registry = CompensationRegistry::new();
871 assert!(registry.is_empty());
872 assert_eq!(registry.len(), 0);
873 }
874
875 #[test]
876 fn test_compensation_registry_register() {
877 let mut registry = CompensationRegistry::new();
878 let task_id = TaskId::new("task-1");
879 let comp = ToolCompensation::skip("Test");
880
881 registry.register(task_id.clone(), comp);
882
883 assert_eq!(registry.len(), 1);
884 assert!(registry.has_compensation(&task_id));
885 }
886
887 #[test]
888 fn test_compensation_registry_get() {
889 let mut registry = CompensationRegistry::new();
890 let task_id = TaskId::new("task-1");
891 let comp = ToolCompensation::new("Test", |_ctx| Ok(TaskResult::Success));
892
893 registry.register(task_id.clone(), comp);
894
895 let retrieved = registry.get(&task_id);
896 assert!(retrieved.is_some());
897 assert_eq!(retrieved.unwrap().description, "Test");
898
899 let missing = registry.get(&TaskId::new("missing"));
901 assert!(missing.is_none());
902 }
903
904 #[test]
905 fn test_compensation_registry_remove() {
906 let mut registry = CompensationRegistry::new();
907 let task_id = TaskId::new("task-1");
908 let comp = ToolCompensation::skip("Test");
909
910 registry.register(task_id.clone(), comp);
911 assert_eq!(registry.len(), 1);
912
913 let removed = registry.remove(&task_id);
914 assert!(removed.is_some());
915 assert_eq!(registry.len(), 0);
916 assert!(!registry.has_compensation(&task_id));
917
918 let removed_again = registry.remove(&task_id);
920 assert!(removed_again.is_none());
921 }
922
923 #[test]
924 fn test_compensation_registry_validate_coverage() {
925 let mut registry = CompensationRegistry::new();
926
927 let task1 = TaskId::new("task-1");
928 let task2 = TaskId::new("task-2");
929 let task3 = TaskId::new("task-3");
930
931 registry.register(task1.clone(), ToolCompensation::skip("Test 1"));
932 registry.register(task2.clone(), ToolCompensation::skip("Test 2"));
933
934 let report = registry.validate_coverage(&[task1.clone(), task2.clone(), task3.clone()]);
935
936 assert_eq!(report.tasks_with_compensation.len(), 2);
937 assert!(report.tasks_with_compensation.contains(&task1));
938 assert!(report.tasks_with_compensation.contains(&task2));
939
940 assert_eq!(report.tasks_without_compensation.len(), 1);
941 assert!(report.tasks_without_compensation.contains(&task3));
942
943 assert!((report.coverage_percentage - 0.666).abs() < 0.01);
944 }
945
946 #[test]
947 fn test_compensation_registry_register_file_creation() {
948 let mut registry = CompensationRegistry::new();
949 let task_id = TaskId::new("task-1");
950
951 registry.register_file_creation(task_id.clone(), "/tmp/test.txt");
952
953 assert!(registry.has_compensation(&task_id));
954 let comp = registry.get(&task_id).unwrap();
955 assert!(comp.description.contains("Delete file"));
956 }
957
958 #[test]
959 fn test_compensation_registry_register_process_spawn() {
960 let mut registry = CompensationRegistry::new();
961 let task_id = TaskId::new("task-1");
962
963 registry.register_process_spawn(task_id.clone(), 12345);
964
965 assert!(registry.has_compensation(&task_id));
966 let comp = registry.get(&task_id).unwrap();
967 assert!(comp.description.contains("Terminate process"));
968 }
969
970 #[test]
971 fn test_compensation_registry_task_ids() {
972 let mut registry = CompensationRegistry::new();
973
974 let task1 = TaskId::new("task-1");
975 let task2 = TaskId::new("task-2");
976
977 registry.register(task1.clone(), ToolCompensation::skip("Test 1"));
978 registry.register(task2.clone(), ToolCompensation::skip("Test 2"));
979
980 let ids = registry.task_ids();
981 assert_eq!(ids.len(), 2);
982 assert!(ids.contains(&task1));
983 assert!(ids.contains(&task2));
984 }
985
986 #[test]
987 fn test_compensation_registry_default() {
988 let registry = CompensationRegistry::default();
989 assert!(registry.is_empty());
990 }
991
992 struct MockTaskWithCompensation {
994 id: TaskId,
995 name: String,
996 deps: Vec<TaskId>,
997 compensation: Option<CompensationAction>,
998 }
999
1000 impl MockTaskWithCompensation {
1001 fn new(id: impl Into<TaskId>, name: &str) -> Self {
1002 Self {
1003 id: id.into(),
1004 name: name.to_string(),
1005 deps: Vec::new(),
1006 compensation: None,
1007 }
1008 }
1009
1010 fn with_dep(mut self, dep: impl Into<TaskId>) -> Self {
1011 self.deps.push(dep.into());
1012 self
1013 }
1014
1015 fn with_compensation(mut self, action: CompensationAction) -> Self {
1016 self.compensation = Some(action);
1017 self
1018 }
1019 }
1020
1021 #[async_trait]
1022 impl WorkflowTask for MockTaskWithCompensation {
1023 async fn execute(&self, _context: &TaskContext) -> Result<TaskResult, TaskError> {
1024 Ok(TaskResult::Success)
1025 }
1026
1027 fn id(&self) -> TaskId {
1028 self.id.clone()
1029 }
1030
1031 fn name(&self) -> &str {
1032 &self.name
1033 }
1034
1035 fn dependencies(&self) -> Vec<TaskId> {
1036 self.deps.clone()
1037 }
1038 }
1039
1040 #[test]
1041 fn test_compensation_action_creation() {
1042 let skip = CompensationAction::skip("Read-only operation");
1043 assert_eq!(skip.action_type, CompensationType::Skip);
1044 assert_eq!(skip.description, "Read-only operation");
1045
1046 let retry = CompensationAction::retry("Transient network error");
1047 assert_eq!(retry.action_type, CompensationType::Retry);
1048
1049 let undo = CompensationAction::undo("Delete file");
1050 assert_eq!(undo.action_type, CompensationType::UndoFunction);
1051 }
1052
1053 #[test]
1054 fn test_executable_compensation_creation() {
1055 let skip = ExecutableCompensation::skip("No action needed");
1056 assert_eq!(skip.action.action_type, CompensationType::Skip);
1057
1058 let retry = ExecutableCompensation::retry("Retry later");
1059 assert_eq!(retry.action.action_type, CompensationType::Retry);
1060
1061 let undo = ExecutableCompensation::with_undo("Execute undo", |_ctx| {
1062 Ok(TaskResult::Success)
1063 });
1064 assert_eq!(undo.action.action_type, CompensationType::UndoFunction);
1065 }
1066
1067 #[test]
1068 fn test_executable_compensation_execute() {
1069 let skip = ExecutableCompensation::skip("No action needed");
1070 let context = TaskContext::new("test", TaskId::new("a"));
1071 let result = skip.execute(&context).unwrap();
1072 assert_eq!(result, TaskResult::Skipped);
1073
1074 let retry = ExecutableCompensation::retry("Retry later");
1075 let result = retry.execute(&context).unwrap();
1076 assert_eq!(result, TaskResult::Skipped);
1077
1078 let undo = ExecutableCompensation::with_undo("Execute undo", |_ctx| {
1079 Ok(TaskResult::Success)
1080 });
1081 let result = undo.execute(&context).unwrap();
1082 assert_eq!(result, TaskResult::Success);
1083 }
1084
1085 #[test]
1086 fn test_rollback_engine_creation() {
1087 let engine = RollbackEngine::new();
1088 let _ = &engine; }
1090
1091 #[tokio::test]
1092 async fn test_rollback_report_creation() {
1093 let report = RollbackReport::new();
1094 assert_eq!(report.total_processed(), 0);
1095 assert!(report.rolled_back_tasks.is_empty());
1096 assert!(report.skipped_tasks.is_empty());
1097 assert!(report.failed_compensations.is_empty());
1098 }
1099
1100 #[test]
1101 fn test_compensation_report_calculation() {
1102 let coverage = CompensationReport::calculate(5, 10);
1103 assert_eq!(coverage, 0.5);
1104
1105 let full_coverage = CompensationReport::calculate(10, 10);
1106 assert_eq!(full_coverage, 1.0);
1107
1108 let no_tasks = CompensationReport::calculate(0, 0);
1109 assert_eq!(no_tasks, 1.0); }
1111
1112 #[test]
1113 fn test_find_dependent_tasks() {
1114 let mut workflow = Workflow::new();
1115
1116 workflow.add_task(Box::new(MockTaskWithCompensation::new("a", "Task A")));
1118 workflow.add_task(Box::new(MockTaskWithCompensation::new("b", "Task B")));
1119 workflow.add_task(Box::new(MockTaskWithCompensation::new("c", "Task C")));
1120 workflow.add_task(Box::new(MockTaskWithCompensation::new("d", "Task D")));
1121
1122 workflow.add_dependency("a", "b").unwrap();
1123 workflow.add_dependency("a", "c").unwrap();
1124 workflow.add_dependency("b", "d").unwrap();
1125 workflow.add_dependency("c", "d").unwrap();
1126
1127 let engine = RollbackEngine::new();
1128 let failed_idx = *workflow.task_map.get(&TaskId::new("d")).unwrap();
1129
1130 let dependents = engine.find_dependent_tasks(&workflow, failed_idx).unwrap();
1132
1133 assert_eq!(dependents.len(), 1);
1136 assert!(dependents.contains(&TaskId::new("d")));
1137 }
1138
1139 #[test]
1140 fn test_diamond_dependency_rollback() {
1141 let mut workflow = Workflow::new();
1142
1143 workflow.add_task(Box::new(MockTaskWithCompensation::new("a", "Task A")));
1145 workflow.add_task(Box::new(MockTaskWithCompensation::new("b", "Task B")));
1146 workflow.add_task(Box::new(MockTaskWithCompensation::new("c", "Task C")));
1147 workflow.add_task(Box::new(MockTaskWithCompensation::new("d", "Task D")));
1148
1149 workflow.add_dependency("a", "b").unwrap();
1150 workflow.add_dependency("a", "c").unwrap();
1151 workflow.add_dependency("b", "d").unwrap();
1152 workflow.add_dependency("c", "d").unwrap();
1153
1154 let engine = RollbackEngine::new();
1155
1156 let rollback_set = engine
1158 .find_rollback_set(&workflow, &TaskId::new("d"), RollbackStrategy::AllDependent)
1159 .unwrap();
1160
1161 assert_eq!(rollback_set.len(), 1);
1163 assert_eq!(rollback_set[0], TaskId::new("d"));
1164 }
1165
1166 #[test]
1167 fn test_reverse_execution_order() {
1168 let mut workflow = Workflow::new();
1169
1170 workflow.add_task(Box::new(MockTaskWithCompensation::new("a", "Task A")));
1172 workflow.add_task(Box::new(MockTaskWithCompensation::new("b", "Task B")));
1173 workflow.add_task(Box::new(MockTaskWithCompensation::new("c", "Task C")));
1174
1175 workflow.add_dependency("a", "b").unwrap();
1176 workflow.add_dependency("b", "c").unwrap();
1177
1178 let engine = RollbackEngine::new();
1179 let failed_idx = *workflow.task_map.get(&TaskId::new("c")).unwrap();
1180
1181 let dependents = engine.find_dependent_tasks(&workflow, failed_idx).unwrap();
1182 let rollback_order = engine.reverse_execution_order(&workflow, dependents).unwrap();
1183
1184 assert_eq!(rollback_order.len(), 1);
1187 assert_eq!(rollback_order[0], TaskId::new("c"));
1188 }
1189
1190 #[tokio::test]
1191 async fn test_execute_rollback() {
1192 let mut workflow = Workflow::new();
1193
1194 workflow.add_task(Box::new(MockTaskWithCompensation::new("a", "Task A")));
1195 workflow.add_task(Box::new(MockTaskWithCompensation::new("b", "Task B")));
1196
1197 workflow.add_dependency("a", "b").unwrap();
1198
1199 let engine = RollbackEngine::new();
1200 let mut audit_log = AuditLog::new();
1201 let registry = CompensationRegistry::new();
1202
1203 let report = engine
1205 .execute_rollback(
1206 &workflow,
1207 vec![TaskId::new("b")],
1208 "test_workflow",
1209 &mut audit_log,
1210 ®istry,
1211 )
1212 .await
1213 .unwrap();
1214
1215 assert_eq!(report.skipped_tasks.len(), 1);
1217 assert_eq!(report.skipped_tasks[0], TaskId::new("b"));
1218 assert!(report.rolled_back_tasks.is_empty());
1219 assert!(report.failed_compensations.is_empty());
1220
1221 let events = audit_log.replay();
1223 assert!(events.iter().any(|e| matches!(e, crate::audit::AuditEvent::WorkflowTaskRolledBack { .. })));
1224 assert!(events.iter().any(|e| matches!(e, crate::audit::AuditEvent::WorkflowRolledBack { .. })));
1225 }
1226
1227 #[tokio::test]
1228 async fn test_execute_rollback_with_compensation() {
1229 let mut workflow = Workflow::new();
1230
1231 workflow.add_task(Box::new(MockTaskWithCompensation::new("a", "Task A")));
1232 workflow.add_task(Box::new(MockTaskWithCompensation::new("b", "Task B")));
1233
1234 workflow.add_dependency("a", "b").unwrap();
1235
1236 let engine = RollbackEngine::new();
1237 let mut audit_log = AuditLog::new();
1238 let mut registry = CompensationRegistry::new();
1239
1240 registry.register(TaskId::new("b"), ToolCompensation::skip("Test compensation"));
1242
1243 let report = engine
1245 .execute_rollback(
1246 &workflow,
1247 vec![TaskId::new("b")],
1248 "test_workflow",
1249 &mut audit_log,
1250 ®istry,
1251 )
1252 .await
1253 .unwrap();
1254
1255 assert_eq!(report.rolled_back_tasks.len(), 1);
1257 assert_eq!(report.rolled_back_tasks[0], TaskId::new("b"));
1258 assert!(report.skipped_tasks.is_empty());
1259 assert!(report.failed_compensations.is_empty());
1260
1261 let events = audit_log.replay();
1263 assert!(events.iter().any(|e| matches!(e, crate::audit::AuditEvent::WorkflowTaskRolledBack { .. })));
1264 assert!(events.iter().any(|e| matches!(e, crate::audit::AuditEvent::WorkflowRolledBack { .. })));
1265 }
1266
1267 #[tokio::test]
1268 async fn test_execute_rollback_mixed_compensation() {
1269 let mut workflow = Workflow::new();
1270
1271 workflow.add_task(Box::new(MockTaskWithCompensation::new("a", "Task A")));
1272 workflow.add_task(Box::new(MockTaskWithCompensation::new("b", "Task B")));
1273 workflow.add_task(Box::new(MockTaskWithCompensation::new("c", "Task C")));
1274
1275 workflow.add_dependency("a", "b").unwrap();
1276 workflow.add_dependency("b", "c").unwrap();
1277
1278 let engine = RollbackEngine::new();
1279 let mut audit_log = AuditLog::new();
1280 let mut registry = CompensationRegistry::new();
1281
1282 registry.register(TaskId::new("a"), ToolCompensation::skip("Test compensation"));
1284
1285 let report = engine
1287 .execute_rollback(
1288 &workflow,
1289 vec![TaskId::new("a"), TaskId::new("b"), TaskId::new("c")],
1290 "test_workflow",
1291 &mut audit_log,
1292 ®istry,
1293 )
1294 .await
1295 .unwrap();
1296
1297 assert_eq!(report.rolled_back_tasks.len(), 1);
1299 assert_eq!(report.rolled_back_tasks[0], TaskId::new("a"));
1300 assert_eq!(report.skipped_tasks.len(), 2);
1301 assert!(report.skipped_tasks.contains(&TaskId::new("b")));
1302 assert!(report.skipped_tasks.contains(&TaskId::new("c")));
1303 assert!(report.failed_compensations.is_empty());
1304 }
1305
1306 #[test]
1307 fn test_validate_compensation_coverage() {
1308 let mut workflow = Workflow::new();
1309
1310 workflow.add_task(Box::new(MockTaskWithCompensation::new("a", "Task A")));
1311 workflow.add_task(Box::new(MockTaskWithCompensation::new("b", "Task B")));
1312
1313 workflow.add_dependency("a", "b").unwrap();
1314
1315 let engine = RollbackEngine::new();
1316 let report = engine.validate_compensation_coverage(&workflow);
1317
1318 assert_eq!(report.tasks_without_compensation.len(), 2);
1320 assert_eq!(report.tasks_with_compensation.len(), 0);
1321 assert_eq!(report.coverage_percentage, 0.0);
1322 }
1323}