Skip to main content

agent_sdk/task/
graph.rs

1use std::collections::HashMap;
2
3use petgraph::algo::toposort;
4use petgraph::graph::{DiGraph, NodeIndex};
5
6use crate::error::{SdkError, SdkResult, TaskId};
7use crate::types::task::Task;
8
9pub struct TaskGraph {
10    graph: DiGraph<TaskId, ()>,
11    node_map: HashMap<TaskId, NodeIndex>,
12}
13
14impl Default for TaskGraph {
15    fn default() -> Self {
16        Self::new()
17    }
18}
19
20impl TaskGraph {
21    pub fn new() -> Self {
22        Self {
23            graph: DiGraph::new(),
24            node_map: HashMap::new(),
25        }
26    }
27
28    pub fn from_tasks(tasks: &[Task]) -> SdkResult<Self> {
29        let mut tg = Self::new();
30
31        for task in tasks {
32            tg.add_task(task.id);
33        }
34
35        for task in tasks {
36            for dep_id in &task.dependencies {
37                if !tg.node_map.contains_key(dep_id) {
38                    return Err(SdkError::TaskNotFound { task_id: *dep_id });
39                }
40                tg.add_dependency(task.id, *dep_id)?;
41            }
42        }
43
44        tg.check_cycles()?;
45        Ok(tg)
46    }
47
48    pub fn add_task(&mut self, task_id: TaskId) {
49        if !self.node_map.contains_key(&task_id) {
50            let idx = self.graph.add_node(task_id);
51            self.node_map.insert(task_id, idx);
52        }
53    }
54
55    pub fn add_dependency(&mut self, task_id: TaskId, depends_on: TaskId) -> SdkResult<()> {
56        let from = self
57            .node_map
58            .get(&depends_on)
59            .ok_or(SdkError::TaskNotFound {
60                task_id: depends_on,
61            })?;
62        let to = self
63            .node_map
64            .get(&task_id)
65            .ok_or(SdkError::TaskNotFound { task_id })?;
66
67        self.graph.add_edge(*from, *to, ());
68        Ok(())
69    }
70
71    pub fn check_cycles(&self) -> SdkResult<()> {
72        match toposort(&self.graph, None) {
73            Ok(_) => Ok(()),
74            Err(cycle) => {
75                let task_id = self.graph[cycle.node_id()];
76                Err(SdkError::DependencyCycle {
77                    task_ids: vec![task_id],
78                })
79            }
80        }
81    }
82
83    pub fn topological_order(&self) -> SdkResult<Vec<TaskId>> {
84        match toposort(&self.graph, None) {
85            Ok(indices) => Ok(indices.into_iter().map(|idx| self.graph[idx]).collect()),
86            Err(cycle) => {
87                let task_id = self.graph[cycle.node_id()];
88                Err(SdkError::DependencyCycle {
89                    task_ids: vec![task_id],
90                })
91            }
92        }
93    }
94
95    pub fn root_tasks(&self) -> Vec<TaskId> {
96        self.graph
97            .node_indices()
98            .filter(|&idx| {
99                self.graph
100                    .neighbors_directed(idx, petgraph::Direction::Incoming)
101                    .count()
102                    == 0
103            })
104            .map(|idx| self.graph[idx])
105            .collect()
106    }
107
108    pub fn dependents_of(&self, task_id: TaskId) -> Vec<TaskId> {
109        if let Some(&idx) = self.node_map.get(&task_id) {
110            self.graph
111                .neighbors_directed(idx, petgraph::Direction::Outgoing)
112                .map(|idx| self.graph[idx])
113                .collect()
114        } else {
115            Vec::new()
116        }
117    }
118
119    pub fn len(&self) -> usize {
120        self.graph.node_count()
121    }
122
123    pub fn is_empty(&self) -> bool {
124        self.graph.node_count() == 0
125    }
126}
127
128#[cfg(test)]
129mod tests {
130    use super::*;
131    use uuid::Uuid;
132
133    #[test]
134    fn test_topological_order() {
135        let id_a = Uuid::new_v4();
136        let id_b = Uuid::new_v4();
137        let id_c = Uuid::new_v4();
138
139        let mut graph = TaskGraph::new();
140        graph.add_task(id_a);
141        graph.add_task(id_b);
142        graph.add_task(id_c);
143
144        graph.add_dependency(id_b, id_a).unwrap();
145        graph.add_dependency(id_c, id_b).unwrap();
146
147        let order = graph.topological_order().unwrap();
148        let pos_a = order.iter().position(|&id| id == id_a).unwrap();
149        let pos_b = order.iter().position(|&id| id == id_b).unwrap();
150        let pos_c = order.iter().position(|&id| id == id_c).unwrap();
151
152        assert!(pos_a < pos_b);
153        assert!(pos_b < pos_c);
154    }
155
156    #[test]
157    fn test_cycle_detection() {
158        let id_a = Uuid::new_v4();
159        let id_b = Uuid::new_v4();
160
161        let mut graph = TaskGraph::new();
162        graph.add_task(id_a);
163        graph.add_task(id_b);
164
165        graph.add_dependency(id_b, id_a).unwrap();
166        graph.add_dependency(id_a, id_b).unwrap();
167
168        assert!(graph.check_cycles().is_err());
169    }
170}