1use std::sync::Arc;
4
5use tokio::sync::Mutex;
6
7use std::path::Path;
8
9use crate::WorkflowError;
10use crate::builder::Workflow;
11use crate::checkpoint::Checkpoint;
12use crate::context::WorkflowContext;
13use crate::step::StepOutput;
14
15#[derive(Debug, Clone, Default)]
19pub struct WorkflowEngine;
20
21impl WorkflowEngine {
22 pub fn new() -> Self {
24 Self
25 }
26
27 pub async fn run(&self, workflow: Workflow) -> Result<WorkflowContext, WorkflowError> {
32 let groups = workflow.parallel_groups()?;
33 let ctx = Arc::new(Mutex::new(WorkflowContext::new()));
34 let failed: Arc<Mutex<Option<WorkflowError>>> = Arc::new(Mutex::new(None));
35
36 for group in groups {
37 if failed.lock().await.is_some() {
39 break;
40 }
41
42 if group.len() == 1 {
43 let step_id = group[0];
45 let step = workflow.step(step_id).expect("DAG validated step exists");
46
47 let mut ctx_guard = ctx.lock().await;
48 match step.execute(&mut ctx_guard).await {
49 Ok(output) => {
50 ctx_guard.set_output(step_id, output);
51 }
52 Err(e) => {
53 return Err(e);
54 }
55 }
56 } else {
57 let mut handles = Vec::with_capacity(group.len());
59
60 for step_id in &group {
61 let step = workflow.step(step_id).expect("DAG validated step exists");
62 let ctx_clone = ctx.clone();
63 let failed_clone = failed.clone();
64 let step_id_owned = (*step_id).to_string();
65
66 let handle = tokio::spawn(async move {
67 let mut ctx_snapshot = ctx_clone.lock().await.clone();
69 drop(ctx_clone); match step.execute(&mut ctx_snapshot).await {
72 Ok(output) => Ok((step_id_owned, output)),
73 Err(e) => {
74 *failed_clone.lock().await = Some(WorkflowError::StepFailed {
75 step_id: step_id_owned.clone(),
76 message: e.to_string(),
77 });
78 Err(e)
79 }
80 }
81 });
82
83 handles.push(handle);
84 }
85
86 let mut first_error: Option<WorkflowError> = None;
88 let mut outputs: Vec<(String, StepOutput)> = Vec::new();
89
90 for handle in handles {
91 match handle.await {
92 Ok(Ok((id, output))) => outputs.push((id, output)),
93 Ok(Err(e)) => {
94 if first_error.is_none() {
95 first_error = Some(e);
96 }
97 }
98 Err(join_err) => {
99 if first_error.is_none() {
100 first_error = Some(WorkflowError::StepFailed {
101 step_id: "unknown".into(),
102 message: format!("Task panicked: {join_err}"),
103 });
104 }
105 }
106 }
107 }
108
109 if let Some(err) = first_error {
111 return Err(err);
112 }
113
114 let mut ctx_guard = ctx.lock().await;
116 for (id, output) in outputs {
117 ctx_guard.set_output(&id, output);
118 }
119 }
120 }
121
122 let result = ctx.lock().await.clone();
123 Ok(result)
124 }
125
126 pub async fn run_with_checkpoint(
130 &self,
131 workflow: Workflow,
132 checkpoint_path: &Path,
133 ) -> Result<WorkflowContext, WorkflowError> {
134 let groups = workflow.parallel_groups()?;
135
136 let mut checkpoint = if checkpoint_path.exists() {
138 Checkpoint::load(checkpoint_path).await?
139 } else {
140 Checkpoint::new()
141 };
142
143 let ctx = Arc::new(Mutex::new(checkpoint.clone().into_context()));
144
145 for group in groups {
146 let pending: Vec<&str> = group
148 .iter()
149 .filter(|id| !checkpoint.is_completed(id))
150 .copied()
151 .collect();
152
153 if pending.is_empty() {
154 continue;
155 }
156
157 if pending.len() == 1 {
158 let step_id = pending[0];
159 let step = workflow.step(step_id).expect("DAG validated");
160 let mut ctx_guard = ctx.lock().await;
161 let output = step.execute(&mut ctx_guard).await?;
162 ctx_guard.set_output(step_id, output.clone());
163 checkpoint.mark_completed(step_id, output);
164 } else {
165 let mut handles = Vec::with_capacity(pending.len());
166
167 for step_id in &pending {
168 let step = workflow.step(step_id).expect("DAG validated");
169 let ctx_clone = ctx.clone();
170 let step_id_owned = (*step_id).to_string();
171
172 let handle = tokio::spawn(async move {
173 let mut ctx_snapshot = ctx_clone.lock().await.clone();
174 drop(ctx_clone);
175 let output = step.execute(&mut ctx_snapshot).await?;
176 Ok::<_, WorkflowError>((step_id_owned, output))
177 });
178 handles.push(handle);
179 }
180
181 let mut first_error: Option<WorkflowError> = None;
182 let mut outputs: Vec<(String, StepOutput)> = Vec::new();
183
184 for handle in handles {
185 match handle.await {
186 Ok(Ok((id, output))) => outputs.push((id, output)),
187 Ok(Err(e)) => {
188 if first_error.is_none() {
189 first_error = Some(e);
190 }
191 }
192 Err(join_err) => {
193 if first_error.is_none() {
194 first_error = Some(WorkflowError::StepFailed {
195 step_id: "unknown".into(),
196 message: format!("Task panicked: {join_err}"),
197 });
198 }
199 }
200 }
201 }
202
203 if let Some(err) = first_error {
204 checkpoint.save(checkpoint_path).await?;
206 return Err(err);
207 }
208
209 let mut ctx_guard = ctx.lock().await;
210 for (id, output) in outputs {
211 ctx_guard.set_output(&id, output.clone());
212 checkpoint.mark_completed(&id, output);
213 }
214 }
215
216 checkpoint.save(checkpoint_path).await?;
218 }
219
220 let result = ctx.lock().await.clone();
221 Ok(result)
222 }
223}
224
225#[cfg(test)]
226mod tests {
227 use super::*;
228 use crate::WorkflowError;
229 use crate::builder::Workflow;
230 use crate::context::WorkflowContext;
231 use crate::step::{Step, StepOutput};
232 use std::sync::Arc;
233 use std::sync::atomic::{AtomicUsize, Ordering};
234 use std::time::Duration;
235
236 struct ValueStep {
239 step_id: String,
240 output: String,
241 }
242
243 impl ValueStep {
244 fn new(id: &str, output: &str) -> Self {
245 Self {
246 step_id: id.into(),
247 output: output.into(),
248 }
249 }
250 }
251
252 #[async_trait::async_trait]
253 impl Step for ValueStep {
254 fn id(&self) -> &str {
255 &self.step_id
256 }
257
258 async fn execute(&self, _ctx: &mut WorkflowContext) -> Result<StepOutput, WorkflowError> {
259 Ok(StepOutput::new(&self.output))
260 }
261 }
262
263 struct AppendStep {
265 step_id: String,
266 dep_id: String,
267 suffix: String,
268 }
269
270 impl AppendStep {
271 fn new(id: &str, dep_id: &str, suffix: &str) -> Self {
272 Self {
273 step_id: id.into(),
274 dep_id: dep_id.into(),
275 suffix: suffix.into(),
276 }
277 }
278 }
279
280 #[async_trait::async_trait]
281 impl Step for AppendStep {
282 fn id(&self) -> &str {
283 &self.step_id
284 }
285
286 async fn execute(&self, ctx: &mut WorkflowContext) -> Result<StepOutput, WorkflowError> {
287 let prev = ctx
288 .output(&self.dep_id)
289 .map(|o| o.value().to_string())
290 .unwrap_or_default();
291 Ok(StepOutput::new(&format!("{prev}{}", self.suffix)))
292 }
293 }
294
295 struct FailStep {
297 step_id: String,
298 message: String,
299 }
300
301 impl FailStep {
302 fn new(id: &str, message: &str) -> Self {
303 Self {
304 step_id: id.into(),
305 message: message.into(),
306 }
307 }
308 }
309
310 #[async_trait::async_trait]
311 impl Step for FailStep {
312 fn id(&self) -> &str {
313 &self.step_id
314 }
315
316 async fn execute(&self, _ctx: &mut WorkflowContext) -> Result<StepOutput, WorkflowError> {
317 Err(WorkflowError::StepFailed {
318 step_id: self.step_id.clone(),
319 message: self.message.clone(),
320 })
321 }
322 }
323
324 struct CountStep {
326 step_id: String,
327 counter: Arc<AtomicUsize>,
328 delay: Option<Duration>,
329 }
330
331 impl CountStep {
332 fn new(id: &str, counter: Arc<AtomicUsize>) -> Self {
333 Self {
334 step_id: id.into(),
335 counter,
336 delay: None,
337 }
338 }
339
340 fn with_delay(mut self, delay: Duration) -> Self {
341 self.delay = Some(delay);
342 self
343 }
344 }
345
346 #[async_trait::async_trait]
347 impl Step for CountStep {
348 fn id(&self) -> &str {
349 &self.step_id
350 }
351
352 async fn execute(&self, _ctx: &mut WorkflowContext) -> Result<StepOutput, WorkflowError> {
353 self.counter.fetch_add(1, Ordering::SeqCst);
354 if let Some(d) = self.delay {
355 tokio::time::sleep(d).await;
356 }
357 Ok(StepOutput::new("done"))
358 }
359 }
360
361 #[tokio::test]
364 async fn runs_single_step() {
365 let workflow = Workflow::builder()
366 .step(ValueStep::new("a", "hello"), &[])
367 .build()
368 .unwrap();
369
370 let engine = WorkflowEngine::new();
371 let result = engine.run(workflow).await.unwrap();
372
373 assert!(result.is_completed("a"));
374 assert_eq!(result.output("a").unwrap().value(), "hello");
375 }
376
377 #[tokio::test]
378 async fn runs_linear_chain_passing_context() {
379 let workflow = Workflow::builder()
380 .step(ValueStep::new("a", "start"), &[])
381 .step(AppendStep::new("b", "a", "_middle"), &["a"])
382 .step(AppendStep::new("c", "b", "_end"), &["b"])
383 .build()
384 .unwrap();
385
386 let engine = WorkflowEngine::new();
387 let result = engine.run(workflow).await.unwrap();
388
389 assert_eq!(result.output("c").unwrap().value(), "start_middle_end");
390 }
391
392 #[tokio::test]
393 async fn runs_parallel_independent_steps() {
394 let counter = Arc::new(AtomicUsize::new(0));
395
396 let workflow = Workflow::builder()
397 .step(
398 CountStep::new("a", counter.clone()).with_delay(Duration::from_millis(50)),
399 &[],
400 )
401 .step(
402 CountStep::new("b", counter.clone()).with_delay(Duration::from_millis(50)),
403 &[],
404 )
405 .step(
406 CountStep::new("c", counter.clone()).with_delay(Duration::from_millis(50)),
407 &[],
408 )
409 .build()
410 .unwrap();
411
412 let engine = WorkflowEngine::new();
413 let start = std::time::Instant::now();
414 let result = engine.run(workflow).await.unwrap();
415 let elapsed = start.elapsed();
416
417 assert_eq!(counter.load(Ordering::SeqCst), 3);
419 assert!(result.is_completed("a"));
420 assert!(result.is_completed("b"));
421 assert!(result.is_completed("c"));
422
423 assert!(elapsed < Duration::from_millis(120));
425 }
426
427 #[tokio::test]
428 async fn step_failure_propagates_error() {
429 let workflow = Workflow::builder()
430 .step(FailStep::new("a", "boom"), &[])
431 .build()
432 .unwrap();
433
434 let engine = WorkflowEngine::new();
435 let result = engine.run(workflow).await;
436
437 assert!(result.is_err());
438 assert!(matches!(
439 result.unwrap_err(),
440 WorkflowError::StepFailed { step_id, .. } if step_id == "a"
441 ));
442 }
443
444 #[tokio::test]
445 async fn dependent_step_skipped_when_dependency_fails() {
446 let counter = Arc::new(AtomicUsize::new(0));
447
448 let workflow = Workflow::builder()
449 .step(FailStep::new("a", "boom"), &[])
450 .step(CountStep::new("b", counter.clone()), &["a"])
451 .build()
452 .unwrap();
453
454 let engine = WorkflowEngine::new();
455 let result = engine.run(workflow).await;
456
457 assert!(result.is_err());
459 assert_eq!(counter.load(Ordering::SeqCst), 0);
461 }
462
463 #[tokio::test]
464 async fn diamond_workflow_executes_correctly() {
465 let workflow = Workflow::builder()
471 .step(ValueStep::new("a", "A"), &[])
472 .step(AppendStep::new("b", "a", "_B"), &["a"])
473 .step(AppendStep::new("c", "a", "_C"), &["a"])
474 .step(AppendStep::new("d", "b", "_D"), &["b", "c"])
475 .build()
476 .unwrap();
477
478 let engine = WorkflowEngine::new();
479 let result = engine.run(workflow).await.unwrap();
480
481 assert_eq!(result.output("a").unwrap().value(), "A");
482 assert_eq!(result.output("b").unwrap().value(), "A_B");
483 assert_eq!(result.output("c").unwrap().value(), "A_C");
484 assert_eq!(result.output("d").unwrap().value(), "A_B_D");
486 }
487
488 #[tokio::test]
491 async fn checkpointed_run_saves_checkpoint_file() {
492 let dir = tempfile::tempdir().unwrap();
493 let ckpt_path = dir.path().join("checkpoint.json");
494
495 let workflow = Workflow::builder()
496 .step(ValueStep::new("a", "A"), &[])
497 .step(ValueStep::new("b", "B"), &["a"])
498 .build()
499 .unwrap();
500
501 let engine = WorkflowEngine::new();
502 let result = engine
503 .run_with_checkpoint(workflow, &ckpt_path)
504 .await
505 .unwrap();
506
507 assert!(ckpt_path.exists());
508 assert!(result.is_completed("a"));
509 assert!(result.is_completed("b"));
510 }
511
512 #[tokio::test]
513 async fn checkpointed_run_skips_completed_steps() {
514 let dir = tempfile::tempdir().unwrap();
515 let ckpt_path = dir.path().join("checkpoint.json");
516
517 let mut pre_checkpoint = crate::checkpoint::Checkpoint::new();
519 pre_checkpoint.mark_completed("a", StepOutput::new("A"));
520 pre_checkpoint.save(&ckpt_path).await.unwrap();
521
522 let counter = Arc::new(AtomicUsize::new(0));
523
524 let workflow = Workflow::builder()
525 .step(CountStep::new("a", counter.clone()), &[])
526 .step(CountStep::new("b", counter.clone()), &["a"])
527 .build()
528 .unwrap();
529
530 let engine = WorkflowEngine::new();
531 let result = engine
532 .run_with_checkpoint(workflow, &ckpt_path)
533 .await
534 .unwrap();
535
536 assert_eq!(counter.load(Ordering::SeqCst), 1);
539 assert!(result.is_completed("a"));
540 assert!(result.is_completed("b"));
541 }
542
543 #[tokio::test]
544 async fn returns_all_completed_step_ids() {
545 let workflow = Workflow::builder()
546 .step(ValueStep::new("x", "1"), &[])
547 .step(ValueStep::new("y", "2"), &[])
548 .build()
549 .unwrap();
550
551 let engine = WorkflowEngine::new();
552 let result = engine.run(workflow).await.unwrap();
553
554 let mut ids = result.completed_step_ids();
555 ids.sort_unstable();
556 assert_eq!(ids, vec!["x", "y"]);
557 }
558}