Skip to main content

dag_executor/dag/
graph.rs

1//! The task dependency graph, backed by `petgraph`.
2
3use crate::error::{Result, ValidationError};
4use crate::tasks::Task;
5use petgraph::algo::toposort;
6use petgraph::graph::{DiGraph, NodeIndex};
7use petgraph::Direction;
8use std::collections::HashMap;
9use std::sync::Arc;
10
11/// A directed acyclic graph of tasks.
12///
13/// Nodes are task ids; an edge `dep -> task` means `task` depends on `dep`
14/// (so a topological sort yields dependencies before the tasks that need them).
15/// Dependency nodes may be referenced before the dependency task itself is
16/// added; [`Dag::validate`] catches any that are never registered.
17#[derive(Default)]
18pub struct Dag {
19    graph: DiGraph<String, ()>,
20    indices: HashMap<String, NodeIndex>,
21    tasks: HashMap<String, Arc<dyn Task>>,
22}
23
24impl Dag {
25    /// Create an empty DAG.
26    pub fn new() -> Self {
27        Self::default()
28    }
29
30    fn node_for(&mut self, id: &str) -> NodeIndex {
31        if let Some(idx) = self.indices.get(id) {
32            return *idx;
33        }
34        let idx = self.graph.add_node(id.to_string());
35        self.indices.insert(id.to_string(), idx);
36        idx
37    }
38
39    /// Register a task and wire up edges from its declared dependencies.
40    ///
41    /// Returns an error on a duplicate id or a self-dependency. Missing
42    /// dependency tasks and cycles are reported later by [`Dag::validate`].
43    pub fn add_task(&mut self, task: Arc<dyn Task>) -> Result<()> {
44        let id = task.id().to_string();
45        if self.tasks.contains_key(&id) {
46            return Err(ValidationError::DuplicateTask(id).into());
47        }
48
49        let deps = task.dependencies();
50        let task_idx = self.node_for(&id);
51        for dep in &deps {
52            if dep == &id {
53                return Err(ValidationError::SelfDependency(id).into());
54            }
55            let dep_idx = self.node_for(dep);
56            // update_edge avoids duplicate parallel edges if a dep is repeated.
57            self.graph.update_edge(dep_idx, task_idx, ());
58        }
59
60        self.tasks.insert(id, task);
61        Ok(())
62    }
63
64    /// Number of registered tasks.
65    pub fn len(&self) -> usize {
66        self.tasks.len()
67    }
68
69    /// Whether the DAG has no tasks.
70    pub fn is_empty(&self) -> bool {
71        self.tasks.is_empty()
72    }
73
74    /// Look up a task by id.
75    pub fn task(&self, id: &str) -> Option<Arc<dyn Task>> {
76        self.tasks.get(id).cloned()
77    }
78
79    /// All registered task ids (unordered).
80    pub fn task_ids(&self) -> Vec<String> {
81        self.tasks.keys().cloned().collect()
82    }
83
84    /// Direct dependencies of `id` (incoming edges).
85    pub fn dependencies_of(&self, id: &str) -> Vec<String> {
86        self.neighbors(id, Direction::Incoming)
87    }
88
89    /// Direct dependents of `id` (outgoing edges).
90    pub fn dependents_of(&self, id: &str) -> Vec<String> {
91        self.neighbors(id, Direction::Outgoing)
92    }
93
94    fn neighbors(&self, id: &str, dir: Direction) -> Vec<String> {
95        match self.indices.get(id) {
96            Some(&idx) => self
97                .graph
98                .neighbors_directed(idx, dir)
99                .map(|n| self.graph[n].clone())
100                .collect(),
101            None => Vec::new(),
102        }
103    }
104
105    /// Validate the graph: every dependency must be a registered task and the
106    /// graph must be acyclic.
107    pub fn validate(&self) -> Result<()> {
108        for (id, task) in &self.tasks {
109            for dep in task.dependencies() {
110                if !self.tasks.contains_key(&dep) {
111                    return Err(ValidationError::MissingDependency {
112                        task: id.clone(),
113                        dep,
114                    }
115                    .into());
116                }
117            }
118        }
119
120        if let Err(cycle) = toposort(&self.graph, None) {
121            let id = self.graph[cycle.node_id()].clone();
122            return Err(ValidationError::Cycle(id).into());
123        }
124        Ok(())
125    }
126
127    /// Return task ids in dependency order (dependencies first).
128    ///
129    /// Errors if the graph contains a cycle.
130    pub fn topological_order(&self) -> Result<Vec<String>> {
131        match toposort(&self.graph, None) {
132            Ok(order) => Ok(order
133                .into_iter()
134                .map(|n| self.graph[n].clone())
135                // Placeholder nodes for unknown deps are filtered out; validate()
136                // is responsible for surfacing those as errors.
137                .filter(|id| self.tasks.contains_key(id))
138                .collect()),
139            Err(cycle) => Err(ValidationError::Cycle(self.graph[cycle.node_id()].clone()).into()),
140        }
141    }
142}