1use async_trait::async_trait;
75use crate::workflow::{
76 builder::WorkflowBuilder,
77 cancellation::CancellationToken,
78 dag::WorkflowError,
79 task::{TaskContext, TaskError, TaskId, TaskResult, WorkflowTask},
80 tasks::{AgentLoopTask, FunctionTask, GraphQueryTask},
81 Workflow,
82};
83
84#[cfg(doc)]
98pub fn example_linear_workflow() -> Result<Workflow, WorkflowError> {
99 WorkflowBuilder::sequential(vec![
100 Box::new(FunctionTask::new(
101 TaskId::new("init"),
102 "Initialize".to_string(),
103 |_ctx| async { Ok(TaskResult::Success) },
104 )),
105 Box::new(FunctionTask::new(
106 TaskId::new("process"),
107 "Process".to_string(),
108 |_ctx| async { Ok(TaskResult::Success) },
109 )),
110 Box::new(FunctionTask::new(
111 TaskId::new("finalize"),
112 "Finalize".to_string(),
113 |_ctx| async { Ok(TaskResult::Success) },
114 )),
115 ])
116}
117
118#[cfg(doc)]
132pub fn example_branching_workflow() -> Result<Workflow, WorkflowError> {
133 let condition = Box::new(FunctionTask::new(
135 TaskId::new("check"),
136 "Check Condition".to_string(),
137 |_ctx| async { Ok(TaskResult::Success) },
138 ));
139
140 let then_task = Box::new(FunctionTask::new(
142 TaskId::new("then_task"),
143 "Then Task".to_string(),
144 |_ctx| async { Ok(TaskResult::Success) },
145 ));
146
147 let else_task = Box::new(FunctionTask::new(
149 TaskId::new("else_task"),
150 "Else Task".to_string(),
151 |_ctx| async { Ok(TaskResult::Success) },
152 ));
153
154 let finalize = Box::new(FunctionTask::new(
156 TaskId::new("finalize"),
157 "Finalize".to_string(),
158 |_ctx| async { Ok(TaskResult::Success) },
159 ));
160
161 WorkflowBuilder::new()
163 .add_task(condition)
164 .add_task(then_task)
165 .add_task(else_task)
166 .add_task(finalize)
167 .dependency(TaskId::new("check"), TaskId::new("then_task"))
168 .dependency(TaskId::new("check"), TaskId::new("else_task"))
169 .dependency(TaskId::new("then_task"), TaskId::new("finalize"))
170 .dependency(TaskId::new("else_task"), TaskId::new("finalize"))
171 .build()
172}
173
174#[cfg(doc)]
188pub fn example_graph_aware_workflow() -> Result<Workflow, WorkflowError> {
189 WorkflowBuilder::new()
190 .add_task(Box::new(GraphQueryTask::find_symbol("process_data")))
191 .add_task(Box::new(GraphQueryTask::references("process_data")))
192 .add_task(Box::new(FunctionTask::new(
193 TaskId::new("analyze"),
194 "Analyze Results".to_string(),
195 |_ctx| async { Ok(TaskResult::Success) },
196 )))
197 .dependency(
198 TaskId::new("graph_query_FindSymbol"),
199 TaskId::new("graph_query_References"),
200 )
201 .dependency(TaskId::new("graph_query_References"), TaskId::new("analyze"))
202 .build()
203}
204
205#[cfg(doc)]
219pub fn example_agent_workflow() -> Result<Workflow, WorkflowError> {
220 let graph_query = Box::new(GraphQueryTask::find_symbol("main"));
221
222 let agent_task = Box::new(AgentLoopTask::new(
223 TaskId::new("agent_analysis"),
224 "Agent Analysis".to_string(),
225 "Analyze the main function and suggest improvements",
226 ));
227
228 let report = Box::new(FunctionTask::new(
229 TaskId::new("report"),
230 "Generate Report".to_string(),
231 |_ctx| async { Ok(TaskResult::Success) },
232 ));
233
234 WorkflowBuilder::new()
235 .add_task(graph_query)
236 .add_task(agent_task)
237 .add_task(report)
238 .dependency(
239 TaskId::new("graph_query_FindSymbol"),
240 TaskId::new("agent_analysis"),
241 )
242 .dependency(TaskId::new("agent_analysis"), TaskId::new("report"))
243 .build()
244}
245
246pub struct CancellationAwareTask {
266 id: TaskId,
267 name: String,
268 iterations: usize,
269 delay_ms: u64,
270}
271
272impl CancellationAwareTask {
273 pub fn new(id: TaskId, name: String, iterations: usize, delay_ms: u64) -> Self {
282 Self {
283 id,
284 name,
285 iterations,
286 delay_ms,
287 }
288 }
289}
290
291#[async_trait]
292impl WorkflowTask for CancellationAwareTask {
293 async fn execute(&self, context: &TaskContext) -> Result<TaskResult, TaskError> {
294 let mut completed_iterations = 0;
295
296 while completed_iterations < self.iterations {
298 if let Some(token) = context.cancellation_token() {
300 if token.poll_cancelled() {
301 return Ok(TaskResult::Success); }
303 }
304
305 tokio::time::sleep(tokio::time::Duration::from_millis(self.delay_ms)).await;
307 completed_iterations += 1;
308 }
309
310 Ok(TaskResult::Success)
311 }
312
313 fn id(&self) -> TaskId {
314 self.id.clone()
315 }
316
317 fn name(&self) -> &str {
318 &self.name
319 }
320}
321
322pub struct PollingTask {
340 id: TaskId,
341 name: String,
342 total_duration_ms: u64,
343}
344
345impl PollingTask {
346 pub fn new(id: TaskId, name: String, total_duration_ms: u64) -> Self {
354 Self {
355 id,
356 name,
357 total_duration_ms,
358 }
359 }
360}
361
362#[async_trait]
363impl WorkflowTask for PollingTask {
364 async fn execute(&self, context: &TaskContext) -> Result<TaskResult, TaskError> {
365 let work_duration = tokio::time::Duration::from_millis(self.total_duration_ms);
367 let work = tokio::time::sleep(work_duration);
368
369 tokio::select! {
370 _ = work => {
371 Ok(TaskResult::Success)
373 }
374 _ = async {
375 if let Some(token) = context.cancellation_token() {
377 token.wait_until_cancelled().await;
378 }
379 } => {
380 Ok(TaskResult::Success)
382 }
383 }
384 }
385
386 fn id(&self) -> TaskId {
387 self.id.clone()
388 }
389
390 fn name(&self) -> &str {
391 &self.name
392 }
393}
394
395pub struct TimeoutAndCancellationTask {
416 id: TaskId,
417 name: String,
418 work_duration_ms: u64,
419}
420
421impl TimeoutAndCancellationTask {
422 pub fn new(id: TaskId, name: String, work_duration_ms: u64) -> Self {
430 Self {
431 id,
432 name,
433 work_duration_ms,
434 }
435 }
436}
437
438#[async_trait]
439impl WorkflowTask for TimeoutAndCancellationTask {
440 async fn execute(&self, context: &TaskContext) -> Result<TaskResult, TaskError> {
441 let work = tokio::time::sleep(tokio::time::Duration::from_millis(self.work_duration_ms));
443
444 tokio::select! {
445 _ = work => {
446 Ok(TaskResult::Success)
448 }
449 _ = tokio::time::sleep(tokio::time::Duration::from_secs(30)) => {
450 Ok(TaskResult::Success) }
453 _ = async {
454 if let Some(token) = context.cancellation_token() {
456 token.wait_until_cancelled().await;
457 }
458 } => {
459 Ok(TaskResult::Success)
461 }
462 }
463 }
464
465 fn id(&self) -> TaskId {
466 self.id.clone()
467 }
468
469 fn name(&self) -> &str {
470 &self.name
471 }
472}
473
474pub fn cooperative_cancellation_example() -> Workflow {
494 let task1 = Box::new(CancellationAwareTask::new(
495 TaskId::new("task1"),
496 "Cancellation Aware Task 1".to_string(),
497 100,
498 10,
499 ));
500
501 let task2 = Box::new(CancellationAwareTask::new(
502 TaskId::new("task2"),
503 "Cancellation Aware Task 2".to_string(),
504 100,
505 10,
506 ));
507
508 let task3 = Box::new(CancellationAwareTask::new(
509 TaskId::new("task3"),
510 "Cancellation Aware Task 3".to_string(),
511 100,
512 10,
513 ));
514
515 WorkflowBuilder::sequential(vec![task1, task2, task3]).unwrap()
516}
517
518pub fn timeout_cancellation_example() -> Workflow {
539 let task1 = Box::new(TimeoutAndCancellationTask::new(
540 TaskId::new("task1"),
541 "Timeout Task 1".to_string(),
542 50, ));
544
545 let task2 = Box::new(TimeoutAndCancellationTask::new(
546 TaskId::new("task2"),
547 "Timeout Task 2".to_string(),
548 50,
549 ));
550
551 WorkflowBuilder::sequential(vec![task1, task2]).unwrap()
552}
553
554#[cfg(test)]
555mod tests {
556 use super::*;
557
558 #[test]
559 fn test_linear_workflow_example() {
560 let result = WorkflowBuilder::sequential(vec![
562 Box::new(FunctionTask::new(
563 TaskId::new("init"),
564 "Initialize".to_string(),
565 |_ctx| async { Ok(TaskResult::Success) },
566 )),
567 Box::new(FunctionTask::new(
568 TaskId::new("process"),
569 "Process".to_string(),
570 |_ctx| async { Ok(TaskResult::Success) },
571 )),
572 Box::new(FunctionTask::new(
573 TaskId::new("finalize"),
574 "Finalize".to_string(),
575 |_ctx| async { Ok(TaskResult::Success) },
576 )),
577 ]);
578
579 assert!(result.is_ok());
580 let workflow = result.unwrap();
581 assert_eq!(workflow.task_count(), 3);
582 }
583
584 #[test]
585 fn test_branching_workflow_example() {
586 let condition = Box::new(FunctionTask::new(
587 TaskId::new("check"),
588 "Check Condition".to_string(),
589 |_ctx| async { Ok(TaskResult::Success) },
590 ));
591
592 let then_task = Box::new(FunctionTask::new(
593 TaskId::new("then_task"),
594 "Then Task".to_string(),
595 |_ctx| async { Ok(TaskResult::Success) },
596 ));
597
598 let else_task = Box::new(FunctionTask::new(
599 TaskId::new("else_task"),
600 "Else Task".to_string(),
601 |_ctx| async { Ok(TaskResult::Success) },
602 ));
603
604 let finalize = Box::new(FunctionTask::new(
605 TaskId::new("finalize"),
606 "Finalize".to_string(),
607 |_ctx| async { Ok(TaskResult::Success) },
608 ));
609
610 let result = WorkflowBuilder::new()
611 .add_task(condition)
612 .add_task(then_task)
613 .add_task(else_task)
614 .add_task(finalize)
615 .dependency(TaskId::new("check"), TaskId::new("then_task"))
616 .dependency(TaskId::new("check"), TaskId::new("else_task"))
617 .dependency(TaskId::new("then_task"), TaskId::new("finalize"))
618 .dependency(TaskId::new("else_task"), TaskId::new("finalize"))
619 .build();
620
621 assert!(result.is_ok());
622 let workflow = result.unwrap();
623 assert_eq!(workflow.task_count(), 4);
624 }
625
626 #[test]
627 fn test_graph_aware_workflow_example() {
628 let result = WorkflowBuilder::new()
629 .add_task(Box::new(GraphQueryTask::find_symbol("process_data")))
630 .add_task(Box::new(GraphQueryTask::references("process_data")))
631 .add_task(Box::new(FunctionTask::new(
632 TaskId::new("analyze"),
633 "Analyze Results".to_string(),
634 |_ctx| async { Ok(TaskResult::Success) },
635 )))
636 .dependency(
637 TaskId::new("graph_query_FindSymbol"),
638 TaskId::new("graph_query_References"),
639 )
640 .dependency(TaskId::new("graph_query_References"), TaskId::new("analyze"))
641 .build();
642
643 assert!(result.is_ok());
644 let workflow = result.unwrap();
645 assert_eq!(workflow.task_count(), 3);
646 }
647
648 #[test]
649 fn test_agent_workflow_example() {
650 let graph_query = Box::new(GraphQueryTask::find_symbol("main"));
651
652 let agent_task = Box::new(AgentLoopTask::new(
653 TaskId::new("agent_analysis"),
654 "Agent Analysis".to_string(),
655 "Analyze the main function and suggest improvements",
656 ));
657
658 let report = Box::new(FunctionTask::new(
659 TaskId::new("report"),
660 "Generate Report".to_string(),
661 |_ctx| async { Ok(TaskResult::Success) },
662 ));
663
664 let result = WorkflowBuilder::new()
665 .add_task(graph_query)
666 .add_task(agent_task)
667 .add_task(report)
668 .dependency(
669 TaskId::new("graph_query_FindSymbol"),
670 TaskId::new("agent_analysis"),
671 )
672 .dependency(TaskId::new("agent_analysis"), TaskId::new("report"))
673 .build();
674
675 assert!(result.is_ok());
676 let workflow = result.unwrap();
677 assert_eq!(workflow.task_count(), 3);
678 }
679
680 #[tokio::test]
683 async fn test_cancellation_aware_task_stops_on_cancel() {
684 use crate::workflow::cancellation::CancellationTokenSource;
685 use crate::workflow::task::TaskContext;
686
687 let source = CancellationTokenSource::new();
688 let task = CancellationAwareTask::new(
689 TaskId::new("task1"),
690 "Test Task".to_string(),
691 1000, 10,
693 );
694
695 let mut context = TaskContext::new("test-workflow", task.id());
697 context = context.with_cancellation_token(source.token());
698
699 tokio::spawn(async move {
701 tokio::time::sleep(tokio::time::Duration::from_millis(50)).await;
702 source.cancel();
703 });
704
705 let start = std::time::Instant::now();
707 let result = task.execute(&context).await;
708 let elapsed = start.elapsed();
709
710 assert!(result.is_ok());
711 assert!(elapsed < tokio::time::Duration::from_millis(500)); }
713
714 #[tokio::test]
715 async fn test_cancellation_aware_task_completes_without_cancel() {
716 use crate::workflow::task::TaskContext;
717
718 let task = CancellationAwareTask::new(
719 TaskId::new("task1"),
720 "Test Task".to_string(),
721 5, 10,
723 );
724
725 let context = TaskContext::new("test-workflow", task.id());
727
728 let result = task.execute(&context).await;
730
731 assert!(result.is_ok());
732 assert_eq!(result.unwrap(), TaskResult::Success);
733 }
734
735 #[tokio::test]
736 async fn test_polling_task_with_tokio_select() {
737 use crate::workflow::cancellation::CancellationTokenSource;
738 use crate::workflow::task::TaskContext;
739
740 let source = CancellationTokenSource::new();
741 let task = PollingTask::new(
742 TaskId::new("task1"),
743 "Polling Task".to_string(),
744 5000, );
746
747 let mut context = TaskContext::new("test-workflow", task.id());
749 context = context.with_cancellation_token(source.token());
750
751 tokio::spawn(async move {
753 tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
754 source.cancel();
755 });
756
757 let start = std::time::Instant::now();
759 let result = task.execute(&context).await;
760 let elapsed = start.elapsed();
761
762 assert!(result.is_ok());
763 assert!(elapsed < tokio::time::Duration::from_millis(500)); }
765
766 #[tokio::test]
767 async fn test_cooperative_cancellation_example() {
768 let workflow = cooperative_cancellation_example();
769
770 assert_eq!(workflow.task_count(), 3);
771 }
772
773 #[tokio::test]
776 async fn test_task_exits_on_timeout_before_cancellation() {
777 use crate::workflow::cancellation::CancellationTokenSource;
778 use crate::workflow::task::TaskContext;
779
780 let source = CancellationTokenSource::new();
781 let task = TimeoutAndCancellationTask::new(
782 TaskId::new("task1"),
783 "Timeout Task".to_string(),
784 5000, );
786
787 let mut context = TaskContext::new("test-workflow", task.id());
789 context = context.with_cancellation_token(source.token());
790
791 let start = std::time::Instant::now();
795 let result = tokio::time::timeout(
796 tokio::time::Duration::from_millis(100),
797 task.execute(&context),
798 ).await;
799 let elapsed = start.elapsed();
800
801 assert!(result.is_err()); assert!(elapsed < tokio::time::Duration::from_millis(200));
804 }
805
806 #[tokio::test]
807 async fn test_task_exits_on_cancellation_before_timeout() {
808 use crate::workflow::cancellation::CancellationTokenSource;
809 use crate::workflow::task::TaskContext;
810
811 let source = CancellationTokenSource::new();
812 let task = TimeoutAndCancellationTask::new(
813 TaskId::new("task1"),
814 "Timeout Task".to_string(),
815 5000, );
817
818 let mut context = TaskContext::new("test-workflow", task.id());
820 context = context.with_cancellation_token(source.token());
821
822 tokio::spawn(async move {
824 tokio::time::sleep(tokio::time::Duration::from_millis(50)).await;
825 source.cancel();
826 });
827
828 let start = std::time::Instant::now();
830 let result = task.execute(&context).await;
831 let elapsed = start.elapsed();
832
833 assert!(result.is_ok());
834 assert!(elapsed < tokio::time::Duration::from_millis(200)); }
836
837 #[tokio::test]
838 async fn test_task_completes_before_timeout_and_cancellation() {
839 use crate::workflow::cancellation::CancellationTokenSource;
840 use crate::workflow::task::TaskContext;
841
842 let source = CancellationTokenSource::new();
843 let task = TimeoutAndCancellationTask::new(
844 TaskId::new("task1"),
845 "Fast Task".to_string(),
846 50, );
848
849 let mut context = TaskContext::new("test-workflow", task.id());
851 context = context.with_cancellation_token(source.token());
852
853 let start = std::time::Instant::now();
855 let result = task.execute(&context).await;
856 let elapsed = start.elapsed();
857
858 assert!(result.is_ok());
859 assert_eq!(result.unwrap(), TaskResult::Success);
860 assert!(elapsed < tokio::time::Duration::from_millis(100)); assert!(elapsed >= tokio::time::Duration::from_millis(40)); }
863
864 #[tokio::test]
865 async fn test_timeout_cancellation_example() {
866 let workflow = timeout_cancellation_example();
867
868 assert_eq!(workflow.task_count(), 2);
869 }
870}