1use crate::workflow::task::{TaskContext, TaskError, TaskId, TaskResult, WorkflowTask};
7use async_trait::async_trait;
8
9pub struct ConditionalTask {
14 condition_task: Box<dyn WorkflowTask>,
16 then_task: Box<dyn WorkflowTask>,
18 else_task: Option<Box<dyn WorkflowTask>>,
20}
21
22impl ConditionalTask {
23 pub fn new(
47 condition_task: Box<dyn WorkflowTask>,
48 then_task: Box<dyn WorkflowTask>,
49 else_task: Option<Box<dyn WorkflowTask>>,
50 ) -> Self {
51 Self {
52 condition_task,
53 then_task,
54 else_task,
55 }
56 }
57
58 pub fn with_else(
60 condition_task: Box<dyn WorkflowTask>,
61 then_task: Box<dyn WorkflowTask>,
62 else_task: Box<dyn WorkflowTask>,
63 ) -> Self {
64 Self::new(condition_task, then_task, Some(else_task))
65 }
66}
67
68#[async_trait]
69impl WorkflowTask for ConditionalTask {
70 async fn execute(&self, context: &TaskContext) -> Result<TaskResult, TaskError> {
71 let condition_result = self.condition_task.execute(context).await?;
73
74 match condition_result {
75 TaskResult::Success => {
76 self.then_task.execute(context).await
78 }
79 TaskResult::Failed(_) | TaskResult::Skipped => {
80 if let Some(else_task) = &self.else_task {
82 else_task.execute(context).await
83 } else {
84 Ok(condition_result)
85 }
86 }
87 TaskResult::WithCompensation { .. } => {
88 self.then_task.execute(context).await
91 }
92 }
93 }
94
95 fn id(&self) -> TaskId {
96 self.condition_task.id()
97 }
98
99 fn name(&self) -> &str {
100 self.condition_task.name()
101 }
102
103 fn dependencies(&self) -> Vec<TaskId> {
104 self.condition_task.dependencies()
105 }
106}
107
108pub struct TryCatchTask {
113 try_task: Box<dyn WorkflowTask>,
115 catch_task: Box<dyn WorkflowTask>,
117}
118
119impl TryCatchTask {
120 pub fn new(try_task: Box<dyn WorkflowTask>, catch_task: Box<dyn WorkflowTask>) -> Self {
143 Self {
144 try_task,
145 catch_task,
146 }
147 }
148}
149
150#[async_trait]
151impl WorkflowTask for TryCatchTask {
152 async fn execute(&self, context: &TaskContext) -> Result<TaskResult, TaskError> {
153 let try_result = self.try_task.execute(context).await;
155
156 match try_result {
157 Ok(TaskResult::Success) => try_result,
158 Ok(TaskResult::Failed(_)) | Ok(TaskResult::Skipped) => {
159 self.catch_task.execute(context).await
161 }
162 Ok(TaskResult::WithCompensation { .. }) => {
163 self.catch_task.execute(context).await
166 }
167 Err(_) => {
168 self.catch_task.execute(context).await
170 }
171 }
172 }
173
174 fn id(&self) -> TaskId {
175 self.try_task.id()
176 }
177
178 fn name(&self) -> &str {
179 self.try_task.name()
180 }
181
182 fn dependencies(&self) -> Vec<TaskId> {
183 self.try_task.dependencies()
184 }
185}
186
187pub struct ParallelTasks {
192 tasks: Vec<Box<dyn WorkflowTask>>,
194}
195
196impl ParallelTasks {
197 pub fn new(tasks: Vec<Box<dyn WorkflowTask>>) -> Self {
219 Self { tasks }
220 }
221}
222
223#[async_trait]
224impl WorkflowTask for ParallelTasks {
225 async fn execute(&self, context: &TaskContext) -> Result<TaskResult, TaskError> {
226 for task in &self.tasks {
230 let task_result = task.execute(context).await?;
231 match task_result {
232 TaskResult::Success => continue,
233 TaskResult::Failed(msg) => return Ok(TaskResult::Failed(msg)),
234 TaskResult::Skipped => continue,
235 TaskResult::WithCompensation { result, compensation } => {
236 match *result {
240 TaskResult::Success => continue,
241 TaskResult::Failed(msg) => return Ok(TaskResult::Failed(msg)),
242 TaskResult::Skipped => continue,
243 TaskResult::WithCompensation { .. } => continue,
244 }
245 }
246 }
247 }
248
249 Ok(TaskResult::Success)
250 }
251
252 fn id(&self) -> TaskId {
253 TaskId::new("parallel_tasks")
254 }
255
256 fn name(&self) -> &str {
257 "Parallel Tasks"
258 }
259
260 fn dependencies(&self) -> Vec<TaskId> {
261 Vec::new()
262 }
263}
264
265#[cfg(test)]
266mod tests {
267 use super::*;
268 use crate::workflow::tasks::FunctionTask;
269 use std::time::Duration;
270 use std::time::Instant;
271
272 #[tokio::test]
273 async fn test_conditional_task_then_branch() {
274 let condition = Box::new(FunctionTask::new(
275 TaskId::new("check"),
276 "Check".to_string(),
277 |_ctx| async { Ok(TaskResult::Success) },
278 ));
279
280 let then_task = Box::new(FunctionTask::new(
281 TaskId::new("then"),
282 "Then".to_string(),
283 |_ctx| async { Ok(TaskResult::Success) },
284 ));
285
286 let conditional = ConditionalTask::new(condition, then_task, None);
287 let context = TaskContext::new("workflow-1", TaskId::new("check"));
288
289 let result = conditional.execute(&context).await.unwrap();
290 assert_eq!(result, TaskResult::Success);
291 }
292
293 #[tokio::test]
294 async fn test_conditional_task_else_branch() {
295 let condition = Box::new(FunctionTask::new(
296 TaskId::new("check"),
297 "Check".to_string(),
298 |_ctx| async { Ok(TaskResult::Failed("error".to_string())) },
299 ));
300
301 let then_task = Box::new(FunctionTask::new(
302 TaskId::new("then"),
303 "Then".to_string(),
304 |_ctx| async { Ok(TaskResult::Success) },
305 ));
306
307 let else_task = Box::new(FunctionTask::new(
308 TaskId::new("else"),
309 "Else".to_string(),
310 |_ctx| async { Ok(TaskResult::Success) },
311 ));
312
313 let conditional = ConditionalTask::with_else(condition, then_task, else_task);
314 let context = TaskContext::new("workflow-1", TaskId::new("check"));
315
316 let result = conditional.execute(&context).await.unwrap();
317 assert_eq!(result, TaskResult::Success);
318 }
319
320 #[tokio::test]
321 async fn test_conditional_task_no_else_returns_failure() {
322 let condition = Box::new(FunctionTask::new(
323 TaskId::new("check"),
324 "Check".to_string(),
325 |_ctx| async { Ok(TaskResult::Failed("error".to_string())) },
326 ));
327
328 let then_task = Box::new(FunctionTask::new(
329 TaskId::new("then"),
330 "Then".to_string(),
331 |_ctx| async { Ok(TaskResult::Success) },
332 ));
333
334 let conditional = ConditionalTask::new(condition, then_task, None);
335 let context = TaskContext::new("workflow-1", TaskId::new("check"));
336
337 let result = conditional.execute(&context).await.unwrap();
338 assert!(matches!(result, TaskResult::Failed(_)));
339 }
340
341 #[tokio::test]
342 async fn test_try_catch_task_success() {
343 let try_task = Box::new(FunctionTask::new(
344 TaskId::new("risky"),
345 "Risky".to_string(),
346 |_ctx| async { Ok(TaskResult::Success) },
347 ));
348
349 let catch_task = Box::new(FunctionTask::new(
350 TaskId::new("recover"),
351 "Recover".to_string(),
352 |_ctx| async { Ok(TaskResult::Success) },
353 ));
354
355 let try_catch = TryCatchTask::new(try_task, catch_task);
356 let context = TaskContext::new("workflow-1", TaskId::new("risky"));
357
358 let result = try_catch.execute(&context).await.unwrap();
359 assert_eq!(result, TaskResult::Success);
360 }
361
362 #[tokio::test]
363 async fn test_try_catch_task_failure_recovery() {
364 let try_task = Box::new(FunctionTask::new(
365 TaskId::new("risky"),
366 "Risky".to_string(),
367 |_ctx| async { Ok(TaskResult::Failed("error".to_string())) },
368 ));
369
370 let catch_task = Box::new(FunctionTask::new(
371 TaskId::new("recover"),
372 "Recover".to_string(),
373 |_ctx| async { Ok(TaskResult::Success) },
374 ));
375
376 let try_catch = TryCatchTask::new(try_task, catch_task);
377 let context = TaskContext::new("workflow-1", TaskId::new("risky"));
378
379 let result = try_catch.execute(&context).await.unwrap();
380 assert_eq!(result, TaskResult::Success);
381 }
382
383 #[tokio::test]
384 async fn test_parallel_tasks_sequential_stub() {
385 let task1 = Box::new(FunctionTask::new(
386 TaskId::new("task1"),
387 "Task 1".to_string(),
388 |_ctx| async { Ok(TaskResult::Success) },
389 ));
390
391 let task2 = Box::new(FunctionTask::new(
392 TaskId::new("task2"),
393 "Task 2".to_string(),
394 |_ctx| async { Ok(TaskResult::Success) },
395 ));
396
397 let parallel = ParallelTasks::new(vec![task1, task2]);
398 let context = TaskContext::new("workflow-1", TaskId::new("parallel_tasks"));
399
400 let result = parallel.execute(&context).await.unwrap();
401 assert_eq!(result, TaskResult::Success);
402 }
403
404 #[tokio::test]
405 async fn test_parallel_tasks_failure_stops() {
406 let task1 = Box::new(FunctionTask::new(
408 TaskId::new("task1"),
409 "Task 1".to_string(),
410 |_ctx| async { Ok(TaskResult::Success) },
411 ));
412
413 let task2 = Box::new(FunctionTask::new(
414 TaskId::new("task2"),
415 "Task 2".to_string(),
416 |_ctx| async { Err(TaskError::ExecutionFailed("error".to_string())) },
417 ));
418
419 let parallel = ParallelTasks::new(vec![task1, task2]);
420 let context = TaskContext::new("workflow-1", TaskId::new("parallel_tasks"));
421
422 let result = parallel.execute(&context).await;
424 assert!(result.is_err());
425 }
426
427 #[tokio::test]
428 async fn test_parallel_tasks_sequential_execution() {
429 use std::time::Instant;
430
431 let task1 = Box::new(FunctionTask::new(
439 TaskId::new("task1"),
440 "Task 1".to_string(),
441 |_ctx| async {
442 tokio::time::sleep(Duration::from_millis(50)).await;
443 Ok(TaskResult::Success)
444 },
445 ));
446
447 let task2 = Box::new(FunctionTask::new(
448 TaskId::new("task2"),
449 "Task 2".to_string(),
450 |_ctx| async {
451 tokio::time::sleep(Duration::from_millis(50)).await;
452 Ok(TaskResult::Success)
453 },
454 ));
455
456 let parallel = ParallelTasks::new(vec![task1, task2]);
457 let context = TaskContext::new("workflow-1", TaskId::new("parallel_tasks"));
458
459 let start = Instant::now();
460 let result = parallel.execute(&context).await;
461 let elapsed = start.elapsed();
462
463 assert!(result.is_ok());
464 assert_eq!(result.unwrap(), TaskResult::Success);
465
466 assert!(elapsed.as_millis() >= 80, "Expected ~100ms sequential but got {}ms", elapsed.as_millis());
469 assert!(elapsed.as_millis() < 150, "Expected ~100ms but got {}ms", elapsed.as_millis());
470 }
471}