Skip to main content

brainos_orchestrate/
graph.rs

1//! DAG dependency graph with parallel execution support.
2
3use std::collections::{HashMap, HashSet, VecDeque};
4
5use serde::{Deserialize, Serialize};
6use thiserror::Error;
7
8use crate::step::TaskStep;
9
10#[derive(Debug, Error)]
11pub enum GraphError {
12    #[error("Cycle detected in task graph")]
13    CycleDetected,
14    #[error("Missing dependency: step {step} depends on {dependency} which does not exist")]
15    MissingDependency { step: String, dependency: String },
16    #[error("Step not found: {0}")]
17    StepNotFound(String),
18}
19
20/// Rollback action for a single step.
21#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct RollbackAction {
23    pub step_id: String,
24    pub description: String,
25    pub command: Option<String>,
26}
27
28/// Directed acyclic graph of task steps.
29#[derive(Debug, Clone, Serialize, Deserialize)]
30pub struct TaskGraph {
31    pub steps: HashMap<String, TaskStep>,
32    pub edges: Vec<(String, String)>, // (from, to) = from must complete before to
33}
34
35impl TaskGraph {
36    /// Create a new graph from a list of steps.
37    /// Edges are derived from each step's `depends_on` field.
38    pub fn from_steps(steps: Vec<TaskStep>) -> Result<Self, GraphError> {
39        let step_map: HashMap<String, TaskStep> =
40            steps.into_iter().map(|s| (s.id.clone(), s)).collect();
41
42        let mut edges = Vec::new();
43        for step in step_map.values() {
44            for dep in &step.depends_on {
45                if !step_map.contains_key(dep) {
46                    return Err(GraphError::MissingDependency {
47                        step: step.id.clone(),
48                        dependency: dep.clone(),
49                    });
50                }
51                edges.push((dep.clone(), step.id.clone()));
52            }
53        }
54
55        let graph = Self {
56            steps: step_map,
57            edges,
58        };
59        graph.validate()?;
60        Ok(graph)
61    }
62
63    /// Validate that the graph has no cycles (topological sort).
64    pub fn validate(&self) -> Result<(), GraphError> {
65        let mut in_degree: HashMap<&str, usize> = HashMap::new();
66        let mut adjacency: HashMap<&str, Vec<&str>> = HashMap::new();
67
68        for id in self.steps.keys() {
69            in_degree.entry(id.as_str()).or_insert(0);
70            adjacency.entry(id.as_str()).or_default();
71        }
72
73        for (from, to) in &self.edges {
74            *in_degree.entry(to.as_str()).or_insert(0) += 1;
75            adjacency
76                .entry(from.as_str())
77                .or_default()
78                .push(to.as_str());
79        }
80
81        let mut queue: VecDeque<&str> = in_degree
82            .iter()
83            .filter(|(_, &deg)| deg == 0)
84            .map(|(&id, _)| id)
85            .collect();
86
87        let mut visited = 0;
88        while let Some(node) = queue.pop_front() {
89            visited += 1;
90            for &next in adjacency.get(node).unwrap_or(&vec![]) {
91                let deg = in_degree
92                    .get_mut(next)
93                    .expect("invariant: every step id seeded into in_degree map at start");
94                *deg -= 1;
95                if *deg == 0 {
96                    queue.push_back(next);
97                }
98            }
99        }
100
101        if visited != self.steps.len() {
102            return Err(GraphError::CycleDetected);
103        }
104        Ok(())
105    }
106
107    /// Return step IDs that are ready to execute (all dependencies satisfied).
108    ///
109    /// `succeeded` must contain only steps that completed *successfully* —
110    /// passing in the broader "terminal" set would let a failed step's
111    /// dependents fire against missing artifacts. The orchestrator marks
112    /// dependents of a failed step as `Skipped` separately.
113    ///
114    /// Result is sorted by topological order index so plans run in the
115    /// order the decomposer intended, not in `HashMap` iteration order.
116    pub fn ready_steps(&self, succeeded: &HashSet<String>) -> Vec<String> {
117        let order = self.topological_order();
118        let rank: HashMap<&str, usize> = order
119            .iter()
120            .enumerate()
121            .map(|(i, id)| (id.as_str(), i))
122            .collect();
123
124        let mut ready: Vec<String> = self
125            .steps
126            .values()
127            .filter(|step| {
128                !succeeded.contains(&step.id)
129                    && step.depends_on.iter().all(|dep| succeeded.contains(dep))
130            })
131            .map(|s| s.id.clone())
132            .collect();
133        ready.sort_by_key(|id| rank.get(id.as_str()).copied().unwrap_or(usize::MAX));
134        ready
135    }
136
137    /// Return all steps that transitively depend on `step_id` (excluding
138    /// `step_id` itself). Used by the orchestrator to mark dependents
139    /// `Skipped` when an upstream step fails.
140    pub fn transitive_dependents(&self, step_id: &str) -> Vec<String> {
141        let mut adjacency: HashMap<&str, Vec<&str>> = HashMap::new();
142        for (from, to) in &self.edges {
143            adjacency
144                .entry(from.as_str())
145                .or_default()
146                .push(to.as_str());
147        }
148
149        let mut out = Vec::new();
150        let mut seen: HashSet<String> = HashSet::new();
151        let mut queue: VecDeque<&str> = VecDeque::new();
152        if let Some(starts) = adjacency.get(step_id) {
153            for &s in starts {
154                queue.push_back(s);
155            }
156        }
157        while let Some(node) = queue.pop_front() {
158            if !seen.insert(node.to_string()) {
159                continue;
160            }
161            out.push(node.to_string());
162            if let Some(nexts) = adjacency.get(node) {
163                for &n in nexts {
164                    queue.push_back(n);
165                }
166            }
167        }
168        out
169    }
170
171    /// Topological sort — returns step IDs in execution order.
172    ///
173    /// Order is deterministic: ties (independent steps with the same
174    /// in-degree) break by step id. Without this, two unrelated no-dep
175    /// steps would execute in `HashMap` iteration order, which meant a
176    /// "Step 1 / Step 2" plan could run in reverse on a given run.
177    pub fn topological_order(&self) -> Vec<String> {
178        let mut in_degree: HashMap<&str, usize> = HashMap::new();
179        let mut adjacency: HashMap<&str, Vec<&str>> = HashMap::new();
180
181        for id in self.steps.keys() {
182            in_degree.entry(id.as_str()).or_insert(0);
183            adjacency.entry(id.as_str()).or_default();
184        }
185
186        for (from, to) in &self.edges {
187            *in_degree.entry(to.as_str()).or_insert(0) += 1;
188            adjacency
189                .entry(from.as_str())
190                .or_default()
191                .push(to.as_str());
192        }
193
194        // BinaryHeap with Reverse → pop smallest id first → stable order.
195        use std::cmp::Reverse;
196        use std::collections::BinaryHeap;
197        let mut queue: BinaryHeap<Reverse<&str>> = in_degree
198            .iter()
199            .filter(|(_, &deg)| deg == 0)
200            .map(|(&id, _)| Reverse(id))
201            .collect();
202
203        let mut order = Vec::new();
204        while let Some(Reverse(node)) = queue.pop() {
205            order.push(node.to_string());
206            for &next in adjacency.get(node).unwrap_or(&vec![]) {
207                let deg = in_degree
208                    .get_mut(next)
209                    .expect("invariant: every step id seeded into in_degree map at start");
210                *deg -= 1;
211                if *deg == 0 {
212                    queue.push(Reverse(next));
213                }
214            }
215        }
216
217        order
218    }
219
220    /// Splice additional steps into an existing graph. Used by the
221    /// replan-on-failure loop — after an initial step fails, the
222    /// orchestrator asks the decomposer for a corrective sub-plan and
223    /// inserts the resulting steps so the execution loop can pick them
224    /// up on the next iteration. Each new step's `depends_on` may
225    /// reference either an existing step id or another new step.
226    pub fn add_steps(&mut self, new_steps: Vec<TaskStep>) -> Result<(), GraphError> {
227        // Pre-validate against the union of existing + incoming step ids
228        // so a new step that depends on a sibling new step is allowed.
229        let mut universe: HashSet<String> = self.steps.keys().cloned().collect();
230        for s in &new_steps {
231            universe.insert(s.id.clone());
232        }
233        for s in &new_steps {
234            for dep in &s.depends_on {
235                if !universe.contains(dep) {
236                    return Err(GraphError::MissingDependency {
237                        step: s.id.clone(),
238                        dependency: dep.clone(),
239                    });
240                }
241            }
242        }
243        for s in new_steps {
244            for dep in &s.depends_on {
245                self.edges.push((dep.clone(), s.id.clone()));
246            }
247            self.steps.insert(s.id.clone(), s);
248        }
249        // Re-validate for cycles introduced by the splice.
250        self.validate()
251    }
252
253    /// Reverse topological order — for rollback.
254    pub fn rollback_order(&self, from_step: &str) -> Vec<RollbackAction> {
255        let order = self.topological_order();
256        let mut result = Vec::new();
257
258        // Find all steps that were completed before (and including) the failed step
259        let mut include = false;
260        for id in order.iter().rev() {
261            if id == from_step {
262                include = true;
263            }
264            if include {
265                if let Some(step) = self.steps.get(id) {
266                    result.push(RollbackAction {
267                        step_id: id.clone(),
268                        description: format!("Rollback: {}", step.description),
269                        command: None,
270                    });
271                }
272            }
273        }
274
275        result
276    }
277}
278
279#[cfg(test)]
280mod tests {
281    use super::*;
282    use crate::step::{StepAction, TaskStep};
283    use audit::ActionTier;
284
285    fn make_step(id: &str, deps: Vec<&str>) -> TaskStep {
286        TaskStep {
287            id: id.to_string(),
288            description: format!("Step {id}"),
289            action: StepAction::Plan {
290                output: "plan".to_string(),
291            },
292            depends_on: deps.into_iter().map(String::from).collect(),
293            tier: ActionTier::Execute,
294            estimated_tokens: 0,
295        }
296    }
297
298    #[test]
299    fn test_valid_graph() {
300        let steps = vec![
301            make_step("a", vec![]),
302            make_step("b", vec!["a"]),
303            make_step("c", vec!["a"]),
304            make_step("d", vec!["b", "c"]),
305        ];
306        let graph = TaskGraph::from_steps(steps).unwrap();
307        assert_eq!(graph.steps.len(), 4);
308        assert_eq!(graph.edges.len(), 4); // a→b, a→c, b→d, c→d
309    }
310
311    #[test]
312    fn test_cycle_detected() {
313        let steps = vec![
314            make_step("a", vec!["c"]),
315            make_step("b", vec!["a"]),
316            make_step("c", vec!["b"]),
317        ];
318        let result = TaskGraph::from_steps(steps);
319        assert!(matches!(result, Err(GraphError::CycleDetected)));
320    }
321
322    #[test]
323    fn test_missing_dependency() {
324        let steps = vec![make_step("a", vec!["nonexistent"])];
325        let result = TaskGraph::from_steps(steps);
326        assert!(matches!(result, Err(GraphError::MissingDependency { .. })));
327    }
328
329    #[test]
330    fn test_ready_steps() {
331        let steps = vec![
332            make_step("a", vec![]),
333            make_step("b", vec!["a"]),
334            make_step("c", vec![]),
335            make_step("d", vec!["b", "c"]),
336        ];
337        let graph = TaskGraph::from_steps(steps).unwrap();
338
339        let completed = HashSet::new();
340        let mut ready = graph.ready_steps(&completed);
341        ready.sort();
342        assert_eq!(ready, vec!["a", "c"]);
343
344        let completed: HashSet<String> = ["a".to_string()].into();
345        let mut ready = graph.ready_steps(&completed);
346        ready.sort();
347        assert_eq!(ready, vec!["b", "c"]);
348
349        let completed: HashSet<String> = ["a", "b", "c"].iter().map(|s| s.to_string()).collect();
350        let ready = graph.ready_steps(&completed);
351        assert_eq!(ready, vec!["d"]);
352    }
353
354    #[test]
355    fn test_topological_order() {
356        let steps = vec![
357            make_step("a", vec![]),
358            make_step("b", vec!["a"]),
359            make_step("c", vec!["b"]),
360        ];
361        let graph = TaskGraph::from_steps(steps).unwrap();
362        let order = graph.topological_order();
363        assert_eq!(order, vec!["a", "b", "c"]);
364    }
365
366    #[test]
367    fn test_transitive_dependents() {
368        let steps = vec![
369            make_step("a", vec![]),
370            make_step("b", vec!["a"]),
371            make_step("c", vec!["a"]),
372            make_step("d", vec!["b", "c"]),
373            make_step("e", vec!["d"]),
374        ];
375        let graph = TaskGraph::from_steps(steps).unwrap();
376
377        let mut deps = graph.transitive_dependents("a");
378        deps.sort();
379        assert_eq!(deps, vec!["b", "c", "d", "e"]);
380
381        let mut deps = graph.transitive_dependents("b");
382        deps.sort();
383        assert_eq!(deps, vec!["d", "e"]);
384
385        assert!(graph.transitive_dependents("e").is_empty());
386    }
387
388    #[test]
389    fn test_topological_order_is_deterministic() {
390        // Three independent steps — without tie-breaking these would
391        // come back in HashMap iteration order. We need a stable order
392        // so plans like "step 1 / step 2 / step 3" run as written.
393        let steps = vec![
394            make_step("c", vec![]),
395            make_step("a", vec![]),
396            make_step("b", vec![]),
397        ];
398        let graph = TaskGraph::from_steps(steps).unwrap();
399        let first = graph.topological_order();
400        for _ in 0..20 {
401            assert_eq!(graph.topological_order(), first);
402        }
403        assert_eq!(first, vec!["a", "b", "c"]);
404    }
405
406    #[test]
407    fn test_ready_steps_returns_deterministic_order() {
408        let steps = vec![
409            make_step("c", vec![]),
410            make_step("a", vec![]),
411            make_step("b", vec!["a"]),
412        ];
413        let graph = TaskGraph::from_steps(steps).unwrap();
414        let ready = graph.ready_steps(&HashSet::new());
415        // `a` must come before `c` because that's its topological order.
416        // Without the sort, `c` could come first due to HashMap order.
417        let pos_a = ready.iter().position(|s| s == "a").unwrap();
418        let pos_c = ready.iter().position(|s| s == "c").unwrap();
419        assert!(pos_a < pos_c);
420    }
421}