dag_executor/dag/
graph.rs1use crate::error::{Result, ValidationError};
4use crate::tasks::Task;
5use petgraph::algo::toposort;
6use petgraph::graph::{DiGraph, NodeIndex};
7use petgraph::Direction;
8use std::collections::HashMap;
9use std::sync::Arc;
10
11#[derive(Default)]
18pub struct Dag {
19 graph: DiGraph<String, ()>,
20 indices: HashMap<String, NodeIndex>,
21 tasks: HashMap<String, Arc<dyn Task>>,
22}
23
24impl Dag {
25 pub fn new() -> Self {
27 Self::default()
28 }
29
30 fn node_for(&mut self, id: &str) -> NodeIndex {
31 if let Some(idx) = self.indices.get(id) {
32 return *idx;
33 }
34 let idx = self.graph.add_node(id.to_string());
35 self.indices.insert(id.to_string(), idx);
36 idx
37 }
38
39 pub fn add_task(&mut self, task: Arc<dyn Task>) -> Result<()> {
44 let id = task.id().to_string();
45 if self.tasks.contains_key(&id) {
46 return Err(ValidationError::DuplicateTask(id).into());
47 }
48
49 let deps = task.dependencies();
50 let task_idx = self.node_for(&id);
51 for dep in &deps {
52 if dep == &id {
53 return Err(ValidationError::SelfDependency(id).into());
54 }
55 let dep_idx = self.node_for(dep);
56 self.graph.update_edge(dep_idx, task_idx, ());
58 }
59
60 self.tasks.insert(id, task);
61 Ok(())
62 }
63
64 pub fn len(&self) -> usize {
66 self.tasks.len()
67 }
68
69 pub fn is_empty(&self) -> bool {
71 self.tasks.is_empty()
72 }
73
74 pub fn task(&self, id: &str) -> Option<Arc<dyn Task>> {
76 self.tasks.get(id).cloned()
77 }
78
79 pub fn task_ids(&self) -> Vec<String> {
81 self.tasks.keys().cloned().collect()
82 }
83
84 pub fn dependencies_of(&self, id: &str) -> Vec<String> {
86 self.neighbors(id, Direction::Incoming)
87 }
88
89 pub fn dependents_of(&self, id: &str) -> Vec<String> {
91 self.neighbors(id, Direction::Outgoing)
92 }
93
94 fn neighbors(&self, id: &str, dir: Direction) -> Vec<String> {
95 match self.indices.get(id) {
96 Some(&idx) => self
97 .graph
98 .neighbors_directed(idx, dir)
99 .map(|n| self.graph[n].clone())
100 .collect(),
101 None => Vec::new(),
102 }
103 }
104
105 pub fn validate(&self) -> Result<()> {
108 for (id, task) in &self.tasks {
109 for dep in task.dependencies() {
110 if !self.tasks.contains_key(&dep) {
111 return Err(ValidationError::MissingDependency {
112 task: id.clone(),
113 dep,
114 }
115 .into());
116 }
117 }
118 }
119
120 if let Err(cycle) = toposort(&self.graph, None) {
121 let id = self.graph[cycle.node_id()].clone();
122 return Err(ValidationError::Cycle(id).into());
123 }
124 Ok(())
125 }
126
127 pub fn topological_order(&self) -> Result<Vec<String>> {
131 match toposort(&self.graph, None) {
132 Ok(order) => Ok(order
133 .into_iter()
134 .map(|n| self.graph[n].clone())
135 .filter(|id| self.tasks.contains_key(id))
138 .collect()),
139 Err(cycle) => Err(ValidationError::Cycle(self.graph[cycle.node_id()].clone()).into()),
140 }
141 }
142}