1use crate::error::{WorkflowError, WorkflowResult};
4use crate::models::{Workflow, WorkflowState};
5use crate::state::StateManager;
6use std::collections::HashMap;
7
8pub struct RollbackManager;
15
16#[derive(Debug, Clone)]
18pub struct RollbackPlan {
19 pub rollback_steps: HashMap<String, Vec<String>>,
21 pub execution_order: Vec<String>,
23}
24
25impl RollbackPlan {
26 pub fn new() -> Self {
28 Self {
29 rollback_steps: HashMap::new(),
30 execution_order: Vec::new(),
31 }
32 }
33
34 pub fn add_rollback_step(&mut self, step_id: String, rollback_step: String) {
36 self.rollback_steps
37 .entry(step_id)
38 .or_default()
39 .push(rollback_step);
40 }
41
42 pub fn record_execution(&mut self, step_id: String) {
44 self.execution_order.push(step_id);
45 }
46
47 pub fn get_rollback_order(&self) -> Vec<String> {
49 let mut rollback_order = Vec::new();
50
51 for step_id in self.execution_order.iter().rev() {
53 if let Some(rollback_steps) = self.rollback_steps.get(step_id) {
54 rollback_order.extend(rollback_steps.clone());
55 }
56 }
57
58 rollback_order
59 }
60}
61
62impl Default for RollbackPlan {
63 fn default() -> Self {
64 Self::new()
65 }
66}
67
68impl RollbackManager {
69 pub fn create_rollback_plan(workflow: &Workflow) -> RollbackPlan {
71 let mut plan = RollbackPlan::new();
72
73 for step in &workflow.steps {
75 plan.record_execution(step.id.clone());
78 }
79
80 plan
81 }
82
83 pub fn execute_rollback(
88 workflow: &Workflow,
89 state: &mut WorkflowState,
90 rollback_plan: &RollbackPlan,
91 ) -> WorkflowResult<()> {
92 let rollback_order = rollback_plan.get_rollback_order();
93
94 for rollback_step_id in rollback_order {
96 let _step = workflow
98 .steps
99 .iter()
100 .find(|s| s.id == rollback_step_id)
101 .ok_or_else(|| {
102 WorkflowError::NotFound(format!(
103 "Rollback step not found: {}",
104 rollback_step_id
105 ))
106 })?;
107
108 StateManager::start_step(state, rollback_step_id.clone());
110
111 StateManager::complete_step(
114 state,
115 rollback_step_id,
116 Some(serde_json::json!({"rollback": true})),
117 0,
118 );
119 }
120
121 Ok(())
122 }
123
124 pub fn restore_state(state: &mut WorkflowState) -> WorkflowResult<()> {
129 state.completed_steps.clear();
131
132 state.step_results.clear();
134
135 state.current_step = None;
137
138 Ok(())
139 }
140
141 pub fn has_rollback_steps(rollback_plan: &RollbackPlan, step_id: &str) -> bool {
143 rollback_plan
144 .rollback_steps
145 .get(step_id)
146 .map(|steps| !steps.is_empty())
147 .unwrap_or(false)
148 }
149
150 pub fn get_rollback_steps(rollback_plan: &RollbackPlan, step_id: &str) -> Vec<String> {
152 rollback_plan
153 .rollback_steps
154 .get(step_id)
155 .cloned()
156 .unwrap_or_default()
157 }
158
159 pub fn add_rollback_step(
161 rollback_plan: &mut RollbackPlan,
162 step_id: String,
163 rollback_step: String,
164 ) {
165 rollback_plan.add_rollback_step(step_id, rollback_step);
166 }
167
168 pub fn record_step_execution(rollback_plan: &mut RollbackPlan, step_id: String) {
170 rollback_plan.record_execution(step_id);
171 }
172}
173
174#[cfg(test)]
175mod tests {
176 use super::*;
177 use crate::models::{
178 AgentStep, ErrorAction, RiskFactors, StepConfig, StepType, WorkflowConfig, WorkflowStep,
179 };
180
181 fn create_simple_workflow() -> Workflow {
182 Workflow {
183 id: "test-workflow".to_string(),
184 name: "Test Workflow".to_string(),
185 description: "A test workflow".to_string(),
186 parameters: vec![],
187 steps: vec![
188 WorkflowStep {
189 id: "step1".to_string(),
190 name: "Step 1".to_string(),
191 step_type: StepType::Agent(AgentStep {
192 agent_id: "test-agent".to_string(),
193 task: "test-task".to_string(),
194 }),
195 config: StepConfig {
196 config: serde_json::json!({"param": "value"}),
197 },
198 dependencies: vec![],
199 approval_required: false,
200 on_error: ErrorAction::Fail,
201 risk_score: None,
202 risk_factors: RiskFactors::default(),
203 },
204 WorkflowStep {
205 id: "step2".to_string(),
206 name: "Step 2".to_string(),
207 step_type: StepType::Agent(AgentStep {
208 agent_id: "test-agent".to_string(),
209 task: "test-task".to_string(),
210 }),
211 config: StepConfig {
212 config: serde_json::json!({"param": "value"}),
213 },
214 dependencies: vec!["step1".to_string()],
215 approval_required: false,
216 on_error: ErrorAction::Fail,
217 risk_score: None,
218 risk_factors: RiskFactors::default(),
219 },
220 ],
221 config: WorkflowConfig {
222 timeout_ms: None,
223 max_parallel: None,
224 },
225 }
226 }
227
228 #[test]
229 fn test_rollback_plan_creation() {
230 let plan = RollbackPlan::new();
231 assert!(plan.rollback_steps.is_empty());
232 assert!(plan.execution_order.is_empty());
233 }
234
235 #[test]
236 fn test_add_rollback_step() {
237 let mut plan = RollbackPlan::new();
238 plan.add_rollback_step("step1".to_string(), "rollback1".to_string());
239
240 assert!(plan.rollback_steps.contains_key("step1"));
241 assert_eq!(
242 plan.rollback_steps.get("step1").unwrap(),
243 &vec!["rollback1".to_string()]
244 );
245 }
246
247 #[test]
248 fn test_record_execution() {
249 let mut plan = RollbackPlan::new();
250 plan.record_execution("step1".to_string());
251 plan.record_execution("step2".to_string());
252
253 assert_eq!(
254 plan.execution_order,
255 vec!["step1".to_string(), "step2".to_string()]
256 );
257 }
258
259 #[test]
260 fn test_get_rollback_order() {
261 let mut plan = RollbackPlan::new();
262
263 plan.add_rollback_step("step1".to_string(), "rollback1".to_string());
264 plan.add_rollback_step("step2".to_string(), "rollback2".to_string());
265
266 plan.record_execution("step1".to_string());
267 plan.record_execution("step2".to_string());
268
269 let rollback_order = plan.get_rollback_order();
270
271 assert_eq!(
273 rollback_order,
274 vec!["rollback2".to_string(), "rollback1".to_string()]
275 );
276 }
277
278 #[test]
279 fn test_create_rollback_plan() {
280 let workflow = create_simple_workflow();
281 let plan = RollbackManager::create_rollback_plan(&workflow);
282
283 assert_eq!(plan.execution_order.len(), 2);
284 assert_eq!(plan.execution_order[0], "step1");
285 assert_eq!(plan.execution_order[1], "step2");
286 }
287
288 #[test]
289 fn test_restore_state() {
290 let workflow = create_simple_workflow();
291 let mut state = StateManager::create_state(&workflow);
292
293 state.completed_steps.push("step1".to_string());
295 state.current_step = Some("step2".to_string());
296
297 let result = RollbackManager::restore_state(&mut state);
298 assert!(result.is_ok());
299
300 assert!(state.completed_steps.is_empty());
301 assert!(state.step_results.is_empty());
302 assert!(state.current_step.is_none());
303 }
304
305 #[test]
306 fn test_has_rollback_steps() {
307 let mut plan = RollbackPlan::new();
308 plan.add_rollback_step("step1".to_string(), "rollback1".to_string());
309
310 assert!(RollbackManager::has_rollback_steps(&plan, "step1"));
311 assert!(!RollbackManager::has_rollback_steps(&plan, "step2"));
312 }
313
314 #[test]
315 fn test_get_rollback_steps() {
316 let mut plan = RollbackPlan::new();
317 plan.add_rollback_step("step1".to_string(), "rollback1".to_string());
318 plan.add_rollback_step("step1".to_string(), "rollback2".to_string());
319
320 let steps = RollbackManager::get_rollback_steps(&plan, "step1");
321 assert_eq!(
322 steps,
323 vec!["rollback1".to_string(), "rollback2".to_string()]
324 );
325 }
326
327 #[test]
328 fn test_add_rollback_step_to_plan() {
329 let mut plan = RollbackPlan::new();
330 RollbackManager::add_rollback_step(&mut plan, "step1".to_string(), "rollback1".to_string());
331
332 assert!(RollbackManager::has_rollback_steps(&plan, "step1"));
333 }
334
335 #[test]
336 fn test_record_step_execution() {
337 let mut plan = RollbackPlan::new();
338 RollbackManager::record_step_execution(&mut plan, "step1".to_string());
339
340 assert_eq!(plan.execution_order, vec!["step1".to_string()]);
341 }
342}