1use crate::error::{WorkflowError, WorkflowResult};
4use crate::models::{StepResult, StepStatus, Workflow, WorkflowState, WorkflowStatus};
5use chrono::Utc;
6use std::collections::HashMap;
7use std::path::Path;
8
9pub struct StateManager;
11
12impl StateManager {
13 pub fn create_state(workflow: &Workflow) -> WorkflowState {
15 WorkflowState {
16 workflow_id: workflow.id.clone(),
17 status: WorkflowStatus::Pending,
18 current_step: None,
19 completed_steps: Vec::new(),
20 step_results: HashMap::new(),
21 started_at: Utc::now(),
22 updated_at: Utc::now(),
23 }
24 }
25
26 pub fn start_workflow(state: &mut WorkflowState) {
28 state.status = WorkflowStatus::Running;
29 state.started_at = Utc::now();
30 state.updated_at = Utc::now();
31 }
32
33 pub fn start_step(state: &mut WorkflowState, step_id: String) {
35 state.current_step = Some(step_id.clone());
36 state.step_results.insert(
37 step_id,
38 StepResult {
39 status: StepStatus::Running,
40 output: None,
41 error: None,
42 duration_ms: 0,
43 },
44 );
45 state.updated_at = Utc::now();
46 }
47
48 pub fn complete_step(
50 state: &mut WorkflowState,
51 step_id: String,
52 output: Option<serde_json::Value>,
53 duration_ms: u64,
54 ) {
55 if let Some(result) = state.step_results.get_mut(&step_id) {
56 result.status = StepStatus::Completed;
57 result.output = output;
58 result.duration_ms = duration_ms;
59 }
60
61 state.completed_steps.push(step_id);
62 state.updated_at = Utc::now();
63 }
64
65 pub fn fail_step(state: &mut WorkflowState, step_id: String, error: String, duration_ms: u64) {
67 if let Some(result) = state.step_results.get_mut(&step_id) {
68 result.status = StepStatus::Failed;
69 result.error = Some(error);
70 result.duration_ms = duration_ms;
71 }
72
73 state.updated_at = Utc::now();
74 }
75
76 pub fn skip_step(state: &mut WorkflowState, step_id: String) {
78 if let Some(result) = state.step_results.get_mut(&step_id) {
79 result.status = StepStatus::Skipped;
80 }
81
82 state.completed_steps.push(step_id);
83 state.updated_at = Utc::now();
84 }
85
86 pub fn wait_for_approval(state: &mut WorkflowState) {
88 state.status = WorkflowStatus::WaitingApproval;
89 state.updated_at = Utc::now();
90 }
91
92 pub fn complete_workflow(state: &mut WorkflowState) {
94 state.status = WorkflowStatus::Completed;
95 state.current_step = None;
96 state.updated_at = Utc::now();
97 }
98
99 pub fn fail_workflow(state: &mut WorkflowState) {
101 state.status = WorkflowStatus::Failed;
102 state.updated_at = Utc::now();
103 }
104
105 pub fn cancel_workflow(state: &mut WorkflowState) {
107 state.status = WorkflowStatus::Cancelled;
108 state.updated_at = Utc::now();
109 }
110
111 pub fn pause_workflow(state: &mut WorkflowState) -> WorkflowResult<()> {
113 if state.status != WorkflowStatus::Running
115 && state.status != WorkflowStatus::WaitingApproval
116 {
117 return Err(WorkflowError::StateError(format!(
118 "Cannot pause workflow in {:?} status",
119 state.status
120 )));
121 }
122
123 state.status = WorkflowStatus::Paused;
124 state.updated_at = Utc::now();
125 Ok(())
126 }
127
128 pub fn resume_workflow(state: &mut WorkflowState) -> WorkflowResult<()> {
130 if state.status != WorkflowStatus::Paused {
132 return Err(WorkflowError::StateError(format!(
133 "Cannot resume workflow in {:?} status",
134 state.status
135 )));
136 }
137
138 state.status = WorkflowStatus::Running;
139 state.updated_at = Utc::now();
140 Ok(())
141 }
142
143 pub fn is_step_completed(state: &WorkflowState, step_id: &str) -> bool {
145 state.completed_steps.contains(&step_id.to_string())
146 }
147
148 pub fn get_next_step_to_execute(
150 state: &WorkflowState,
151 available_steps: &[String],
152 ) -> Option<String> {
153 available_steps
154 .iter()
155 .find(|step_id| !Self::is_step_completed(state, step_id))
156 .cloned()
157 }
158
159 pub fn persist_state(state: &WorkflowState, path: &Path) -> WorkflowResult<()> {
161 if let Some(parent) = path.parent() {
163 std::fs::create_dir_all(parent).map_err(|e| {
164 WorkflowError::StateError(format!("Failed to create state directory: {}", e))
165 })?;
166 }
167
168 let yaml = serde_yaml::to_string(state)
169 .map_err(|e| WorkflowError::StateError(format!("Failed to serialize state: {}", e)))?;
170
171 std::fs::write(path, yaml)
172 .map_err(|e| WorkflowError::StateError(format!("Failed to write state file: {}", e)))?;
173
174 Ok(())
175 }
176
177 pub fn persist_state_json(state: &WorkflowState, path: &Path) -> WorkflowResult<()> {
179 if let Some(parent) = path.parent() {
181 std::fs::create_dir_all(parent).map_err(|e| {
182 WorkflowError::StateError(format!("Failed to create state directory: {}", e))
183 })?;
184 }
185
186 let json = serde_json::to_string_pretty(state)
187 .map_err(|e| WorkflowError::StateError(format!("Failed to serialize state: {}", e)))?;
188
189 std::fs::write(path, json)
190 .map_err(|e| WorkflowError::StateError(format!("Failed to write state file: {}", e)))?;
191
192 Ok(())
193 }
194
195 pub fn load_state(path: &Path) -> WorkflowResult<WorkflowState> {
197 if !path.exists() {
198 return Err(WorkflowError::StateError(format!(
199 "State file not found: {}",
200 path.display()
201 )));
202 }
203
204 let content = std::fs::read_to_string(path)
205 .map_err(|e| WorkflowError::StateError(format!("Failed to read state file: {}", e)))?;
206
207 if let Ok(state) = serde_json::from_str::<WorkflowState>(&content) {
209 return Ok(state);
210 }
211
212 serde_yaml::from_str::<WorkflowState>(&content)
214 .map_err(|e| WorkflowError::StateError(format!("Failed to deserialize state: {}", e)))
215 }
216
217 pub fn validate_state(state: &WorkflowState) -> WorkflowResult<()> {
219 if state.workflow_id.is_empty() {
221 return Err(WorkflowError::StateError(
222 "Workflow ID cannot be empty".to_string(),
223 ));
224 }
225
226 for step_id in &state.completed_steps {
228 if !state.step_results.contains_key(step_id) {
229 return Err(WorkflowError::StateError(format!(
230 "Completed step '{}' has no result",
231 step_id
232 )));
233 }
234 }
235
236 if let Some(current_step) = &state.current_step {
238 if !state.step_results.contains_key(current_step) {
239 return Err(WorkflowError::StateError(format!(
240 "Current step '{}' has no result",
241 current_step
242 )));
243 }
244 }
245
246 Ok(())
247 }
248
249 pub fn load_state_validated(path: &Path) -> WorkflowResult<WorkflowState> {
251 let state = Self::load_state(path)?;
252 Self::validate_state(&state)?;
253 Ok(state)
254 }
255
256 pub fn load_state_with_recovery(path: &Path) -> WorkflowResult<WorkflowState> {
258 match Self::load_state_validated(path) {
259 Ok(state) => Ok(state),
260 Err(e) => {
261 eprintln!("Warning: Failed to load state file: {}", e);
263 Err(e)
264 }
265 }
266 }
267
268 pub fn get_progress(state: &WorkflowState, total_steps: usize) -> u32 {
270 if total_steps == 0 {
271 return 0;
272 }
273
274 ((state.completed_steps.len() as u32 * 100) / total_steps as u32).min(100)
275 }
276}
277
278#[cfg(test)]
279mod tests {
280 use super::*;
281 use crate::models::{ErrorAction, RiskFactors, StepType, WorkflowConfig, WorkflowStep};
282
283 fn create_test_workflow() -> Workflow {
284 Workflow {
285 id: "test-workflow".to_string(),
286 name: "Test Workflow".to_string(),
287 description: "A test workflow".to_string(),
288 parameters: vec![],
289 steps: vec![WorkflowStep {
290 id: "step1".to_string(),
291 name: "Step 1".to_string(),
292 step_type: StepType::Agent(crate::models::AgentStep {
293 agent_id: "test-agent".to_string(),
294 task: "test-task".to_string(),
295 }),
296 config: crate::models::StepConfig {
297 config: serde_json::json!({}),
298 },
299 dependencies: vec![],
300 approval_required: false,
301 on_error: ErrorAction::Fail,
302 risk_score: None,
303 risk_factors: RiskFactors::default(),
304 }],
305 config: WorkflowConfig {
306 timeout_ms: None,
307 max_parallel: None,
308 },
309 }
310 }
311
312 #[test]
313 fn test_create_state() {
314 let workflow = create_test_workflow();
315 let state = StateManager::create_state(&workflow);
316
317 assert_eq!(state.workflow_id, "test-workflow");
318 assert_eq!(state.status, WorkflowStatus::Pending);
319 assert!(state.current_step.is_none());
320 assert!(state.completed_steps.is_empty());
321 }
322
323 #[test]
324 fn test_start_workflow() {
325 let workflow = create_test_workflow();
326 let mut state = StateManager::create_state(&workflow);
327
328 StateManager::start_workflow(&mut state);
329 assert_eq!(state.status, WorkflowStatus::Running);
330 }
331
332 #[test]
333 fn test_complete_step() {
334 let workflow = create_test_workflow();
335 let mut state = StateManager::create_state(&workflow);
336
337 StateManager::start_step(&mut state, "step1".to_string());
338 StateManager::complete_step(
339 &mut state,
340 "step1".to_string(),
341 Some(serde_json::json!({"result": "success"})),
342 100,
343 );
344
345 assert!(state.completed_steps.contains(&"step1".to_string()));
346 assert_eq!(
347 state.step_results.get("step1").unwrap().status,
348 StepStatus::Completed
349 );
350 }
351
352 #[test]
353 fn test_get_progress() {
354 let workflow = create_test_workflow();
355 let mut state = StateManager::create_state(&workflow);
356
357 assert_eq!(StateManager::get_progress(&state, 10), 0);
358
359 state.completed_steps.push("step1".to_string());
360 assert_eq!(StateManager::get_progress(&state, 10), 10);
361
362 state.completed_steps.push("step2".to_string());
363 assert_eq!(StateManager::get_progress(&state, 10), 20);
364 }
365
366 #[test]
367 fn test_pause_resume_workflow() {
368 let workflow = create_test_workflow();
369 let mut state = StateManager::create_state(&workflow);
370
371 StateManager::start_workflow(&mut state);
372 assert_eq!(state.status, WorkflowStatus::Running);
373
374 let result = StateManager::pause_workflow(&mut state);
376 assert!(result.is_ok());
377 assert_eq!(state.status, WorkflowStatus::Paused);
378
379 let result = StateManager::resume_workflow(&mut state);
381 assert!(result.is_ok());
382 assert_eq!(state.status, WorkflowStatus::Running);
383 }
384
385 #[test]
386 fn test_pause_non_running_workflow_fails() {
387 let workflow = create_test_workflow();
388 let mut state = StateManager::create_state(&workflow);
389
390 let result = StateManager::pause_workflow(&mut state);
392 assert!(result.is_err());
393 }
394
395 #[test]
396 fn test_resume_non_paused_workflow_fails() {
397 let workflow = create_test_workflow();
398 let mut state = StateManager::create_state(&workflow);
399
400 let result = StateManager::resume_workflow(&mut state);
402 assert!(result.is_err());
403 }
404
405 #[test]
406 fn test_is_step_completed() {
407 let workflow = create_test_workflow();
408 let mut state = StateManager::create_state(&workflow);
409
410 assert!(!StateManager::is_step_completed(&state, "step1"));
411
412 state.completed_steps.push("step1".to_string());
413 assert!(StateManager::is_step_completed(&state, "step1"));
414 }
415
416 #[test]
417 fn test_get_next_step_to_execute() {
418 let workflow = create_test_workflow();
419 let mut state = StateManager::create_state(&workflow);
420
421 let available_steps = vec![
422 "step1".to_string(),
423 "step2".to_string(),
424 "step3".to_string(),
425 ];
426
427 let next = StateManager::get_next_step_to_execute(&state, &available_steps);
429 assert_eq!(next, Some("step1".to_string()));
430
431 state.completed_steps.push("step1".to_string());
433
434 let next = StateManager::get_next_step_to_execute(&state, &available_steps);
436 assert_eq!(next, Some("step2".to_string()));
437
438 state.completed_steps.push("step2".to_string());
440 state.completed_steps.push("step3".to_string());
441
442 let next = StateManager::get_next_step_to_execute(&state, &available_steps);
444 assert_eq!(next, None);
445 }
446
447 #[test]
448 fn test_validate_state_success() {
449 let workflow = create_test_workflow();
450 let mut state = StateManager::create_state(&workflow);
451
452 StateManager::start_step(&mut state, "step1".to_string());
453 StateManager::complete_step(&mut state, "step1".to_string(), None, 100);
454
455 let result = StateManager::validate_state(&state);
456 assert!(result.is_ok());
457 }
458
459 #[test]
460 fn test_validate_state_missing_result() {
461 let workflow = create_test_workflow();
462 let mut state = StateManager::create_state(&workflow);
463
464 state.completed_steps.push("step1".to_string());
466
467 let result = StateManager::validate_state(&state);
468 assert!(result.is_err());
469 }
470}