Skip to main content

forge_agent/workflow/
rollback.rs

1//! Rollback engine for workflow failure recovery using Saga compensation pattern.
2//!
3//! The rollback engine provides selective rollback of dependent tasks using
4//! DAG backward traversal. When a task fails, only its dependents are rolled back,
5//! leaving independent tasks in their completed state.
6//!
7//! # Saga Compensation Pattern
8//!
9//! Rollback uses the Saga pattern where each task optionally provides a
10//! compensation action that undoes its side effects:
11//! - `UndoFunction`: Executes a compensating transaction (e.g., delete created file)
12//! - `Skip`: No compensation needed (read-only operations like queries)
13//! - `Retry`: Recommends retry instead of compensation (transient failures)
14//!
15//! # Rollback Strategies
16//!
17//! - `AllDependent`: Roll back all tasks reachable from failed task (default)
18//! - `FailedOnly`: Roll back only the failed task
19//! - `Custom`: Use provided filter function for selective rollback
20
21use crate::audit::AuditLog;
22use crate::workflow::dag::Workflow;
23use crate::workflow::task::{CompensationAction, CompensationType, TaskContext, TaskId, TaskError, TaskResult, WorkflowTask};
24use chrono::Utc;
25use petgraph::graph::NodeIndex;
26use petgraph::visit::IntoNeighborsDirected;
27use petgraph::Direction;
28use std::collections::{HashMap, HashSet, VecDeque};
29use std::fs;
30use std::path::Path;
31use std::sync::Arc;
32use thiserror::Error;
33
34/// Extended compensation action with undo function support.
35///
36/// This extends the base CompensationAction with executable undo logic.
37/// The base type is serializable for audit logs, while this type adds
38/// runtime execution capabilities.
39#[derive(Clone)]
40pub struct ExecutableCompensation {
41    /// Base compensation action
42    pub action: CompensationAction,
43    /// Optional undo function (used for UndoFunction type)
44    #[allow(clippy::type_complexity)]
45    undo_fn: Option<Arc<dyn Fn(&TaskContext) -> Result<TaskResult, TaskError> + Send + Sync>>,
46}
47
48impl ExecutableCompensation {
49    /// Creates a new ExecutableCompensation from an action.
50    pub fn new(action: CompensationAction) -> Self {
51        Self {
52            action,
53            undo_fn: None,
54        }
55    }
56
57    /// Creates an UndoFunction compensation with the given undo function.
58    pub fn with_undo<F>(description: impl Into<String>, undo_fn: F) -> Self
59    where
60        F: Fn(&TaskContext) -> Result<TaskResult, TaskError> + Send + Sync + 'static,
61    {
62        Self {
63            action: CompensationAction::undo(description),
64            undo_fn: Some(Arc::new(undo_fn)),
65        }
66    }
67
68    /// Creates a Skip compensation (no undo needed).
69    pub fn skip(description: impl Into<String>) -> Self {
70        Self::new(CompensationAction::skip(description))
71    }
72
73    /// Creates a Retry compensation (recommends retry instead of undo).
74    pub fn retry(description: impl Into<String>) -> Self {
75        Self::new(CompensationAction::retry(description))
76    }
77
78    /// Executes the compensation action.
79    pub fn execute(&self, context: &TaskContext) -> Result<TaskResult, TaskError> {
80        match self.action.action_type {
81            CompensationType::UndoFunction => {
82                if let Some(undo_fn) = &self.undo_fn {
83                    undo_fn(context)
84                } else {
85                    Ok(TaskResult::Skipped)
86                }
87            }
88            CompensationType::Skip => Ok(TaskResult::Skipped),
89            CompensationType::Retry => Ok(TaskResult::Skipped),
90        }
91    }
92}
93
94impl From<ExecutableCompensation> for CompensationAction {
95    fn from(exec: ExecutableCompensation) -> Self {
96        exec.action
97    }
98}
99
100/// Compensation action for external tool side effects.
101///
102/// ToolCompensation wraps an undo function that compensates for external
103/// tool actions (file edits, process spawns, etc.) that cannot be rolled
104/// back through normal workflow operations.
105#[derive(Clone)]
106pub struct ToolCompensation {
107    /// Human-readable description of the compensation
108    pub description: String,
109    /// Optional undo function (executed during rollback)
110    #[allow(clippy::type_complexity)]
111    compensate: Arc<dyn Fn(&TaskContext) -> Result<TaskResult, TaskError> + Send + Sync>,
112}
113
114impl ToolCompensation {
115    /// Creates a new ToolCompensation with the given description and undo function.
116    pub fn new<F>(description: impl Into<String>, compensate_fn: F) -> Self
117    where
118        F: Fn(&TaskContext) -> Result<TaskResult, TaskError> + Send + Sync + 'static,
119    {
120        Self {
121            description: description.into(),
122            compensate: Arc::new(compensate_fn),
123        }
124    }
125
126    /// Executes the compensation action.
127    pub fn execute(&self, context: &TaskContext) -> Result<TaskResult, TaskError> {
128        (self.compensate)(context)
129    }
130
131    /// Creates a file deletion compensation for undoing file creation.
132    ///
133    /// # Arguments
134    ///
135    /// * `file_path` - Path to the file that will be deleted during rollback
136    ///
137    /// # Returns
138    ///
139    /// A ToolCompensation that deletes the specified file
140    ///
141    /// # Example
142    ///
143    /// ```ignore
144    /// let comp = ToolCompensation::file_compensation("/tmp/work_output.txt");
145    /// ```
146    pub fn file_compensation(file_path: impl Into<String>) -> Self {
147        let path = file_path.into();
148        Self::new(format!("Delete file: {}", path), move |_context| {
149            // Delete the file if it exists
150            if Path::new(&path).exists() {
151                fs::remove_file(&path).map_err(|e| {
152                    TaskError::ExecutionFailed(format!("Failed to delete file {}: {}", path, e))
153                })?;
154            }
155            Ok(TaskResult::Success)
156        })
157    }
158
159    /// Creates a process termination compensation for undoing process spawns.
160    ///
161    /// # Arguments
162    ///
163    /// * `pid` - Process ID to terminate
164    ///
165    /// # Returns
166    ///
167    /// A ToolCompensation that terminates the specified process
168    ///
169    /// # Example
170    ///
171    /// ```ignore
172    /// let comp = ToolCompensation::process_compensation(12345);
173    /// ```
174    pub fn process_compensation(pid: u32) -> Self {
175        Self::new(format!("Terminate process: {}", pid), move |_context| {
176            // Try to kill the process gracefully
177            #[cfg(unix)]
178            {
179                use std::process::Command;
180                let result = Command::new("kill")
181                    .arg("-TERM")
182                    .arg(pid.to_string())
183                    .output();
184
185                match result {
186                    Ok(_) => Ok(TaskResult::Success),
187                    Err(e) => Ok(TaskResult::Failed(format!("Failed to terminate process {}: {}", pid, e))),
188                }
189            }
190
191            #[cfg(not(unix))]
192            {
193                Ok(TaskResult::Failed(format!("Process termination not supported on this platform")))
194            }
195        })
196    }
197
198    /// Creates a skip compensation (no undo needed).
199    ///
200    /// Used for tasks that don't have side effects or don't need compensation.
201    pub fn skip(description: impl Into<String>) -> Self {
202        Self::new(description, |_context| Ok(TaskResult::Skipped))
203    }
204
205    /// Creates a retry compensation (recommends retry instead of undo).
206    ///
207    /// Used for transient failures where retry is preferred over compensation.
208    pub fn retry(description: impl Into<String>) -> Self {
209        Self::new(description, |_context| Ok(TaskResult::Skipped))
210    }
211}
212
213impl From<CompensationAction> for ToolCompensation {
214    fn from(action: CompensationAction) -> Self {
215        match action.action_type {
216            CompensationType::Skip => ToolCompensation::skip(action.description),
217            CompensationType::Retry => ToolCompensation::retry(action.description),
218            CompensationType::UndoFunction => {
219                // Note: Can't create undo from serializable action
220                // This is a no-op compensation
221                ToolCompensation::skip(format!(
222                    "{} (no undo function available)",
223                    action.description
224                ))
225            }
226        }
227    }
228}
229
230/// Registry for tracking compensation actions for workflow tasks.
231///
232/// CompensationRegistry maintains a mapping from task IDs to their
233/// corresponding compensation actions. During rollback, the registry
234/// is consulted to find and execute compensations in reverse order.
235pub struct CompensationRegistry {
236    /// Mapping of task IDs to their compensation actions
237    compensations: HashMap<TaskId, ToolCompensation>,
238}
239
240impl CompensationRegistry {
241    /// Creates a new empty compensation registry.
242    pub fn new() -> Self {
243        Self {
244            compensations: HashMap::new(),
245        }
246    }
247
248    /// Registers a compensation action for a task.
249    ///
250    /// # Arguments
251    ///
252    /// * `task_id` - The task ID to register compensation for
253    /// * `compensation` - The compensation action to register
254    ///
255    /// # Example
256    ///
257    /// ```ignore
258    /// registry.register(
259    ///     TaskId::new("task-1"),
260    ///     ToolCompensation::file_compensation("/tmp/output.txt")
261    /// );
262    /// ```
263    pub fn register(&mut self, task_id: TaskId, compensation: ToolCompensation) {
264        self.compensations.insert(task_id, compensation);
265    }
266
267    /// Gets the compensation action for a task.
268    ///
269    /// # Arguments
270    ///
271    /// * `task_id` - The task ID to look up
272    ///
273    /// # Returns
274    ///
275    /// - `Some(&ToolCompensation)` if the task has a compensation
276    /// - `None` if the task has no compensation registered
277    pub fn get(&self, task_id: &TaskId) -> Option<&ToolCompensation> {
278        self.compensations.get(task_id)
279    }
280
281    /// Checks if a task has a compensation action registered.
282    ///
283    /// # Arguments
284    ///
285    /// * `task_id` - The task ID to check
286    ///
287    /// # Returns
288    ///
289    /// - `true` if the task has compensation
290    /// - `false` if the task has no compensation
291    pub fn has_compensation(&self, task_id: &TaskId) -> bool {
292        self.compensations.contains_key(task_id)
293    }
294
295    /// Removes a compensation action from the registry.
296    ///
297    /// Typically called after successful rollback execution.
298    ///
299    /// # Arguments
300    ///
301    /// * `task_id` - The task ID to remove compensation for
302    ///
303    /// # Returns
304    ///
305    /// - `Some(ToolCompensation)` if the task had compensation
306    /// - `None` if the task had no compensation
307    pub fn remove(&mut self, task_id: &TaskId) -> Option<ToolCompensation> {
308        self.compensations.remove(task_id)
309    }
310
311    /// Validates compensation coverage for a set of tasks.
312    ///
313    /// Reports which tasks have compensation actions defined and which don't.
314    ///
315    /// # Arguments
316    ///
317    /// * `task_ids` - The task IDs to validate
318    ///
319    /// # Returns
320    ///
321    /// A CompensationReport showing coverage statistics
322    pub fn validate_coverage(&self, task_ids: &[TaskId]) -> CompensationReport {
323        let mut with_compensation = Vec::new();
324        let mut without_compensation = Vec::new();
325
326        for task_id in task_ids {
327            if self.has_compensation(task_id) {
328                with_compensation.push(task_id.clone());
329            } else {
330                without_compensation.push(task_id.clone());
331            }
332        }
333
334        let total = task_ids.len();
335        let coverage = CompensationReport::calculate(with_compensation.len(), total);
336
337        CompensationReport {
338            tasks_with_compensation: with_compensation,
339            tasks_without_compensation: without_compensation,
340            coverage_percentage: coverage,
341        }
342    }
343
344    /// Registers a file creation compensation for a task.
345    ///
346    /// Convenience method that automatically creates a file deletion compensation.
347    ///
348    /// # Arguments
349    ///
350    /// * `task_id` - The task ID to register compensation for
351    /// * `file_path` - Path to the file that will be deleted during rollback
352    pub fn register_file_creation(&mut self, task_id: TaskId, file_path: impl Into<String>) {
353        self.register(task_id, ToolCompensation::file_compensation(file_path));
354    }
355
356    /// Registers a process spawn compensation for a task.
357    ///
358    /// Convenience method that automatically creates a process termination compensation.
359    ///
360    /// # Arguments
361    ///
362    /// * `task_id` - The task ID to register compensation for
363    /// * `pid` - Process ID to terminate during rollback
364    pub fn register_process_spawn(&mut self, task_id: TaskId, pid: u32) {
365        self.register(task_id, ToolCompensation::process_compensation(pid));
366    }
367
368    /// Returns the number of registered compensations.
369    pub fn len(&self) -> usize {
370        self.compensations.len()
371    }
372
373    /// Returns true if the registry is empty.
374    pub fn is_empty(&self) -> bool {
375        self.compensations.is_empty()
376    }
377
378    /// Returns all task IDs with registered compensations.
379    pub fn task_ids(&self) -> Vec<TaskId> {
380        self.compensations.keys().cloned().collect()
381    }
382}
383
384impl Default for CompensationRegistry {
385    fn default() -> Self {
386        Self::new()
387    }
388}
389
390/// Rollback strategy for determining which tasks to roll back.
391#[derive(Clone, Copy, Debug, PartialEq, Eq)]
392pub enum RollbackStrategy {
393    /// Roll back all tasks reachable from failed task (all dependents)
394    AllDependent,
395    /// Roll back only the failed task
396    FailedOnly,
397    /// Use custom filter function
398    Custom,
399}
400
401/// Report from rollback execution.
402#[derive(Clone, Debug)]
403pub struct RollbackReport {
404    /// Tasks that were successfully rolled back
405    pub rolled_back_tasks: Vec<TaskId>,
406    /// Tasks that were skipped (no compensation defined)
407    pub skipped_tasks: Vec<TaskId>,
408    /// Tasks where compensation failed
409    pub failed_compensations: Vec<(TaskId, String)>,
410}
411
412impl RollbackReport {
413    /// Creates a new empty rollback report.
414    fn new() -> Self {
415        Self {
416            rolled_back_tasks: Vec::new(),
417            skipped_tasks: Vec::new(),
418            failed_compensations: Vec::new(),
419        }
420    }
421
422    /// Returns the total number of tasks processed.
423    pub fn total_processed(&self) -> usize {
424        self.rolled_back_tasks.len() + self.skipped_tasks.len() + self.failed_compensations.len()
425    }
426}
427
428/// Report from compensation coverage validation.
429#[derive(Clone, Debug)]
430pub struct CompensationReport {
431    /// Tasks that have compensation defined
432    pub tasks_with_compensation: Vec<TaskId>,
433    /// Tasks that lack compensation
434    pub tasks_without_compensation: Vec<TaskId>,
435    /// Percentage of tasks with compensation (0.0 to 1.0)
436    pub coverage_percentage: f64,
437}
438
439impl CompensationReport {
440    /// Calculates coverage percentage from task counts.
441    fn calculate(with_compensation: usize, total: usize) -> f64 {
442        if total == 0 {
443            1.0
444        } else {
445            with_compensation as f64 / total as f64
446        }
447    }
448}
449
450/// Errors that can occur during rollback.
451#[derive(Error, Debug)]
452pub enum RollbackError {
453    /// Failed to find rollback set (DAG traversal error)
454    #[error("Failed to determine rollback set: {0}")]
455    RollbackSetFailed(String),
456
457    /// Task not found in workflow during rollback
458    #[error("Task not found during rollback: {0}")]
459    TaskNotFound(TaskId),
460
461    /// Compensation execution failed
462    #[error("Compensation failed for task {0}: {1}")]
463    CompensationFailed(TaskId, String),
464
465    /// Graph traversal error
466    #[error("Graph traversal error: {0}")]
467    TraversalError(String),
468}
469
470/// Rollback engine for workflow failure recovery.
471///
472/// The rollback engine implements the Saga compensation pattern using
473/// DAG backward traversal to selectively roll back dependent tasks.
474pub struct RollbackEngine {
475    _private: (),
476}
477
478impl RollbackEngine {
479    /// Creates a new rollback engine.
480    pub fn new() -> Self {
481        Self { _private: () }
482    }
483
484    /// Finds the set of tasks to roll back based on failure and strategy.
485    ///
486    /// Uses reverse graph traversal starting from the failed task to find
487    /// all dependent tasks. The rollback order is reverse execution order.
488    ///
489    /// # Arguments
490    ///
491    /// * `workflow` - The workflow to analyze
492    /// * `failed_task` - The task that failed (rollback origin)
493    /// * `strategy` - Rollback strategy to apply
494    ///
495    /// # Returns
496    ///
497    /// - `Ok(Vec<TaskId>)` - Tasks in rollback order (reverse execution)
498    /// - `Err(RollbackError)` - If traversal fails
499    ///
500    /// # Example
501    ///
502    /// ```ignore
503    /// let engine = RollbackEngine::new();
504    /// let rollback_set = engine.find_rollback_set(
505    ///     &workflow,
506    ///     &TaskId::new("task_d"),
507    ///     RollbackStrategy::AllDependent
508    /// )?;
509    /// ```
510    pub fn find_rollback_set(
511        &self,
512        workflow: &Workflow,
513        failed_task: &TaskId,
514        strategy: RollbackStrategy,
515    ) -> Result<Vec<TaskId>, RollbackError> {
516        // Find failed task node index
517        let failed_idx = *workflow
518            .task_map
519            .get(failed_task)
520            .ok_or_else(|| RollbackError::TaskNotFound(failed_task.clone()))?;
521
522        match strategy {
523            RollbackStrategy::FailedOnly => {
524                // Only roll back the failed task
525                Ok(vec![failed_task.clone()])
526            }
527            RollbackStrategy::AllDependent => {
528                // Find all nodes reachable from failed task in reverse graph
529                let dependent_set = self.find_dependent_tasks(workflow, failed_idx)?;
530                // Sort in reverse execution order
531                self.reverse_execution_order(workflow, dependent_set)
532            }
533            RollbackStrategy::Custom => {
534                // Custom strategy not yet implemented
535                // For now, treat as AllDependent
536                let dependent_set = self.find_dependent_tasks(workflow, failed_idx)?;
537                self.reverse_execution_order(workflow, dependent_set)
538            }
539        }
540    }
541
542    /// Finds all tasks dependent on the failed task using forward traversal.
543    ///
544    /// Traverses the graph following edges from the failed task to find
545    /// all nodes that depend on it (directly or transitively).
546    ///
547    /// In the DAG a -> b, edge direction is "a executes before b",
548    /// so b depends on a. If a fails, we traverse forward to find b.
549    ///
550    /// # Arguments
551    ///
552    /// * `workflow` - The workflow to analyze
553    /// * `failed_idx` - Node index of the failed task
554    ///
555    /// # Returns
556    ///
557    /// Set of TaskIds that depend on the failed task
558    fn find_dependent_tasks(
559        &self,
560        workflow: &Workflow,
561        failed_idx: NodeIndex,
562    ) -> Result<HashSet<TaskId>, RollbackError> {
563        let mut dependent_set = HashSet::new();
564        let mut visited = HashSet::new();
565        let mut stack = VecDeque::new();
566
567        // Start from failed task, traverse forward edges (outgoing direction)
568        // to find all tasks that depend on the failed task
569        stack.push_back(failed_idx);
570        visited.insert(failed_idx);
571
572        while let Some(current_idx) = stack.pop_front() {
573            // Get node weight to extract TaskId
574            if let Some(node) = workflow.graph.node_weight(current_idx) {
575                let task_id = node.id().clone();
576                dependent_set.insert(task_id);
577            }
578
579            // Find all nodes that depend on current node
580            // Edges go FROM prerequisite TO dependent
581            // So Outgoing neighbors are the dependents
582            for neighbor in workflow
583                .graph
584                .neighbors_directed(current_idx, Direction::Outgoing)
585            {
586                if !visited.contains(&neighbor) {
587                    visited.insert(neighbor);
588                    stack.push_back(neighbor);
589                }
590            }
591        }
592
593        Ok(dependent_set)
594    }
595
596    /// Sorts tasks in reverse execution order (for correct rollback).
597    ///
598    /// Rollback must execute in reverse order of execution to maintain
599    /// dependency correctness (later tasks rolled back before earlier tasks).
600    ///
601    /// # Arguments
602    ///
603    /// * `workflow` - The workflow for execution order
604    /// * `tasks` - Tasks to sort
605    ///
606    /// # Returns
607    ///
608    /// Tasks sorted in reverse execution order
609    fn reverse_execution_order(
610        &self,
611        workflow: &Workflow,
612        tasks: HashSet<TaskId>,
613    ) -> Result<Vec<TaskId>, RollbackError> {
614        // Get execution order
615        let execution_order = workflow
616            .execution_order()
617            .map_err(|e| RollbackError::TraversalError(e.to_string()))?;
618
619        // Create position map for O(1) lookup
620        let position_map: HashMap<TaskId, usize> = execution_order
621            .iter()
622            .enumerate()
623            .map(|(pos, task_id)| (task_id.clone(), pos))
624            .collect();
625
626        // Filter to only tasks in the rollback set
627        let mut rollback_tasks: Vec<TaskId> = tasks.into_iter().collect();
628
629        // Sort by reverse execution order (higher position first)
630        rollback_tasks.sort_by(|a, b| {
631            let pos_a = position_map.get(a).copied().unwrap_or(0);
632            let pos_b = position_map.get(b).copied().unwrap_or(0);
633            pos_b.cmp(&pos_a) // Reverse order
634        });
635
636        Ok(rollback_tasks)
637    }
638
639    /// Executes rollback for a set of tasks.
640    ///
641    /// Executes compensation actions for each task in rollback order.
642    /// Tasks without compensation are skipped. Failed compensations are logged.
643    ///
644    /// # Arguments
645    ///
646    /// * `workflow` - The workflow being rolled back
647    /// * `tasks` - Tasks to roll back (in rollback order)
648    /// * `workflow_id` - Workflow ID for audit logging
649    /// * `audit_log` - Audit log for recording rollback events
650    /// * `compensation_registry` - Registry containing compensation actions
651    ///
652    /// # Returns
653    ///
654    /// - `Ok(RollbackReport)` - Report of rollback execution
655    /// - `Err(RollbackError)` - If critical failure occurs
656    pub async fn execute_rollback(
657        &self,
658        workflow: &Workflow,
659        tasks: Vec<TaskId>,
660        workflow_id: &str,
661        audit_log: &mut AuditLog,
662        compensation_registry: &CompensationRegistry,
663    ) -> Result<RollbackReport, RollbackError> {
664        let mut report = RollbackReport::new();
665
666        for task_id in &tasks {
667            // Get task node
668            let node_idx = workflow
669                .task_map
670                .get(task_id)
671                .ok_or_else(|| RollbackError::TaskNotFound(task_id.clone()))?;
672
673            let _node = workflow
674                .graph
675                .node_weight(*node_idx)
676                .expect("Node index should be valid");
677
678            // Try to get compensation from registry
679            if let Some(compensation) = compensation_registry.get(task_id) {
680                // Create task context for compensation execution
681                let context = TaskContext::new(workflow_id, task_id.clone());
682
683                // Execute the compensation
684                match compensation.execute(&context) {
685                    Ok(_) => {
686                        // Record successful rollback in audit log
687                        let _ = audit_log
688                            .record(crate::audit::AuditEvent::WorkflowTaskRolledBack {
689                                timestamp: Utc::now(),
690                                workflow_id: workflow_id.to_string(),
691                                task_id: task_id.to_string(),
692                                compensation: compensation.description.clone(),
693                            })
694                            .await;
695
696                        report.rolled_back_tasks.push(task_id.clone());
697                    }
698                    Err(e) => {
699                        // Record compensation failure
700                        let error_msg = e.to_string();
701                        let _ = audit_log
702                            .record(crate::audit::AuditEvent::WorkflowTaskRolledBack {
703                                timestamp: Utc::now(),
704                                workflow_id: workflow_id.to_string(),
705                                task_id: task_id.to_string(),
706                                compensation: format!("Failed: {}", error_msg),
707                            })
708                            .await;
709
710                        report.failed_compensations.push((task_id.clone(), error_msg));
711                    }
712                }
713            } else {
714                // No compensation registered - skip task
715                report.skipped_tasks.push(task_id.clone());
716
717                // Record skipped rollback in audit log
718                let _ = audit_log
719                    .record(crate::audit::AuditEvent::WorkflowTaskRolledBack {
720                        timestamp: Utc::now(),
721                        workflow_id: workflow_id.to_string(),
722                        task_id: task_id.to_string(),
723                        compensation: "No compensation registered".to_string(),
724                    })
725                    .await;
726            }
727        }
728
729        // Record workflow rolled back event
730        let _ = audit_log
731            .record(crate::audit::AuditEvent::WorkflowRolledBack {
732                timestamp: Utc::now(),
733                workflow_id: workflow_id.to_string(),
734                reason: "Task failure triggered rollback".to_string(),
735                rolled_back_tasks: tasks.iter().map(|id| id.to_string()).collect(),
736            })
737            .await;
738
739        Ok(report)
740    }
741
742    /// Validates compensation coverage for all tasks in workflow.
743    ///
744    /// Reports which tasks have compensation actions defined and which don't.
745    ///
746    /// # Arguments
747    ///
748    /// * `workflow` - The workflow to validate
749    ///
750    /// # Returns
751    ///
752    /// Compensation coverage report
753    pub fn validate_compensation_coverage(
754        &self,
755        workflow: &Workflow,
756    ) -> CompensationReport {
757        let total_tasks = workflow.task_count();
758        let with_compensation = Vec::new();
759        let mut without_compensation = Vec::new();
760
761        // Note: We can't check actual compensation without the task instances
762        // This is a placeholder that will be enhanced when we redesign TaskNode
763        // to store compensation metadata
764        for task_id in workflow.task_ids() {
765            // For now, assume all tasks need compensation (conservative)
766            without_compensation.push(task_id);
767        }
768
769        let coverage = CompensationReport::calculate(with_compensation.len(), total_tasks);
770
771        CompensationReport {
772            tasks_with_compensation: with_compensation,
773            tasks_without_compensation: without_compensation,
774            coverage_percentage: coverage,
775        }
776    }
777}
778
779impl Default for RollbackEngine {
780    fn default() -> Self {
781        Self::new()
782    }
783}
784
785#[cfg(test)]
786mod tests {
787    use super::*;
788    use crate::workflow::task::{TaskContext, TaskError, TaskResult, WorkflowTask};
789    use async_trait::async_trait;
790    use std::fs::File;
791    use std::io::Write;
792
793    #[test]
794    fn test_tool_compensation_creation() {
795        let comp = ToolCompensation::new("Test compensation", |_ctx| Ok(TaskResult::Success));
796        assert_eq!(comp.description, "Test compensation");
797    }
798
799    #[test]
800    fn test_tool_compensation_execute() {
801        let comp = ToolCompensation::new("Execute test", |_ctx| Ok(TaskResult::Success));
802        let context = TaskContext::new("test", TaskId::new("a"));
803        let result = comp.execute(&context).unwrap();
804        assert_eq!(result, TaskResult::Success);
805    }
806
807    #[test]
808    fn test_tool_compensation_execute_error() {
809        let comp = ToolCompensation::new("Execute test", |_ctx| {
810            Err(TaskError::ExecutionFailed("Test error".to_string()))
811        });
812        let context = TaskContext::new("test", TaskId::new("a"));
813        let result = comp.execute(&context);
814        assert!(result.is_err());
815    }
816
817    #[test]
818    fn test_tool_compensation_skip() {
819        let comp = ToolCompensation::skip("No action needed");
820        let context = TaskContext::new("test", TaskId::new("a"));
821        let result = comp.execute(&context).unwrap();
822        assert_eq!(result, TaskResult::Skipped);
823    }
824
825    #[test]
826    fn test_tool_compensation_retry() {
827        let comp = ToolCompensation::retry("Retry recommended");
828        let context = TaskContext::new("test", TaskId::new("a"));
829        let result = comp.execute(&context).unwrap();
830        assert_eq!(result, TaskResult::Skipped);
831    }
832
833    #[test]
834    fn test_tool_compensation_file() {
835        // Create a temporary file
836        let temp_file = "/tmp/test_tool_compensation.txt";
837        let mut file = File::create(temp_file).unwrap();
838        writeln!(file, "test content").unwrap();
839        drop(file);
840
841        // Verify file exists
842        assert!(Path::new(temp_file).exists());
843
844        // Create file compensation and execute it
845        let comp = ToolCompensation::file_compensation(temp_file);
846        let context = TaskContext::new("test", TaskId::new("a"));
847        let result = comp.execute(&context);
848
849        assert!(result.is_ok());
850        assert!(!Path::new(temp_file).exists()); // File should be deleted
851    }
852
853    #[test]
854    fn test_tool_compensation_from_compensation_action() {
855        let skip_action = CompensationAction::skip("Skip action");
856        let skip_comp: ToolCompensation = skip_action.into();
857        assert_eq!(skip_comp.description, "Skip action");
858
859        let retry_action = CompensationAction::retry("Retry action");
860        let retry_comp: ToolCompensation = retry_action.into();
861        assert_eq!(retry_comp.description, "Retry action");
862
863        let undo_action = CompensationAction::undo("Undo action");
864        let undo_comp: ToolCompensation = undo_action.into();
865        assert!(undo_comp.description.contains("no undo function available"));
866    }
867
868    #[test]
869    fn test_compensation_registry_new() {
870        let registry = CompensationRegistry::new();
871        assert!(registry.is_empty());
872        assert_eq!(registry.len(), 0);
873    }
874
875    #[test]
876    fn test_compensation_registry_register() {
877        let mut registry = CompensationRegistry::new();
878        let task_id = TaskId::new("task-1");
879        let comp = ToolCompensation::skip("Test");
880
881        registry.register(task_id.clone(), comp);
882
883        assert_eq!(registry.len(), 1);
884        assert!(registry.has_compensation(&task_id));
885    }
886
887    #[test]
888    fn test_compensation_registry_get() {
889        let mut registry = CompensationRegistry::new();
890        let task_id = TaskId::new("task-1");
891        let comp = ToolCompensation::new("Test", |_ctx| Ok(TaskResult::Success));
892
893        registry.register(task_id.clone(), comp);
894
895        let retrieved = registry.get(&task_id);
896        assert!(retrieved.is_some());
897        assert_eq!(retrieved.unwrap().description, "Test");
898
899        // Non-existent task
900        let missing = registry.get(&TaskId::new("missing"));
901        assert!(missing.is_none());
902    }
903
904    #[test]
905    fn test_compensation_registry_remove() {
906        let mut registry = CompensationRegistry::new();
907        let task_id = TaskId::new("task-1");
908        let comp = ToolCompensation::skip("Test");
909
910        registry.register(task_id.clone(), comp);
911        assert_eq!(registry.len(), 1);
912
913        let removed = registry.remove(&task_id);
914        assert!(removed.is_some());
915        assert_eq!(registry.len(), 0);
916        assert!(!registry.has_compensation(&task_id));
917
918        // Remove non-existent task
919        let removed_again = registry.remove(&task_id);
920        assert!(removed_again.is_none());
921    }
922
923    #[test]
924    fn test_compensation_registry_validate_coverage() {
925        let mut registry = CompensationRegistry::new();
926
927        let task1 = TaskId::new("task-1");
928        let task2 = TaskId::new("task-2");
929        let task3 = TaskId::new("task-3");
930
931        registry.register(task1.clone(), ToolCompensation::skip("Test 1"));
932        registry.register(task2.clone(), ToolCompensation::skip("Test 2"));
933
934        let report = registry.validate_coverage(&[task1.clone(), task2.clone(), task3.clone()]);
935
936        assert_eq!(report.tasks_with_compensation.len(), 2);
937        assert!(report.tasks_with_compensation.contains(&task1));
938        assert!(report.tasks_with_compensation.contains(&task2));
939
940        assert_eq!(report.tasks_without_compensation.len(), 1);
941        assert!(report.tasks_without_compensation.contains(&task3));
942
943        assert!((report.coverage_percentage - 0.666).abs() < 0.01);
944    }
945
946    #[test]
947    fn test_compensation_registry_register_file_creation() {
948        let mut registry = CompensationRegistry::new();
949        let task_id = TaskId::new("task-1");
950
951        registry.register_file_creation(task_id.clone(), "/tmp/test.txt");
952
953        assert!(registry.has_compensation(&task_id));
954        let comp = registry.get(&task_id).unwrap();
955        assert!(comp.description.contains("Delete file"));
956    }
957
958    #[test]
959    fn test_compensation_registry_register_process_spawn() {
960        let mut registry = CompensationRegistry::new();
961        let task_id = TaskId::new("task-1");
962
963        registry.register_process_spawn(task_id.clone(), 12345);
964
965        assert!(registry.has_compensation(&task_id));
966        let comp = registry.get(&task_id).unwrap();
967        assert!(comp.description.contains("Terminate process"));
968    }
969
970    #[test]
971    fn test_compensation_registry_task_ids() {
972        let mut registry = CompensationRegistry::new();
973
974        let task1 = TaskId::new("task-1");
975        let task2 = TaskId::new("task-2");
976
977        registry.register(task1.clone(), ToolCompensation::skip("Test 1"));
978        registry.register(task2.clone(), ToolCompensation::skip("Test 2"));
979
980        let ids = registry.task_ids();
981        assert_eq!(ids.len(), 2);
982        assert!(ids.contains(&task1));
983        assert!(ids.contains(&task2));
984    }
985
986    #[test]
987    fn test_compensation_registry_default() {
988        let registry = CompensationRegistry::default();
989        assert!(registry.is_empty());
990    }
991
992    // Mock task with compensation for testing
993    struct MockTaskWithCompensation {
994        id: TaskId,
995        name: String,
996        deps: Vec<TaskId>,
997        compensation: Option<CompensationAction>,
998    }
999
1000    impl MockTaskWithCompensation {
1001        fn new(id: impl Into<TaskId>, name: &str) -> Self {
1002            Self {
1003                id: id.into(),
1004                name: name.to_string(),
1005                deps: Vec::new(),
1006                compensation: None,
1007            }
1008        }
1009
1010        fn with_dep(mut self, dep: impl Into<TaskId>) -> Self {
1011            self.deps.push(dep.into());
1012            self
1013        }
1014
1015        fn with_compensation(mut self, action: CompensationAction) -> Self {
1016            self.compensation = Some(action);
1017            self
1018        }
1019    }
1020
1021    #[async_trait]
1022    impl WorkflowTask for MockTaskWithCompensation {
1023        async fn execute(&self, _context: &TaskContext) -> Result<TaskResult, TaskError> {
1024            Ok(TaskResult::Success)
1025        }
1026
1027        fn id(&self) -> TaskId {
1028            self.id.clone()
1029        }
1030
1031        fn name(&self) -> &str {
1032            &self.name
1033        }
1034
1035        fn dependencies(&self) -> Vec<TaskId> {
1036            self.deps.clone()
1037        }
1038    }
1039
1040    #[test]
1041    fn test_compensation_action_creation() {
1042        let skip = CompensationAction::skip("Read-only operation");
1043        assert_eq!(skip.action_type, CompensationType::Skip);
1044        assert_eq!(skip.description, "Read-only operation");
1045
1046        let retry = CompensationAction::retry("Transient network error");
1047        assert_eq!(retry.action_type, CompensationType::Retry);
1048
1049        let undo = CompensationAction::undo("Delete file");
1050        assert_eq!(undo.action_type, CompensationType::UndoFunction);
1051    }
1052
1053    #[test]
1054    fn test_executable_compensation_creation() {
1055        let skip = ExecutableCompensation::skip("No action needed");
1056        assert_eq!(skip.action.action_type, CompensationType::Skip);
1057
1058        let retry = ExecutableCompensation::retry("Retry later");
1059        assert_eq!(retry.action.action_type, CompensationType::Retry);
1060
1061        let undo = ExecutableCompensation::with_undo("Execute undo", |_ctx| {
1062            Ok(TaskResult::Success)
1063        });
1064        assert_eq!(undo.action.action_type, CompensationType::UndoFunction);
1065    }
1066
1067    #[test]
1068    fn test_executable_compensation_execute() {
1069        let skip = ExecutableCompensation::skip("No action needed");
1070        let context = TaskContext::new("test", TaskId::new("a"));
1071        let result = skip.execute(&context).unwrap();
1072        assert_eq!(result, TaskResult::Skipped);
1073
1074        let retry = ExecutableCompensation::retry("Retry later");
1075        let result = retry.execute(&context).unwrap();
1076        assert_eq!(result, TaskResult::Skipped);
1077
1078        let undo = ExecutableCompensation::with_undo("Execute undo", |_ctx| {
1079            Ok(TaskResult::Success)
1080        });
1081        let result = undo.execute(&context).unwrap();
1082        assert_eq!(result, TaskResult::Success);
1083    }
1084
1085    #[test]
1086    fn test_rollback_engine_creation() {
1087        let engine = RollbackEngine::new();
1088        let _ = &engine; // Use engine to avoid unused warning
1089    }
1090
1091    #[tokio::test]
1092    async fn test_rollback_report_creation() {
1093        let report = RollbackReport::new();
1094        assert_eq!(report.total_processed(), 0);
1095        assert!(report.rolled_back_tasks.is_empty());
1096        assert!(report.skipped_tasks.is_empty());
1097        assert!(report.failed_compensations.is_empty());
1098    }
1099
1100    #[test]
1101    fn test_compensation_report_calculation() {
1102        let coverage = CompensationReport::calculate(5, 10);
1103        assert_eq!(coverage, 0.5);
1104
1105        let full_coverage = CompensationReport::calculate(10, 10);
1106        assert_eq!(full_coverage, 1.0);
1107
1108        let no_tasks = CompensationReport::calculate(0, 0);
1109        assert_eq!(no_tasks, 1.0); // No tasks = full coverage
1110    }
1111
1112    #[test]
1113    fn test_find_dependent_tasks() {
1114        let mut workflow = Workflow::new();
1115
1116        // Create diamond DAG: a -> b, a -> c, b -> d, c -> d
1117        workflow.add_task(Box::new(MockTaskWithCompensation::new("a", "Task A")));
1118        workflow.add_task(Box::new(MockTaskWithCompensation::new("b", "Task B")));
1119        workflow.add_task(Box::new(MockTaskWithCompensation::new("c", "Task C")));
1120        workflow.add_task(Box::new(MockTaskWithCompensation::new("d", "Task D")));
1121
1122        workflow.add_dependency("a", "b").unwrap();
1123        workflow.add_dependency("a", "c").unwrap();
1124        workflow.add_dependency("b", "d").unwrap();
1125        workflow.add_dependency("c", "d").unwrap();
1126
1127        let engine = RollbackEngine::new();
1128        let failed_idx = *workflow.task_map.get(&TaskId::new("d")).unwrap();
1129
1130        // Find dependents of d (should be none in forward direction)
1131        let dependents = engine.find_dependent_tasks(&workflow, failed_idx).unwrap();
1132
1133        // d has no dependents, only dependencies
1134        // So rollback set should only contain d itself
1135        assert_eq!(dependents.len(), 1);
1136        assert!(dependents.contains(&TaskId::new("d")));
1137    }
1138
1139    #[test]
1140    fn test_diamond_dependency_rollback() {
1141        let mut workflow = Workflow::new();
1142
1143        // Diamond: a -> b, a -> c, b -> d, c -> d
1144        workflow.add_task(Box::new(MockTaskWithCompensation::new("a", "Task A")));
1145        workflow.add_task(Box::new(MockTaskWithCompensation::new("b", "Task B")));
1146        workflow.add_task(Box::new(MockTaskWithCompensation::new("c", "Task C")));
1147        workflow.add_task(Box::new(MockTaskWithCompensation::new("d", "Task D")));
1148
1149        workflow.add_dependency("a", "b").unwrap();
1150        workflow.add_dependency("a", "c").unwrap();
1151        workflow.add_dependency("b", "d").unwrap();
1152        workflow.add_dependency("c", "d").unwrap();
1153
1154        let engine = RollbackEngine::new();
1155
1156        // When d fails, only d should be rolled back (no dependents)
1157        let rollback_set = engine
1158            .find_rollback_set(&workflow, &TaskId::new("d"), RollbackStrategy::AllDependent)
1159            .unwrap();
1160
1161        // Only d is rolled back (it has no dependents)
1162        assert_eq!(rollback_set.len(), 1);
1163        assert_eq!(rollback_set[0], TaskId::new("d"));
1164    }
1165
1166    #[test]
1167    fn test_reverse_execution_order() {
1168        let mut workflow = Workflow::new();
1169
1170        // Create linear chain: a -> b -> c
1171        workflow.add_task(Box::new(MockTaskWithCompensation::new("a", "Task A")));
1172        workflow.add_task(Box::new(MockTaskWithCompensation::new("b", "Task B")));
1173        workflow.add_task(Box::new(MockTaskWithCompensation::new("c", "Task C")));
1174
1175        workflow.add_dependency("a", "b").unwrap();
1176        workflow.add_dependency("b", "c").unwrap();
1177
1178        let engine = RollbackEngine::new();
1179        let failed_idx = *workflow.task_map.get(&TaskId::new("c")).unwrap();
1180
1181        let dependents = engine.find_dependent_tasks(&workflow, failed_idx).unwrap();
1182        let rollback_order = engine.reverse_execution_order(&workflow, dependents).unwrap();
1183
1184        // Execution order: a, b, c
1185        // Rollback order should be reverse: c
1186        assert_eq!(rollback_order.len(), 1);
1187        assert_eq!(rollback_order[0], TaskId::new("c"));
1188    }
1189
1190    #[tokio::test]
1191    async fn test_execute_rollback() {
1192        let mut workflow = Workflow::new();
1193
1194        workflow.add_task(Box::new(MockTaskWithCompensation::new("a", "Task A")));
1195        workflow.add_task(Box::new(MockTaskWithCompensation::new("b", "Task B")));
1196
1197        workflow.add_dependency("a", "b").unwrap();
1198
1199        let engine = RollbackEngine::new();
1200        let mut audit_log = AuditLog::new();
1201        let registry = CompensationRegistry::new();
1202
1203        // Roll back task b (no compensation registered, so skipped)
1204        let report = engine
1205            .execute_rollback(
1206                &workflow,
1207                vec![TaskId::new("b")],
1208                "test_workflow",
1209                &mut audit_log,
1210                &registry,
1211            )
1212            .await
1213            .unwrap();
1214
1215        // Task b should be skipped (no compensation registered)
1216        assert_eq!(report.skipped_tasks.len(), 1);
1217        assert_eq!(report.skipped_tasks[0], TaskId::new("b"));
1218        assert!(report.rolled_back_tasks.is_empty());
1219        assert!(report.failed_compensations.is_empty());
1220
1221        // Verify audit events
1222        let events = audit_log.replay();
1223        assert!(events.iter().any(|e| matches!(e, crate::audit::AuditEvent::WorkflowTaskRolledBack { .. })));
1224        assert!(events.iter().any(|e| matches!(e, crate::audit::AuditEvent::WorkflowRolledBack { .. })));
1225    }
1226
1227    #[tokio::test]
1228    async fn test_execute_rollback_with_compensation() {
1229        let mut workflow = Workflow::new();
1230
1231        workflow.add_task(Box::new(MockTaskWithCompensation::new("a", "Task A")));
1232        workflow.add_task(Box::new(MockTaskWithCompensation::new("b", "Task B")));
1233
1234        workflow.add_dependency("a", "b").unwrap();
1235
1236        let engine = RollbackEngine::new();
1237        let mut audit_log = AuditLog::new();
1238        let mut registry = CompensationRegistry::new();
1239
1240        // Register compensation for task b
1241        registry.register(TaskId::new("b"), ToolCompensation::skip("Test compensation"));
1242
1243        // Roll back task b
1244        let report = engine
1245            .execute_rollback(
1246                &workflow,
1247                vec![TaskId::new("b")],
1248                "test_workflow",
1249                &mut audit_log,
1250                &registry,
1251            )
1252            .await
1253            .unwrap();
1254
1255        // Task b should be rolled back successfully
1256        assert_eq!(report.rolled_back_tasks.len(), 1);
1257        assert_eq!(report.rolled_back_tasks[0], TaskId::new("b"));
1258        assert!(report.skipped_tasks.is_empty());
1259        assert!(report.failed_compensations.is_empty());
1260
1261        // Verify audit events
1262        let events = audit_log.replay();
1263        assert!(events.iter().any(|e| matches!(e, crate::audit::AuditEvent::WorkflowTaskRolledBack { .. })));
1264        assert!(events.iter().any(|e| matches!(e, crate::audit::AuditEvent::WorkflowRolledBack { .. })));
1265    }
1266
1267    #[tokio::test]
1268    async fn test_execute_rollback_mixed_compensation() {
1269        let mut workflow = Workflow::new();
1270
1271        workflow.add_task(Box::new(MockTaskWithCompensation::new("a", "Task A")));
1272        workflow.add_task(Box::new(MockTaskWithCompensation::new("b", "Task B")));
1273        workflow.add_task(Box::new(MockTaskWithCompensation::new("c", "Task C")));
1274
1275        workflow.add_dependency("a", "b").unwrap();
1276        workflow.add_dependency("b", "c").unwrap();
1277
1278        let engine = RollbackEngine::new();
1279        let mut audit_log = AuditLog::new();
1280        let mut registry = CompensationRegistry::new();
1281
1282        // Register compensation only for task a
1283        registry.register(TaskId::new("a"), ToolCompensation::skip("Test compensation"));
1284
1285        // Roll back all three tasks
1286        let report = engine
1287            .execute_rollback(
1288                &workflow,
1289                vec![TaskId::new("a"), TaskId::new("b"), TaskId::new("c")],
1290                "test_workflow",
1291                &mut audit_log,
1292                &registry,
1293            )
1294            .await
1295            .unwrap();
1296
1297        // Task a rolled back, b and c skipped
1298        assert_eq!(report.rolled_back_tasks.len(), 1);
1299        assert_eq!(report.rolled_back_tasks[0], TaskId::new("a"));
1300        assert_eq!(report.skipped_tasks.len(), 2);
1301        assert!(report.skipped_tasks.contains(&TaskId::new("b")));
1302        assert!(report.skipped_tasks.contains(&TaskId::new("c")));
1303        assert!(report.failed_compensations.is_empty());
1304    }
1305
1306    #[test]
1307    fn test_validate_compensation_coverage() {
1308        let mut workflow = Workflow::new();
1309
1310        workflow.add_task(Box::new(MockTaskWithCompensation::new("a", "Task A")));
1311        workflow.add_task(Box::new(MockTaskWithCompensation::new("b", "Task B")));
1312
1313        workflow.add_dependency("a", "b").unwrap();
1314
1315        let engine = RollbackEngine::new();
1316        let report = engine.validate_compensation_coverage(&workflow);
1317
1318        // All tasks are marked as without compensation (placeholder logic)
1319        assert_eq!(report.tasks_without_compensation.len(), 2);
1320        assert_eq!(report.tasks_with_compensation.len(), 0);
1321        assert_eq!(report.coverage_percentage, 0.0);
1322    }
1323}