1use crate::{TaskError, TaskId, TaskResult};
36use serde::{Deserialize, Serialize};
37use std::collections::{HashMap, HashSet, VecDeque};
38
39#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
50pub enum TaskNodeStatus {
51 Pending,
53 Ready,
55 Running,
57 Completed,
59 Failed,
61}
62
63#[derive(Debug, Clone, Serialize, Deserialize)]
76pub struct TaskNode {
77 id: TaskId,
79 dependencies: Vec<TaskId>,
81 status: TaskNodeStatus,
83}
84
85impl TaskNode {
86 pub fn new(id: TaskId) -> Self {
97 Self {
98 id,
99 dependencies: Vec::new(),
100 status: TaskNodeStatus::Pending,
101 }
102 }
103
104 pub fn id(&self) -> TaskId {
116 self.id
117 }
118
119 pub fn dependencies(&self) -> &[TaskId] {
130 &self.dependencies
131 }
132
133 pub fn status(&self) -> TaskNodeStatus {
144 self.status
145 }
146
147 pub fn add_dependency(&mut self, task_id: TaskId) {
160 if !self.dependencies.contains(&task_id) {
161 self.dependencies.push(task_id);
162 }
163 }
164
165 pub fn set_status(&mut self, status: TaskNodeStatus) {
177 self.status = status;
178 }
179
180 pub(crate) fn remove_dependency(&mut self, task_id: TaskId) {
182 self.dependencies.retain(|&id| id != task_id);
183 }
184}
185
186#[derive(Debug, Clone, Serialize, Deserialize)]
207pub struct TaskDAG {
208 nodes: HashMap<TaskId, TaskNode>,
210 dependents: HashMap<TaskId, Vec<TaskId>>,
212}
213
214impl TaskDAG {
215 pub fn new() -> Self {
226 Self {
227 nodes: HashMap::new(),
228 dependents: HashMap::new(),
229 }
230 }
231
232 pub fn add_task(&mut self, task_id: TaskId) -> TaskResult<()> {
249 if self.nodes.contains_key(&task_id) {
250 return Err(TaskError::ExecutionFailed(format!(
251 "Task {} already exists in DAG",
252 task_id
253 )));
254 }
255
256 self.nodes.insert(task_id, TaskNode::new(task_id));
257 self.dependents.insert(task_id, Vec::new());
258 Ok(())
259 }
260
261 pub fn add_dependency(&mut self, task_id: TaskId, depends_on: TaskId) -> TaskResult<()> {
285 if !self.nodes.contains_key(&task_id) {
287 return Err(TaskError::TaskNotFound(task_id.to_string()));
288 }
289 if !self.nodes.contains_key(&depends_on) {
290 return Err(TaskError::TaskNotFound(depends_on.to_string()));
291 }
292
293 if let Some(node) = self.nodes.get_mut(&task_id) {
295 node.add_dependency(depends_on);
296 }
297
298 if let Some(deps) = self.dependents.get_mut(&depends_on)
300 && !deps.contains(&task_id)
301 {
302 deps.push(task_id);
303 }
304
305 if let Err(e) = self.detect_cycle() {
307 if let Some(node) = self.nodes.get_mut(&task_id) {
308 node.remove_dependency(depends_on);
309 }
310 if let Some(deps) = self.dependents.get_mut(&depends_on) {
311 deps.retain(|&id| id != task_id);
312 }
313 return Err(e);
314 }
315
316 Ok(())
317 }
318
319 pub fn task_count(&self) -> usize {
332 self.nodes.len()
333 }
334
335 pub fn get_task(&self, task_id: TaskId) -> Option<&TaskNode> {
350 self.nodes.get(&task_id)
351 }
352
353 pub fn get_ready_tasks(&self) -> Vec<TaskId> {
372 self.nodes
373 .values()
374 .filter(|node| {
375 node.status() == TaskNodeStatus::Pending
377 && node
378 .dependencies()
379 .iter()
380 .all(|dep_id| match self.nodes.get(dep_id) {
381 Some(dep_node) => dep_node.status() == TaskNodeStatus::Completed,
382 None => false,
383 })
384 })
385 .map(|node| node.id())
386 .collect()
387 }
388
389 pub fn mark_completed(&mut self, task_id: TaskId) -> TaskResult<()> {
408 let node = self
409 .nodes
410 .get_mut(&task_id)
411 .ok_or_else(|| TaskError::TaskNotFound(task_id.to_string()))?;
412
413 node.set_status(TaskNodeStatus::Completed);
414 Ok(())
415 }
416
417 pub fn mark_failed(&mut self, task_id: TaskId) -> TaskResult<()> {
436 let node = self
437 .nodes
438 .get_mut(&task_id)
439 .ok_or_else(|| TaskError::TaskNotFound(task_id.to_string()))?;
440
441 node.set_status(TaskNodeStatus::Failed);
442 Ok(())
443 }
444
445 pub fn mark_running(&mut self, task_id: TaskId) -> TaskResult<()> {
464 let node = self
465 .nodes
466 .get_mut(&task_id)
467 .ok_or_else(|| TaskError::TaskNotFound(task_id.to_string()))?;
468
469 node.set_status(TaskNodeStatus::Running);
470 Ok(())
471 }
472
473 pub fn topological_sort(&self) -> TaskResult<Vec<TaskId>> {
507 let mut in_degree: HashMap<TaskId, usize> = HashMap::new();
509 for (task_id, node) in &self.nodes {
510 in_degree.insert(*task_id, node.dependencies().len());
511 }
512
513 let mut queue: VecDeque<TaskId> = in_degree
515 .iter()
516 .filter(|(_, degree)| **degree == 0)
517 .map(|(task_id, _)| *task_id)
518 .collect();
519
520 let mut sorted = Vec::new();
521
522 while let Some(task_id) = queue.pop_front() {
523 sorted.push(task_id);
524
525 if let Some(deps) = self.dependents.get(&task_id) {
527 for &dependent in deps {
528 if let Some(degree) = in_degree.get_mut(&dependent) {
529 *degree -= 1;
530 if *degree == 0 {
531 queue.push_back(dependent);
532 }
533 }
534 }
535 }
536 }
537
538 if sorted.len() != self.nodes.len() {
540 return Err(TaskError::ExecutionFailed(
541 "Cycle detected in task dependencies".to_string(),
542 ));
543 }
544
545 Ok(sorted)
546 }
547
548 fn detect_cycle(&self) -> TaskResult<()> {
557 let mut visited = HashSet::new();
558 let mut rec_stack = HashSet::new();
559
560 for &start_id in self.nodes.keys() {
561 if visited.contains(&start_id) {
562 continue;
563 }
564
565 let mut stack: Vec<(TaskId, usize, bool)> = vec![(start_id, 0, true)];
568
569 while let Some((task_id, dep_idx, is_entering)) = stack.last_mut() {
570 if *is_entering {
571 visited.insert(*task_id);
572 rec_stack.insert(*task_id);
573 *is_entering = false;
574 }
575
576 let deps = self
577 .nodes
578 .get(task_id)
579 .map(|n| n.dependencies())
580 .unwrap_or(&[]);
581
582 if *dep_idx < deps.len() {
583 let dep_id = deps[*dep_idx];
584 *dep_idx += 1;
585
586 if rec_stack.contains(&dep_id) {
587 return Err(TaskError::ExecutionFailed(format!(
588 "Cycle detected: {} -> {}",
589 task_id, dep_id
590 )));
591 }
592
593 if !visited.contains(&dep_id) {
594 stack.push((dep_id, 0, true));
595 }
596 } else {
597 rec_stack.remove(task_id);
599 stack.pop();
600 }
601 }
602 }
603
604 Ok(())
605 }
606}
607
608impl Default for TaskDAG {
609 fn default() -> Self {
610 Self::new()
611 }
612}
613
614#[cfg(test)]
615mod tests {
616 use super::*;
617 use rstest::rstest;
618
619 #[rstest]
620 fn test_dag_creation() {
621 let dag = TaskDAG::new();
623
624 assert_eq!(dag.task_count(), 0);
626 }
627
628 #[rstest]
629 fn test_add_task() {
630 let mut dag = TaskDAG::new();
632 let task_id = TaskId::new();
633
634 dag.add_task(task_id).unwrap();
636
637 assert_eq!(dag.task_count(), 1);
639 assert!(dag.get_task(task_id).is_some());
640 }
641
642 #[rstest]
643 fn test_add_duplicate_task() {
644 let mut dag = TaskDAG::new();
646 let task_id = TaskId::new();
647 dag.add_task(task_id).unwrap();
648
649 let result = dag.add_task(task_id);
651
652 assert!(result.is_err());
654 }
655
656 #[rstest]
657 fn test_add_dependency() {
658 let mut dag = TaskDAG::new();
660 let task_a = TaskId::new();
661 let task_b = TaskId::new();
662 dag.add_task(task_a).unwrap();
663 dag.add_task(task_b).unwrap();
664
665 dag.add_dependency(task_b, task_a).unwrap();
667
668 let node_b = dag.get_task(task_b).unwrap();
670 assert_eq!(node_b.dependencies().len(), 1);
671 assert_eq!(node_b.dependencies()[0], task_a);
672 }
673
674 #[rstest]
675 fn test_add_dependency_nonexistent_task() {
676 let mut dag = TaskDAG::new();
678 let task_a = TaskId::new();
679 let task_b = TaskId::new();
680 dag.add_task(task_a).unwrap();
681
682 let result = dag.add_dependency(task_a, task_b);
684
685 assert!(result.is_err());
687 }
688
689 #[rstest]
690 fn test_cycle_detection() {
691 let mut dag = TaskDAG::new();
693 let task_a = TaskId::new();
694 let task_b = TaskId::new();
695 let task_c = TaskId::new();
696
697 dag.add_task(task_a).unwrap();
698 dag.add_task(task_b).unwrap();
699 dag.add_task(task_c).unwrap();
700
701 dag.add_dependency(task_b, task_a).unwrap();
702 dag.add_dependency(task_c, task_b).unwrap();
703
704 let result = dag.add_dependency(task_a, task_c);
706
707 assert!(result.is_err());
709 }
710
711 #[rstest]
712 fn test_topological_sort_simple() {
713 let mut dag = TaskDAG::new();
715 let task_a = TaskId::new();
716 let task_b = TaskId::new();
717 let task_c = TaskId::new();
718
719 dag.add_task(task_a).unwrap();
720 dag.add_task(task_b).unwrap();
721 dag.add_task(task_c).unwrap();
722
723 dag.add_dependency(task_b, task_a).unwrap();
725 dag.add_dependency(task_c, task_b).unwrap();
726
727 let order = dag.topological_sort().unwrap();
729
730 assert_eq!(order.len(), 3);
732 let a_pos = order.iter().position(|&id| id == task_a).unwrap();
733 let b_pos = order.iter().position(|&id| id == task_b).unwrap();
734 let c_pos = order.iter().position(|&id| id == task_c).unwrap();
735 assert!(a_pos < b_pos);
736 assert!(b_pos < c_pos);
737 }
738
739 #[rstest]
740 fn test_topological_sort_diamond() {
741 let mut dag = TaskDAG::new();
743 let task_a = TaskId::new();
744 let task_b = TaskId::new();
745 let task_c = TaskId::new();
746 let task_d = TaskId::new();
747
748 dag.add_task(task_a).unwrap();
749 dag.add_task(task_b).unwrap();
750 dag.add_task(task_c).unwrap();
751 dag.add_task(task_d).unwrap();
752
753 dag.add_dependency(task_b, task_a).unwrap();
755 dag.add_dependency(task_c, task_a).unwrap();
756 dag.add_dependency(task_d, task_b).unwrap();
757 dag.add_dependency(task_d, task_c).unwrap();
758
759 let order = dag.topological_sort().unwrap();
761
762 assert_eq!(order.len(), 4);
764 let a_pos = order.iter().position(|&id| id == task_a).unwrap();
765 let b_pos = order.iter().position(|&id| id == task_b).unwrap();
766 let c_pos = order.iter().position(|&id| id == task_c).unwrap();
767 let d_pos = order.iter().position(|&id| id == task_d).unwrap();
768
769 assert!(a_pos < b_pos);
771 assert!(a_pos < c_pos);
772 assert!(b_pos < d_pos);
774 assert!(c_pos < d_pos);
775 }
776
777 #[rstest]
778 fn test_get_ready_tasks() {
779 let mut dag = TaskDAG::new();
781 let task_a = TaskId::new();
782 let task_b = TaskId::new();
783 let task_c = TaskId::new();
784
785 dag.add_task(task_a).unwrap();
786 dag.add_task(task_b).unwrap();
787 dag.add_task(task_c).unwrap();
788
789 dag.add_dependency(task_b, task_a).unwrap();
791 dag.add_dependency(task_c, task_b).unwrap();
792
793 let ready = dag.get_ready_tasks();
795 assert_eq!(ready.len(), 1);
796 assert!(ready.contains(&task_a));
797
798 dag.mark_completed(task_a).unwrap();
800 let ready = dag.get_ready_tasks();
801
802 assert_eq!(ready.len(), 1);
804 assert!(ready.contains(&task_b));
805
806 dag.mark_completed(task_b).unwrap();
808 let ready = dag.get_ready_tasks();
809
810 assert_eq!(ready.len(), 1);
812 assert!(ready.contains(&task_c));
813 }
814
815 #[rstest]
816 fn test_mark_status() {
817 let mut dag = TaskDAG::new();
819 let task_id = TaskId::new();
820 dag.add_task(task_id).unwrap();
821
822 assert_eq!(
824 dag.get_task(task_id).unwrap().status(),
825 TaskNodeStatus::Pending
826 );
827
828 dag.mark_running(task_id).unwrap();
830 assert_eq!(
831 dag.get_task(task_id).unwrap().status(),
832 TaskNodeStatus::Running
833 );
834
835 dag.mark_completed(task_id).unwrap();
837 assert_eq!(
838 dag.get_task(task_id).unwrap().status(),
839 TaskNodeStatus::Completed
840 );
841 }
842
843 #[rstest]
844 fn test_mark_failed() {
845 let mut dag = TaskDAG::new();
847 let task_id = TaskId::new();
848 dag.add_task(task_id).unwrap();
849
850 dag.mark_failed(task_id).unwrap();
852
853 assert_eq!(
855 dag.get_task(task_id).unwrap().status(),
856 TaskNodeStatus::Failed
857 );
858 }
859
860 #[rstest]
861 fn test_parallel_execution_detection() {
862 let mut dag = TaskDAG::new();
864 let task_a = TaskId::new();
865 let task_b = TaskId::new();
866 let task_c = TaskId::new();
867 let task_d = TaskId::new();
868
869 dag.add_task(task_a).unwrap();
870 dag.add_task(task_b).unwrap();
871 dag.add_task(task_c).unwrap();
872 dag.add_task(task_d).unwrap();
873
874 dag.add_dependency(task_b, task_a).unwrap();
876 dag.add_dependency(task_c, task_a).unwrap();
877 dag.add_dependency(task_d, task_b).unwrap();
878 dag.add_dependency(task_d, task_c).unwrap();
879
880 dag.mark_completed(task_a).unwrap();
882 let ready = dag.get_ready_tasks();
883
884 assert_eq!(ready.len(), 2);
886 assert!(ready.contains(&task_b));
887 assert!(ready.contains(&task_c));
888 }
889
890 #[rstest]
891 fn test_deep_dependency_chain_does_not_stack_overflow() {
892 let mut dag = TaskDAG::new();
896 let depth = 1000;
897 let mut task_ids = Vec::with_capacity(depth);
898
899 for _ in 0..depth {
900 let id = TaskId::new();
901 dag.add_task(id).unwrap();
902 task_ids.push(id);
903 }
904
905 for i in 1..depth {
906 dag.add_dependency(task_ids[i], task_ids[i - 1]).unwrap();
907 }
908
909 let order = dag.topological_sort().unwrap();
911
912 assert_eq!(order.len(), depth);
914 for i in 1..depth {
916 let dep_pos = order.iter().position(|&id| id == task_ids[i - 1]).unwrap();
917 let task_pos = order.iter().position(|&id| id == task_ids[i]).unwrap();
918 assert!(dep_pos < task_pos);
919 }
920 }
921
922 #[rstest]
923 fn test_cycle_detection_on_deep_chain_with_back_edge() {
924 let mut dag = TaskDAG::new();
926 let depth = 500;
927 let mut task_ids = Vec::with_capacity(depth);
928
929 for _ in 0..depth {
930 let id = TaskId::new();
931 dag.add_task(id).unwrap();
932 task_ids.push(id);
933 }
934
935 for i in 1..depth {
936 dag.add_dependency(task_ids[i], task_ids[i - 1]).unwrap();
937 }
938
939 let result = dag.add_dependency(task_ids[0], task_ids[depth - 1]);
941
942 assert!(result.is_err());
944 }
945
946 #[rstest]
950 fn test_deep_chain_10k_nodes_does_not_stack_overflow() {
951 let mut dag = TaskDAG::new();
953 let depth = 10_000;
954 let mut task_ids = Vec::with_capacity(depth);
955
956 for _ in 0..depth {
957 let id = TaskId::new();
958 dag.add_task(id).unwrap();
959 task_ids.push(id);
960 }
961
962 for i in 1..depth {
963 dag.add_dependency(task_ids[i], task_ids[i - 1]).unwrap();
964 }
965
966 let order = dag.topological_sort().unwrap();
968
969 assert_eq!(order.len(), depth);
971 for i in 1..depth {
972 let prev_pos = order.iter().position(|&id| id == task_ids[i - 1]).unwrap();
973 let curr_pos = order.iter().position(|&id| id == task_ids[i]).unwrap();
974 assert!(
975 prev_pos < curr_pos,
976 "task_ids[{}] must precede task_ids[{}]",
977 i - 1,
978 i
979 );
980 }
981 }
982
983 #[rstest]
986 fn test_deep_chain_10k_nodes_back_edge_detected() {
987 let mut dag = TaskDAG::new();
989 let depth = 10_000;
990 let mut task_ids = Vec::with_capacity(depth);
991
992 for _ in 0..depth {
993 let id = TaskId::new();
994 dag.add_task(id).unwrap();
995 task_ids.push(id);
996 }
997
998 for i in 1..depth {
999 dag.add_dependency(task_ids[i], task_ids[i - 1]).unwrap();
1000 }
1001
1002 let result = dag.add_dependency(task_ids[0], task_ids[depth - 1]);
1004
1005 assert!(
1007 result.is_err(),
1008 "back-edge cycle must be detected in 10k chain"
1009 );
1010 }
1011}