Skip to main content

erio_workflow/
engine.rs

1//! Workflow execution engine with parallel step execution.
2
3use std::sync::Arc;
4
5use tokio::sync::Mutex;
6
7use std::path::Path;
8
9use crate::WorkflowError;
10use crate::builder::Workflow;
11use crate::checkpoint::Checkpoint;
12use crate::context::WorkflowContext;
13use crate::step::StepOutput;
14
15/// Executes workflows by resolving the DAG and running steps.
16///
17/// Independent steps are executed in parallel using tokio tasks.
18#[derive(Debug, Clone, Default)]
19pub struct WorkflowEngine;
20
21impl WorkflowEngine {
22    /// Creates a new workflow engine.
23    pub fn new() -> Self {
24        Self
25    }
26
27    /// Runs a workflow to completion.
28    ///
29    /// Steps are executed in parallel groups determined by the DAG.
30    /// If any step fails, dependent steps are skipped and the error is returned.
31    pub async fn run(&self, workflow: Workflow) -> Result<WorkflowContext, WorkflowError> {
32        let groups = workflow.parallel_groups()?;
33        let ctx = Arc::new(Mutex::new(WorkflowContext::new()));
34        let failed: Arc<Mutex<Option<WorkflowError>>> = Arc::new(Mutex::new(None));
35
36        for group in groups {
37            // Check if a previous step already failed
38            if failed.lock().await.is_some() {
39                break;
40            }
41
42            if group.len() == 1 {
43                // Single step — run directly (no spawn overhead)
44                let step_id = group[0];
45                let step = workflow.step(step_id).expect("DAG validated step exists");
46
47                let mut ctx_guard = ctx.lock().await;
48                match step.execute(&mut ctx_guard).await {
49                    Ok(output) => {
50                        ctx_guard.set_output(step_id, output);
51                    }
52                    Err(e) => {
53                        return Err(e);
54                    }
55                }
56            } else {
57                // Multiple independent steps — run in parallel
58                let mut handles = Vec::with_capacity(group.len());
59
60                for step_id in &group {
61                    let step = workflow.step(step_id).expect("DAG validated step exists");
62                    let ctx_clone = ctx.clone();
63                    let failed_clone = failed.clone();
64                    let step_id_owned = (*step_id).to_string();
65
66                    let handle = tokio::spawn(async move {
67                        // Take a snapshot of context for this step
68                        let mut ctx_snapshot = ctx_clone.lock().await.clone();
69                        drop(ctx_clone); // Release lock during execution
70
71                        match step.execute(&mut ctx_snapshot).await {
72                            Ok(output) => Ok((step_id_owned, output)),
73                            Err(e) => {
74                                *failed_clone.lock().await = Some(WorkflowError::StepFailed {
75                                    step_id: step_id_owned.clone(),
76                                    message: e.to_string(),
77                                });
78                                Err(e)
79                            }
80                        }
81                    });
82
83                    handles.push(handle);
84                }
85
86                // Collect results
87                let mut first_error: Option<WorkflowError> = None;
88                let mut outputs: Vec<(String, StepOutput)> = Vec::new();
89
90                for handle in handles {
91                    match handle.await {
92                        Ok(Ok((id, output))) => outputs.push((id, output)),
93                        Ok(Err(e)) => {
94                            if first_error.is_none() {
95                                first_error = Some(e);
96                            }
97                        }
98                        Err(join_err) => {
99                            if first_error.is_none() {
100                                first_error = Some(WorkflowError::StepFailed {
101                                    step_id: "unknown".into(),
102                                    message: format!("Task panicked: {join_err}"),
103                                });
104                            }
105                        }
106                    }
107                }
108
109                // If any step in this group failed, return the error
110                if let Some(err) = first_error {
111                    return Err(err);
112                }
113
114                // Store all outputs
115                let mut ctx_guard = ctx.lock().await;
116                for (id, output) in outputs {
117                    ctx_guard.set_output(&id, output);
118                }
119            }
120        }
121
122        let result = ctx.lock().await.clone();
123        Ok(result)
124    }
125
126    /// Runs a workflow with checkpointing after each group completes.
127    ///
128    /// If a checkpoint file already exists at the path, completed steps are skipped.
129    pub async fn run_with_checkpoint(
130        &self,
131        workflow: Workflow,
132        checkpoint_path: &Path,
133    ) -> Result<WorkflowContext, WorkflowError> {
134        let groups = workflow.parallel_groups()?;
135
136        // Load existing checkpoint or create new
137        let mut checkpoint = if checkpoint_path.exists() {
138            Checkpoint::load(checkpoint_path).await?
139        } else {
140            Checkpoint::new()
141        };
142
143        let ctx = Arc::new(Mutex::new(checkpoint.clone().into_context()));
144
145        for group in groups {
146            // Filter out already-completed steps
147            let pending: Vec<&str> = group
148                .iter()
149                .filter(|id| !checkpoint.is_completed(id))
150                .copied()
151                .collect();
152
153            if pending.is_empty() {
154                continue;
155            }
156
157            if pending.len() == 1 {
158                let step_id = pending[0];
159                let step = workflow.step(step_id).expect("DAG validated");
160                let mut ctx_guard = ctx.lock().await;
161                let output = step.execute(&mut ctx_guard).await?;
162                ctx_guard.set_output(step_id, output.clone());
163                checkpoint.mark_completed(step_id, output);
164            } else {
165                let mut handles = Vec::with_capacity(pending.len());
166
167                for step_id in &pending {
168                    let step = workflow.step(step_id).expect("DAG validated");
169                    let ctx_clone = ctx.clone();
170                    let step_id_owned = (*step_id).to_string();
171
172                    let handle = tokio::spawn(async move {
173                        let mut ctx_snapshot = ctx_clone.lock().await.clone();
174                        drop(ctx_clone);
175                        let output = step.execute(&mut ctx_snapshot).await?;
176                        Ok::<_, WorkflowError>((step_id_owned, output))
177                    });
178                    handles.push(handle);
179                }
180
181                let mut first_error: Option<WorkflowError> = None;
182                let mut outputs: Vec<(String, StepOutput)> = Vec::new();
183
184                for handle in handles {
185                    match handle.await {
186                        Ok(Ok((id, output))) => outputs.push((id, output)),
187                        Ok(Err(e)) => {
188                            if first_error.is_none() {
189                                first_error = Some(e);
190                            }
191                        }
192                        Err(join_err) => {
193                            if first_error.is_none() {
194                                first_error = Some(WorkflowError::StepFailed {
195                                    step_id: "unknown".into(),
196                                    message: format!("Task panicked: {join_err}"),
197                                });
198                            }
199                        }
200                    }
201                }
202
203                if let Some(err) = first_error {
204                    // Save checkpoint before returning error
205                    checkpoint.save(checkpoint_path).await?;
206                    return Err(err);
207                }
208
209                let mut ctx_guard = ctx.lock().await;
210                for (id, output) in outputs {
211                    ctx_guard.set_output(&id, output.clone());
212                    checkpoint.mark_completed(&id, output);
213                }
214            }
215
216            // Save checkpoint after each group
217            checkpoint.save(checkpoint_path).await?;
218        }
219
220        let result = ctx.lock().await.clone();
221        Ok(result)
222    }
223}
224
225#[cfg(test)]
226mod tests {
227    use super::*;
228    use crate::WorkflowError;
229    use crate::builder::Workflow;
230    use crate::context::WorkflowContext;
231    use crate::step::{Step, StepOutput};
232    use std::sync::Arc;
233    use std::sync::atomic::{AtomicUsize, Ordering};
234    use std::time::Duration;
235
236    // === Mock Steps ===
237
238    struct ValueStep {
239        step_id: String,
240        output: String,
241    }
242
243    impl ValueStep {
244        fn new(id: &str, output: &str) -> Self {
245            Self {
246                step_id: id.into(),
247                output: output.into(),
248            }
249        }
250    }
251
252    #[async_trait::async_trait]
253    impl Step for ValueStep {
254        fn id(&self) -> &str {
255            &self.step_id
256        }
257
258        async fn execute(&self, _ctx: &mut WorkflowContext) -> Result<StepOutput, WorkflowError> {
259            Ok(StepOutput::new(&self.output))
260        }
261    }
262
263    /// Step that reads a dependency's output and appends to it.
264    struct AppendStep {
265        step_id: String,
266        dep_id: String,
267        suffix: String,
268    }
269
270    impl AppendStep {
271        fn new(id: &str, dep_id: &str, suffix: &str) -> Self {
272            Self {
273                step_id: id.into(),
274                dep_id: dep_id.into(),
275                suffix: suffix.into(),
276            }
277        }
278    }
279
280    #[async_trait::async_trait]
281    impl Step for AppendStep {
282        fn id(&self) -> &str {
283            &self.step_id
284        }
285
286        async fn execute(&self, ctx: &mut WorkflowContext) -> Result<StepOutput, WorkflowError> {
287            let prev = ctx
288                .output(&self.dep_id)
289                .map(|o| o.value().to_string())
290                .unwrap_or_default();
291            Ok(StepOutput::new(&format!("{prev}{}", self.suffix)))
292        }
293    }
294
295    /// Step that fails.
296    struct FailStep {
297        step_id: String,
298        message: String,
299    }
300
301    impl FailStep {
302        fn new(id: &str, message: &str) -> Self {
303            Self {
304                step_id: id.into(),
305                message: message.into(),
306            }
307        }
308    }
309
310    #[async_trait::async_trait]
311    impl Step for FailStep {
312        fn id(&self) -> &str {
313            &self.step_id
314        }
315
316        async fn execute(&self, _ctx: &mut WorkflowContext) -> Result<StepOutput, WorkflowError> {
317            Err(WorkflowError::StepFailed {
318                step_id: self.step_id.clone(),
319                message: self.message.clone(),
320            })
321        }
322    }
323
324    /// Step that tracks execution via an atomic counter.
325    struct CountStep {
326        step_id: String,
327        counter: Arc<AtomicUsize>,
328        delay: Option<Duration>,
329    }
330
331    impl CountStep {
332        fn new(id: &str, counter: Arc<AtomicUsize>) -> Self {
333            Self {
334                step_id: id.into(),
335                counter,
336                delay: None,
337            }
338        }
339
340        fn with_delay(mut self, delay: Duration) -> Self {
341            self.delay = Some(delay);
342            self
343        }
344    }
345
346    #[async_trait::async_trait]
347    impl Step for CountStep {
348        fn id(&self) -> &str {
349            &self.step_id
350        }
351
352        async fn execute(&self, _ctx: &mut WorkflowContext) -> Result<StepOutput, WorkflowError> {
353            self.counter.fetch_add(1, Ordering::SeqCst);
354            if let Some(d) = self.delay {
355                tokio::time::sleep(d).await;
356            }
357            Ok(StepOutput::new("done"))
358        }
359    }
360
361    // === Engine Tests ===
362
363    #[tokio::test]
364    async fn runs_single_step() {
365        let workflow = Workflow::builder()
366            .step(ValueStep::new("a", "hello"), &[])
367            .build()
368            .unwrap();
369
370        let engine = WorkflowEngine::new();
371        let result = engine.run(workflow).await.unwrap();
372
373        assert!(result.is_completed("a"));
374        assert_eq!(result.output("a").unwrap().value(), "hello");
375    }
376
377    #[tokio::test]
378    async fn runs_linear_chain_passing_context() {
379        let workflow = Workflow::builder()
380            .step(ValueStep::new("a", "start"), &[])
381            .step(AppendStep::new("b", "a", "_middle"), &["a"])
382            .step(AppendStep::new("c", "b", "_end"), &["b"])
383            .build()
384            .unwrap();
385
386        let engine = WorkflowEngine::new();
387        let result = engine.run(workflow).await.unwrap();
388
389        assert_eq!(result.output("c").unwrap().value(), "start_middle_end");
390    }
391
392    #[tokio::test]
393    async fn runs_parallel_independent_steps() {
394        let counter = Arc::new(AtomicUsize::new(0));
395
396        let workflow = Workflow::builder()
397            .step(
398                CountStep::new("a", counter.clone()).with_delay(Duration::from_millis(50)),
399                &[],
400            )
401            .step(
402                CountStep::new("b", counter.clone()).with_delay(Duration::from_millis(50)),
403                &[],
404            )
405            .step(
406                CountStep::new("c", counter.clone()).with_delay(Duration::from_millis(50)),
407                &[],
408            )
409            .build()
410            .unwrap();
411
412        let engine = WorkflowEngine::new();
413        let start = std::time::Instant::now();
414        let result = engine.run(workflow).await.unwrap();
415        let elapsed = start.elapsed();
416
417        // All 3 should have run
418        assert_eq!(counter.load(Ordering::SeqCst), 3);
419        assert!(result.is_completed("a"));
420        assert!(result.is_completed("b"));
421        assert!(result.is_completed("c"));
422
423        // Should run in parallel (< 120ms), not sequentially (>= 150ms)
424        assert!(elapsed < Duration::from_millis(120));
425    }
426
427    #[tokio::test]
428    async fn step_failure_propagates_error() {
429        let workflow = Workflow::builder()
430            .step(FailStep::new("a", "boom"), &[])
431            .build()
432            .unwrap();
433
434        let engine = WorkflowEngine::new();
435        let result = engine.run(workflow).await;
436
437        assert!(result.is_err());
438        assert!(matches!(
439            result.unwrap_err(),
440            WorkflowError::StepFailed { step_id, .. } if step_id == "a"
441        ));
442    }
443
444    #[tokio::test]
445    async fn dependent_step_skipped_when_dependency_fails() {
446        let counter = Arc::new(AtomicUsize::new(0));
447
448        let workflow = Workflow::builder()
449            .step(FailStep::new("a", "boom"), &[])
450            .step(CountStep::new("b", counter.clone()), &["a"])
451            .build()
452            .unwrap();
453
454        let engine = WorkflowEngine::new();
455        let result = engine.run(workflow).await;
456
457        // Workflow fails
458        assert!(result.is_err());
459        // Step b never ran
460        assert_eq!(counter.load(Ordering::SeqCst), 0);
461    }
462
463    #[tokio::test]
464    async fn diamond_workflow_executes_correctly() {
465        //     a
466        //    / \
467        //   b   c
468        //    \ /
469        //     d
470        let workflow = Workflow::builder()
471            .step(ValueStep::new("a", "A"), &[])
472            .step(AppendStep::new("b", "a", "_B"), &["a"])
473            .step(AppendStep::new("c", "a", "_C"), &["a"])
474            .step(AppendStep::new("d", "b", "_D"), &["b", "c"])
475            .build()
476            .unwrap();
477
478        let engine = WorkflowEngine::new();
479        let result = engine.run(workflow).await.unwrap();
480
481        assert_eq!(result.output("a").unwrap().value(), "A");
482        assert_eq!(result.output("b").unwrap().value(), "A_B");
483        assert_eq!(result.output("c").unwrap().value(), "A_C");
484        // d depends on b, reads b's output
485        assert_eq!(result.output("d").unwrap().value(), "A_B_D");
486    }
487
488    // === Checkpointed Run Tests ===
489
490    #[tokio::test]
491    async fn checkpointed_run_saves_checkpoint_file() {
492        let dir = tempfile::tempdir().unwrap();
493        let ckpt_path = dir.path().join("checkpoint.json");
494
495        let workflow = Workflow::builder()
496            .step(ValueStep::new("a", "A"), &[])
497            .step(ValueStep::new("b", "B"), &["a"])
498            .build()
499            .unwrap();
500
501        let engine = WorkflowEngine::new();
502        let result = engine
503            .run_with_checkpoint(workflow, &ckpt_path)
504            .await
505            .unwrap();
506
507        assert!(ckpt_path.exists());
508        assert!(result.is_completed("a"));
509        assert!(result.is_completed("b"));
510    }
511
512    #[tokio::test]
513    async fn checkpointed_run_skips_completed_steps() {
514        let dir = tempfile::tempdir().unwrap();
515        let ckpt_path = dir.path().join("checkpoint.json");
516
517        // Pre-populate checkpoint with step "a" completed
518        let mut pre_checkpoint = crate::checkpoint::Checkpoint::new();
519        pre_checkpoint.mark_completed("a", StepOutput::new("A"));
520        pre_checkpoint.save(&ckpt_path).await.unwrap();
521
522        let counter = Arc::new(AtomicUsize::new(0));
523
524        let workflow = Workflow::builder()
525            .step(CountStep::new("a", counter.clone()), &[])
526            .step(CountStep::new("b", counter.clone()), &["a"])
527            .build()
528            .unwrap();
529
530        let engine = WorkflowEngine::new();
531        let result = engine
532            .run_with_checkpoint(workflow, &ckpt_path)
533            .await
534            .unwrap();
535
536        // Step "a" was already in checkpoint, should not run again
537        // Only step "b" should have run
538        assert_eq!(counter.load(Ordering::SeqCst), 1);
539        assert!(result.is_completed("a"));
540        assert!(result.is_completed("b"));
541    }
542
543    #[tokio::test]
544    async fn returns_all_completed_step_ids() {
545        let workflow = Workflow::builder()
546            .step(ValueStep::new("x", "1"), &[])
547            .step(ValueStep::new("y", "2"), &[])
548            .build()
549            .unwrap();
550
551        let engine = WorkflowEngine::new();
552        let result = engine.run(workflow).await.unwrap();
553
554        let mut ids = result.completed_step_ids();
555        ids.sort_unstable();
556        assert_eq!(ids, vec!["x", "y"]);
557    }
558}