Skip to main content

forge_agent/workflow/
combinators.rs

1//! Task composition helpers for complex workflows.
2//!
3//! Provides combinators for conditional execution, error recovery,
4//! and parallel task patterns.
5
6use crate::workflow::task::{TaskContext, TaskError, TaskId, TaskResult, WorkflowTask};
7use async_trait::async_trait;
8
9/// Task that executes conditionally based on another task's result.
10///
11/// The condition task is executed first, then based on its result,
12/// either the then_task or else_task is executed.
13pub struct ConditionalTask {
14    /// Task that determines which branch to execute
15    condition_task: Box<dyn WorkflowTask>,
16    /// Task to execute if condition succeeds
17    then_task: Box<dyn WorkflowTask>,
18    /// Optional task to execute if condition fails
19    else_task: Option<Box<dyn WorkflowTask>>,
20}
21
22impl ConditionalTask {
23    /// Creates a new conditional task.
24    ///
25    /// # Arguments
26    ///
27    /// * `condition_task` - Task whose result determines the branch
28    /// * `then_task` - Task executed on success
29    /// * `else_task` - Optional task executed on failure
30    ///
31    /// # Example
32    ///
33    /// ```ignore
34    /// let condition = Box::new(FunctionTask::new(
35    ///     TaskId::new("check"),
36    ///     "Check".to_string(),
37    ///     |_ctx| async { Ok(TaskResult::Success) }
38    /// ));
39    /// let then_branch = Box::new(FunctionTask::new(
40    ///     TaskId::new("then"),
41    ///     "Then".to_string(),
42    ///     |_ctx| async { Ok(TaskResult::Success) }
43    /// ));
44    /// let task = ConditionalTask::new(condition, then_branch, None);
45    /// ```
46    pub fn new(
47        condition_task: Box<dyn WorkflowTask>,
48        then_task: Box<dyn WorkflowTask>,
49        else_task: Option<Box<dyn WorkflowTask>>,
50    ) -> Self {
51        Self {
52            condition_task,
53            then_task,
54            else_task,
55        }
56    }
57
58    /// Creates a conditional task with an else branch.
59    pub fn with_else(
60        condition_task: Box<dyn WorkflowTask>,
61        then_task: Box<dyn WorkflowTask>,
62        else_task: Box<dyn WorkflowTask>,
63    ) -> Self {
64        Self::new(condition_task, then_task, Some(else_task))
65    }
66}
67
68#[async_trait]
69impl WorkflowTask for ConditionalTask {
70    async fn execute(&self, context: &TaskContext) -> Result<TaskResult, TaskError> {
71        // Execute condition task
72        let condition_result = self.condition_task.execute(context).await?;
73
74        match condition_result {
75            TaskResult::Success => {
76                // Execute then task
77                self.then_task.execute(context).await
78            }
79            TaskResult::Failed(_) | TaskResult::Skipped => {
80                // Execute else task if available, otherwise return condition result
81                if let Some(else_task) = &self.else_task {
82                    else_task.execute(context).await
83                } else {
84                    Ok(condition_result)
85                }
86            }
87            TaskResult::WithCompensation { .. } => {
88                // For now, treat WithCompensation as Success and execute then task
89                // TODO: Consider if compensation should propagate
90                self.then_task.execute(context).await
91            }
92        }
93    }
94
95    fn id(&self) -> TaskId {
96        self.condition_task.id()
97    }
98
99    fn name(&self) -> &str {
100        self.condition_task.name()
101    }
102
103    fn dependencies(&self) -> Vec<TaskId> {
104        self.condition_task.dependencies()
105    }
106}
107
108/// Task that executes with error recovery.
109///
110/// The try_task is executed first. If it fails, the catch_task is
111/// executed instead, allowing the workflow to continue gracefully.
112pub struct TryCatchTask {
113    /// Task to attempt
114    try_task: Box<dyn WorkflowTask>,
115    /// Task to execute on failure
116    catch_task: Box<dyn WorkflowTask>,
117}
118
119impl TryCatchTask {
120    /// Creates a new try-catch task.
121    ///
122    /// # Arguments
123    ///
124    /// * `try_task` - Task to attempt
125    /// * `catch_task` - Task executed on try_task failure
126    ///
127    /// # Example
128    ///
129    /// ```ignore
130    /// let try_task = Box::new(FunctionTask::new(
131    ///     TaskId::new("risky"),
132    ///     "Risky Operation".to_string(),
133    ///     |_ctx| async { Ok(TaskResult::Failed("error".to_string())) }
134    /// ));
135    /// let catch_task = Box::new(FunctionTask::new(
136    ///     TaskId::new("recover"),
137    ///     "Recovery".to_string(),
138    ///     |_ctx| async { Ok(TaskResult::Success) }
139    /// ));
140    /// let task = TryCatchTask::new(try_task, catch_task);
141    /// ```
142    pub fn new(try_task: Box<dyn WorkflowTask>, catch_task: Box<dyn WorkflowTask>) -> Self {
143        Self {
144            try_task,
145            catch_task,
146        }
147    }
148}
149
150#[async_trait]
151impl WorkflowTask for TryCatchTask {
152    async fn execute(&self, context: &TaskContext) -> Result<TaskResult, TaskError> {
153        // Execute try task
154        let try_result = self.try_task.execute(context).await;
155
156        match try_result {
157            Ok(TaskResult::Success) => try_result,
158            Ok(TaskResult::Failed(_)) | Ok(TaskResult::Skipped) => {
159                // Execute catch task on failure
160                self.catch_task.execute(context).await
161            }
162            Ok(TaskResult::WithCompensation { .. }) => {
163                // For now, execute catch task on compensation result
164                // TODO: Consider if compensation should execute before catch
165                self.catch_task.execute(context).await
166            }
167            Err(_) => {
168                // Execute catch task on error
169                self.catch_task.execute(context).await
170            }
171        }
172    }
173
174    fn id(&self) -> TaskId {
175        self.try_task.id()
176    }
177
178    fn name(&self) -> &str {
179        self.try_task.name()
180    }
181
182    fn dependencies(&self) -> Vec<TaskId> {
183        self.try_task.dependencies()
184    }
185}
186
187/// Task that executes multiple tasks in parallel using JoinSet.
188///
189/// Tasks are spawned concurrently and all must succeed for the
190/// parallel task to succeed. Uses fail-fast behavior on first error.
191pub struct ParallelTasks {
192    /// Tasks to execute
193    tasks: Vec<Box<dyn WorkflowTask>>,
194}
195
196impl ParallelTasks {
197    /// Creates a new parallel tasks wrapper.
198    ///
199    /// # Arguments
200    ///
201    /// * `tasks` - Vector of tasks to execute
202    ///
203    /// # Example
204    ///
205    /// ```ignore
206    /// let task1 = Box::new(FunctionTask::new(
207    ///     TaskId::new("task1"),
208    ///     "Task 1".to_string(),
209    ///     |_ctx| async { Ok(TaskResult::Success) }
210    /// ));
211    /// let task2 = Box::new(FunctionTask::new(
212    ///     TaskId::new("task2"),
213    ///     "Task 2".to_string(),
214    ///     |_ctx| async { Ok(TaskResult::Success) }
215    /// ));
216    /// let parallel = ParallelTasks::new(vec![task1, task2]);
217    /// ```
218    pub fn new(tasks: Vec<Box<dyn WorkflowTask>>) -> Self {
219        Self { tasks }
220    }
221}
222
223#[async_trait]
224impl WorkflowTask for ParallelTasks {
225    async fn execute(&self, context: &TaskContext) -> Result<TaskResult, TaskError> {
226        // Execute tasks sequentially for now.
227        // True parallelism is available at the DAG level via execute_parallel().
228        // This combinator provides logical grouping of tasks that can run together.
229        for task in &self.tasks {
230            let task_result = task.execute(context).await?;
231            match task_result {
232                TaskResult::Success => continue,
233                TaskResult::Failed(msg) => return Ok(TaskResult::Failed(msg)),
234                TaskResult::Skipped => continue,
235                TaskResult::WithCompensation { result, compensation } => {
236                    // Note: ParallelTasks doesn't have access to compensation registry
237                    // Compensations are lost - this is a known limitation
238                    // For proper compensation handling, use DAG-level parallel execution
239                    match *result {
240                        TaskResult::Success => continue,
241                        TaskResult::Failed(msg) => return Ok(TaskResult::Failed(msg)),
242                        TaskResult::Skipped => continue,
243                        TaskResult::WithCompensation { .. } => continue,
244                    }
245                }
246            }
247        }
248
249        Ok(TaskResult::Success)
250    }
251
252    fn id(&self) -> TaskId {
253        TaskId::new("parallel_tasks")
254    }
255
256    fn name(&self) -> &str {
257        "Parallel Tasks"
258    }
259
260    fn dependencies(&self) -> Vec<TaskId> {
261        Vec::new()
262    }
263}
264
265#[cfg(test)]
266mod tests {
267    use super::*;
268    use crate::workflow::tasks::FunctionTask;
269    use std::time::Duration;
270    use std::time::Instant;
271
272    #[tokio::test]
273    async fn test_conditional_task_then_branch() {
274        let condition = Box::new(FunctionTask::new(
275            TaskId::new("check"),
276            "Check".to_string(),
277            |_ctx| async { Ok(TaskResult::Success) },
278        ));
279
280        let then_task = Box::new(FunctionTask::new(
281            TaskId::new("then"),
282            "Then".to_string(),
283            |_ctx| async { Ok(TaskResult::Success) },
284        ));
285
286        let conditional = ConditionalTask::new(condition, then_task, None);
287        let context = TaskContext::new("workflow-1", TaskId::new("check"));
288
289        let result = conditional.execute(&context).await.unwrap();
290        assert_eq!(result, TaskResult::Success);
291    }
292
293    #[tokio::test]
294    async fn test_conditional_task_else_branch() {
295        let condition = Box::new(FunctionTask::new(
296            TaskId::new("check"),
297            "Check".to_string(),
298            |_ctx| async { Ok(TaskResult::Failed("error".to_string())) },
299        ));
300
301        let then_task = Box::new(FunctionTask::new(
302            TaskId::new("then"),
303            "Then".to_string(),
304            |_ctx| async { Ok(TaskResult::Success) },
305        ));
306
307        let else_task = Box::new(FunctionTask::new(
308            TaskId::new("else"),
309            "Else".to_string(),
310            |_ctx| async { Ok(TaskResult::Success) },
311        ));
312
313        let conditional = ConditionalTask::with_else(condition, then_task, else_task);
314        let context = TaskContext::new("workflow-1", TaskId::new("check"));
315
316        let result = conditional.execute(&context).await.unwrap();
317        assert_eq!(result, TaskResult::Success);
318    }
319
320    #[tokio::test]
321    async fn test_conditional_task_no_else_returns_failure() {
322        let condition = Box::new(FunctionTask::new(
323            TaskId::new("check"),
324            "Check".to_string(),
325            |_ctx| async { Ok(TaskResult::Failed("error".to_string())) },
326        ));
327
328        let then_task = Box::new(FunctionTask::new(
329            TaskId::new("then"),
330            "Then".to_string(),
331            |_ctx| async { Ok(TaskResult::Success) },
332        ));
333
334        let conditional = ConditionalTask::new(condition, then_task, None);
335        let context = TaskContext::new("workflow-1", TaskId::new("check"));
336
337        let result = conditional.execute(&context).await.unwrap();
338        assert!(matches!(result, TaskResult::Failed(_)));
339    }
340
341    #[tokio::test]
342    async fn test_try_catch_task_success() {
343        let try_task = Box::new(FunctionTask::new(
344            TaskId::new("risky"),
345            "Risky".to_string(),
346            |_ctx| async { Ok(TaskResult::Success) },
347        ));
348
349        let catch_task = Box::new(FunctionTask::new(
350            TaskId::new("recover"),
351            "Recover".to_string(),
352            |_ctx| async { Ok(TaskResult::Success) },
353        ));
354
355        let try_catch = TryCatchTask::new(try_task, catch_task);
356        let context = TaskContext::new("workflow-1", TaskId::new("risky"));
357
358        let result = try_catch.execute(&context).await.unwrap();
359        assert_eq!(result, TaskResult::Success);
360    }
361
362    #[tokio::test]
363    async fn test_try_catch_task_failure_recovery() {
364        let try_task = Box::new(FunctionTask::new(
365            TaskId::new("risky"),
366            "Risky".to_string(),
367            |_ctx| async { Ok(TaskResult::Failed("error".to_string())) },
368        ));
369
370        let catch_task = Box::new(FunctionTask::new(
371            TaskId::new("recover"),
372            "Recover".to_string(),
373            |_ctx| async { Ok(TaskResult::Success) },
374        ));
375
376        let try_catch = TryCatchTask::new(try_task, catch_task);
377        let context = TaskContext::new("workflow-1", TaskId::new("risky"));
378
379        let result = try_catch.execute(&context).await.unwrap();
380        assert_eq!(result, TaskResult::Success);
381    }
382
383    #[tokio::test]
384    async fn test_parallel_tasks_sequential_stub() {
385        let task1 = Box::new(FunctionTask::new(
386            TaskId::new("task1"),
387            "Task 1".to_string(),
388            |_ctx| async { Ok(TaskResult::Success) },
389        ));
390
391        let task2 = Box::new(FunctionTask::new(
392            TaskId::new("task2"),
393            "Task 2".to_string(),
394            |_ctx| async { Ok(TaskResult::Success) },
395        ));
396
397        let parallel = ParallelTasks::new(vec![task1, task2]);
398        let context = TaskContext::new("workflow-1", TaskId::new("parallel_tasks"));
399
400        let result = parallel.execute(&context).await.unwrap();
401        assert_eq!(result, TaskResult::Success);
402    }
403
404    #[tokio::test]
405    async fn test_parallel_tasks_failure_stops() {
406        // Tests fail-fast behavior: second task fails, parallel execution should stop
407        let task1 = Box::new(FunctionTask::new(
408            TaskId::new("task1"),
409            "Task 1".to_string(),
410            |_ctx| async { Ok(TaskResult::Success) },
411        ));
412
413        let task2 = Box::new(FunctionTask::new(
414            TaskId::new("task2"),
415            "Task 2".to_string(),
416            |_ctx| async { Err(TaskError::ExecutionFailed("error".to_string())) },
417        ));
418
419        let parallel = ParallelTasks::new(vec![task1, task2]);
420        let context = TaskContext::new("workflow-1", TaskId::new("parallel_tasks"));
421
422        // Now with actual execution, task2 should fail
423        let result = parallel.execute(&context).await;
424        assert!(result.is_err());
425    }
426
427    #[tokio::test]
428    async fn test_parallel_tasks_sequential_execution() {
429        use std::time::Instant;
430
431        // NOTE: ParallelTasks executes sequentially, not in parallel.
432        // For true parallel task execution, use the DAG's execute_parallel().
433        // This combinator provides logical grouping of tasks that can be
434        // executed together when placed in a parallel execution layer.
435
436        // Create two tasks that each take 50ms
437        // Since ParallelTasks executes sequentially, total time should be ~100ms
438        let task1 = Box::new(FunctionTask::new(
439            TaskId::new("task1"),
440            "Task 1".to_string(),
441            |_ctx| async {
442                tokio::time::sleep(Duration::from_millis(50)).await;
443                Ok(TaskResult::Success)
444            },
445        ));
446
447        let task2 = Box::new(FunctionTask::new(
448            TaskId::new("task2"),
449            "Task 2".to_string(),
450            |_ctx| async {
451                tokio::time::sleep(Duration::from_millis(50)).await;
452                Ok(TaskResult::Success)
453            },
454        ));
455
456        let parallel = ParallelTasks::new(vec![task1, task2]);
457        let context = TaskContext::new("workflow-1", TaskId::new("parallel_tasks"));
458
459        let start = Instant::now();
460        let result = parallel.execute(&context).await;
461        let elapsed = start.elapsed();
462
463        assert!(result.is_ok());
464        assert_eq!(result.unwrap(), TaskResult::Success);
465
466        // Sequential execution should complete in ~100ms (not ~50ms)
467        // Allow some tolerance for scheduling overhead
468        assert!(elapsed.as_millis() >= 80, "Expected ~100ms sequential but got {}ms", elapsed.as_millis());
469        assert!(elapsed.as_millis() < 150, "Expected ~100ms but got {}ms", elapsed.as_millis());
470    }
471}