1use std::collections::HashMap;
2
3use petgraph::algo::toposort;
4use petgraph::graph::{DiGraph, NodeIndex};
5
6use crate::error::{SdkError, SdkResult, TaskId};
7use crate::types::task::Task;
8
9pub struct TaskGraph {
10 graph: DiGraph<TaskId, ()>,
11 node_map: HashMap<TaskId, NodeIndex>,
12}
13
14impl Default for TaskGraph {
15 fn default() -> Self {
16 Self::new()
17 }
18}
19
20impl TaskGraph {
21 pub fn new() -> Self {
22 Self {
23 graph: DiGraph::new(),
24 node_map: HashMap::new(),
25 }
26 }
27
28 pub fn from_tasks(tasks: &[Task]) -> SdkResult<Self> {
29 let mut tg = Self::new();
30
31 for task in tasks {
32 tg.add_task(task.id);
33 }
34
35 for task in tasks {
36 for dep_id in &task.dependencies {
37 if !tg.node_map.contains_key(dep_id) {
38 return Err(SdkError::TaskNotFound { task_id: *dep_id });
39 }
40 tg.add_dependency(task.id, *dep_id)?;
41 }
42 }
43
44 tg.check_cycles()?;
45 Ok(tg)
46 }
47
48 pub fn add_task(&mut self, task_id: TaskId) {
49 if !self.node_map.contains_key(&task_id) {
50 let idx = self.graph.add_node(task_id);
51 self.node_map.insert(task_id, idx);
52 }
53 }
54
55 pub fn add_dependency(&mut self, task_id: TaskId, depends_on: TaskId) -> SdkResult<()> {
56 let from = self
57 .node_map
58 .get(&depends_on)
59 .ok_or(SdkError::TaskNotFound {
60 task_id: depends_on,
61 })?;
62 let to = self
63 .node_map
64 .get(&task_id)
65 .ok_or(SdkError::TaskNotFound { task_id })?;
66
67 self.graph.add_edge(*from, *to, ());
68 Ok(())
69 }
70
71 pub fn check_cycles(&self) -> SdkResult<()> {
72 match toposort(&self.graph, None) {
73 Ok(_) => Ok(()),
74 Err(cycle) => {
75 let task_id = self.graph[cycle.node_id()];
76 Err(SdkError::DependencyCycle {
77 task_ids: vec![task_id],
78 })
79 }
80 }
81 }
82
83 pub fn topological_order(&self) -> SdkResult<Vec<TaskId>> {
84 match toposort(&self.graph, None) {
85 Ok(indices) => Ok(indices.into_iter().map(|idx| self.graph[idx]).collect()),
86 Err(cycle) => {
87 let task_id = self.graph[cycle.node_id()];
88 Err(SdkError::DependencyCycle {
89 task_ids: vec![task_id],
90 })
91 }
92 }
93 }
94
95 pub fn root_tasks(&self) -> Vec<TaskId> {
96 self.graph
97 .node_indices()
98 .filter(|&idx| {
99 self.graph
100 .neighbors_directed(idx, petgraph::Direction::Incoming)
101 .count()
102 == 0
103 })
104 .map(|idx| self.graph[idx])
105 .collect()
106 }
107
108 pub fn dependents_of(&self, task_id: TaskId) -> Vec<TaskId> {
109 if let Some(&idx) = self.node_map.get(&task_id) {
110 self.graph
111 .neighbors_directed(idx, petgraph::Direction::Outgoing)
112 .map(|idx| self.graph[idx])
113 .collect()
114 } else {
115 Vec::new()
116 }
117 }
118
119 pub fn len(&self) -> usize {
120 self.graph.node_count()
121 }
122
123 pub fn is_empty(&self) -> bool {
124 self.graph.node_count() == 0
125 }
126}
127
128#[cfg(test)]
129mod tests {
130 use super::*;
131 use uuid::Uuid;
132
133 #[test]
134 fn test_topological_order() {
135 let id_a = Uuid::new_v4();
136 let id_b = Uuid::new_v4();
137 let id_c = Uuid::new_v4();
138
139 let mut graph = TaskGraph::new();
140 graph.add_task(id_a);
141 graph.add_task(id_b);
142 graph.add_task(id_c);
143
144 graph.add_dependency(id_b, id_a).unwrap();
145 graph.add_dependency(id_c, id_b).unwrap();
146
147 let order = graph.topological_order().unwrap();
148 let pos_a = order.iter().position(|&id| id == id_a).unwrap();
149 let pos_b = order.iter().position(|&id| id == id_b).unwrap();
150 let pos_c = order.iter().position(|&id| id == id_c).unwrap();
151
152 assert!(pos_a < pos_b);
153 assert!(pos_b < pos_c);
154 }
155
156 #[test]
157 fn test_cycle_detection() {
158 let id_a = Uuid::new_v4();
159 let id_b = Uuid::new_v4();
160
161 let mut graph = TaskGraph::new();
162 graph.add_task(id_a);
163 graph.add_task(id_b);
164
165 graph.add_dependency(id_b, id_a).unwrap();
166 graph.add_dependency(id_a, id_b).unwrap();
167
168 assert!(graph.check_cycles().is_err());
169 }
170}