1use std::collections::{HashMap, HashSet, VecDeque};
4
5use serde::{Deserialize, Serialize};
6use thiserror::Error;
7
8use crate::step::TaskStep;
9
10#[derive(Debug, Error)]
11pub enum GraphError {
12 #[error("Cycle detected in task graph")]
13 CycleDetected,
14 #[error("Missing dependency: step {step} depends on {dependency} which does not exist")]
15 MissingDependency { step: String, dependency: String },
16 #[error("Step not found: {0}")]
17 StepNotFound(String),
18}
19
20#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct RollbackAction {
23 pub step_id: String,
24 pub description: String,
25 pub command: Option<String>,
26}
27
28#[derive(Debug, Clone, Serialize, Deserialize)]
30pub struct TaskGraph {
31 pub steps: HashMap<String, TaskStep>,
32 pub edges: Vec<(String, String)>, }
34
35impl TaskGraph {
36 pub fn from_steps(steps: Vec<TaskStep>) -> Result<Self, GraphError> {
39 let step_map: HashMap<String, TaskStep> =
40 steps.into_iter().map(|s| (s.id.clone(), s)).collect();
41
42 let mut edges = Vec::new();
43 for step in step_map.values() {
44 for dep in &step.depends_on {
45 if !step_map.contains_key(dep) {
46 return Err(GraphError::MissingDependency {
47 step: step.id.clone(),
48 dependency: dep.clone(),
49 });
50 }
51 edges.push((dep.clone(), step.id.clone()));
52 }
53 }
54
55 let graph = Self {
56 steps: step_map,
57 edges,
58 };
59 graph.validate()?;
60 Ok(graph)
61 }
62
63 pub fn validate(&self) -> Result<(), GraphError> {
65 let mut in_degree: HashMap<&str, usize> = HashMap::new();
66 let mut adjacency: HashMap<&str, Vec<&str>> = HashMap::new();
67
68 for id in self.steps.keys() {
69 in_degree.entry(id.as_str()).or_insert(0);
70 adjacency.entry(id.as_str()).or_default();
71 }
72
73 for (from, to) in &self.edges {
74 *in_degree.entry(to.as_str()).or_insert(0) += 1;
75 adjacency
76 .entry(from.as_str())
77 .or_default()
78 .push(to.as_str());
79 }
80
81 let mut queue: VecDeque<&str> = in_degree
82 .iter()
83 .filter(|(_, °)| deg == 0)
84 .map(|(&id, _)| id)
85 .collect();
86
87 let mut visited = 0;
88 while let Some(node) = queue.pop_front() {
89 visited += 1;
90 for &next in adjacency.get(node).unwrap_or(&vec![]) {
91 let deg = in_degree
92 .get_mut(next)
93 .expect("invariant: every step id seeded into in_degree map at start");
94 *deg -= 1;
95 if *deg == 0 {
96 queue.push_back(next);
97 }
98 }
99 }
100
101 if visited != self.steps.len() {
102 return Err(GraphError::CycleDetected);
103 }
104 Ok(())
105 }
106
107 pub fn ready_steps(&self, succeeded: &HashSet<String>) -> Vec<String> {
117 let order = self.topological_order();
118 let rank: HashMap<&str, usize> = order
119 .iter()
120 .enumerate()
121 .map(|(i, id)| (id.as_str(), i))
122 .collect();
123
124 let mut ready: Vec<String> = self
125 .steps
126 .values()
127 .filter(|step| {
128 !succeeded.contains(&step.id)
129 && step.depends_on.iter().all(|dep| succeeded.contains(dep))
130 })
131 .map(|s| s.id.clone())
132 .collect();
133 ready.sort_by_key(|id| rank.get(id.as_str()).copied().unwrap_or(usize::MAX));
134 ready
135 }
136
137 pub fn transitive_dependents(&self, step_id: &str) -> Vec<String> {
141 let mut adjacency: HashMap<&str, Vec<&str>> = HashMap::new();
142 for (from, to) in &self.edges {
143 adjacency
144 .entry(from.as_str())
145 .or_default()
146 .push(to.as_str());
147 }
148
149 let mut out = Vec::new();
150 let mut seen: HashSet<String> = HashSet::new();
151 let mut queue: VecDeque<&str> = VecDeque::new();
152 if let Some(starts) = adjacency.get(step_id) {
153 for &s in starts {
154 queue.push_back(s);
155 }
156 }
157 while let Some(node) = queue.pop_front() {
158 if !seen.insert(node.to_string()) {
159 continue;
160 }
161 out.push(node.to_string());
162 if let Some(nexts) = adjacency.get(node) {
163 for &n in nexts {
164 queue.push_back(n);
165 }
166 }
167 }
168 out
169 }
170
171 pub fn topological_order(&self) -> Vec<String> {
178 let mut in_degree: HashMap<&str, usize> = HashMap::new();
179 let mut adjacency: HashMap<&str, Vec<&str>> = HashMap::new();
180
181 for id in self.steps.keys() {
182 in_degree.entry(id.as_str()).or_insert(0);
183 adjacency.entry(id.as_str()).or_default();
184 }
185
186 for (from, to) in &self.edges {
187 *in_degree.entry(to.as_str()).or_insert(0) += 1;
188 adjacency
189 .entry(from.as_str())
190 .or_default()
191 .push(to.as_str());
192 }
193
194 use std::cmp::Reverse;
196 use std::collections::BinaryHeap;
197 let mut queue: BinaryHeap<Reverse<&str>> = in_degree
198 .iter()
199 .filter(|(_, °)| deg == 0)
200 .map(|(&id, _)| Reverse(id))
201 .collect();
202
203 let mut order = Vec::new();
204 while let Some(Reverse(node)) = queue.pop() {
205 order.push(node.to_string());
206 for &next in adjacency.get(node).unwrap_or(&vec![]) {
207 let deg = in_degree
208 .get_mut(next)
209 .expect("invariant: every step id seeded into in_degree map at start");
210 *deg -= 1;
211 if *deg == 0 {
212 queue.push(Reverse(next));
213 }
214 }
215 }
216
217 order
218 }
219
220 pub fn add_steps(&mut self, new_steps: Vec<TaskStep>) -> Result<(), GraphError> {
227 let mut universe: HashSet<String> = self.steps.keys().cloned().collect();
230 for s in &new_steps {
231 universe.insert(s.id.clone());
232 }
233 for s in &new_steps {
234 for dep in &s.depends_on {
235 if !universe.contains(dep) {
236 return Err(GraphError::MissingDependency {
237 step: s.id.clone(),
238 dependency: dep.clone(),
239 });
240 }
241 }
242 }
243 for s in new_steps {
244 for dep in &s.depends_on {
245 self.edges.push((dep.clone(), s.id.clone()));
246 }
247 self.steps.insert(s.id.clone(), s);
248 }
249 self.validate()
251 }
252
253 pub fn rollback_order(&self, from_step: &str) -> Vec<RollbackAction> {
255 let order = self.topological_order();
256 let mut result = Vec::new();
257
258 let mut include = false;
260 for id in order.iter().rev() {
261 if id == from_step {
262 include = true;
263 }
264 if include {
265 if let Some(step) = self.steps.get(id) {
266 result.push(RollbackAction {
267 step_id: id.clone(),
268 description: format!("Rollback: {}", step.description),
269 command: None,
270 });
271 }
272 }
273 }
274
275 result
276 }
277}
278
279#[cfg(test)]
280mod tests {
281 use super::*;
282 use crate::step::{StepAction, TaskStep};
283 use audit::ActionTier;
284
285 fn make_step(id: &str, deps: Vec<&str>) -> TaskStep {
286 TaskStep {
287 id: id.to_string(),
288 description: format!("Step {id}"),
289 action: StepAction::Plan {
290 output: "plan".to_string(),
291 },
292 depends_on: deps.into_iter().map(String::from).collect(),
293 tier: ActionTier::Execute,
294 estimated_tokens: 0,
295 }
296 }
297
298 #[test]
299 fn test_valid_graph() {
300 let steps = vec![
301 make_step("a", vec![]),
302 make_step("b", vec!["a"]),
303 make_step("c", vec!["a"]),
304 make_step("d", vec!["b", "c"]),
305 ];
306 let graph = TaskGraph::from_steps(steps).unwrap();
307 assert_eq!(graph.steps.len(), 4);
308 assert_eq!(graph.edges.len(), 4); }
310
311 #[test]
312 fn test_cycle_detected() {
313 let steps = vec![
314 make_step("a", vec!["c"]),
315 make_step("b", vec!["a"]),
316 make_step("c", vec!["b"]),
317 ];
318 let result = TaskGraph::from_steps(steps);
319 assert!(matches!(result, Err(GraphError::CycleDetected)));
320 }
321
322 #[test]
323 fn test_missing_dependency() {
324 let steps = vec![make_step("a", vec!["nonexistent"])];
325 let result = TaskGraph::from_steps(steps);
326 assert!(matches!(result, Err(GraphError::MissingDependency { .. })));
327 }
328
329 #[test]
330 fn test_ready_steps() {
331 let steps = vec![
332 make_step("a", vec![]),
333 make_step("b", vec!["a"]),
334 make_step("c", vec![]),
335 make_step("d", vec!["b", "c"]),
336 ];
337 let graph = TaskGraph::from_steps(steps).unwrap();
338
339 let completed = HashSet::new();
340 let mut ready = graph.ready_steps(&completed);
341 ready.sort();
342 assert_eq!(ready, vec!["a", "c"]);
343
344 let completed: HashSet<String> = ["a".to_string()].into();
345 let mut ready = graph.ready_steps(&completed);
346 ready.sort();
347 assert_eq!(ready, vec!["b", "c"]);
348
349 let completed: HashSet<String> = ["a", "b", "c"].iter().map(|s| s.to_string()).collect();
350 let ready = graph.ready_steps(&completed);
351 assert_eq!(ready, vec!["d"]);
352 }
353
354 #[test]
355 fn test_topological_order() {
356 let steps = vec![
357 make_step("a", vec![]),
358 make_step("b", vec!["a"]),
359 make_step("c", vec!["b"]),
360 ];
361 let graph = TaskGraph::from_steps(steps).unwrap();
362 let order = graph.topological_order();
363 assert_eq!(order, vec!["a", "b", "c"]);
364 }
365
366 #[test]
367 fn test_transitive_dependents() {
368 let steps = vec![
369 make_step("a", vec![]),
370 make_step("b", vec!["a"]),
371 make_step("c", vec!["a"]),
372 make_step("d", vec!["b", "c"]),
373 make_step("e", vec!["d"]),
374 ];
375 let graph = TaskGraph::from_steps(steps).unwrap();
376
377 let mut deps = graph.transitive_dependents("a");
378 deps.sort();
379 assert_eq!(deps, vec!["b", "c", "d", "e"]);
380
381 let mut deps = graph.transitive_dependents("b");
382 deps.sort();
383 assert_eq!(deps, vec!["d", "e"]);
384
385 assert!(graph.transitive_dependents("e").is_empty());
386 }
387
388 #[test]
389 fn test_topological_order_is_deterministic() {
390 let steps = vec![
394 make_step("c", vec![]),
395 make_step("a", vec![]),
396 make_step("b", vec![]),
397 ];
398 let graph = TaskGraph::from_steps(steps).unwrap();
399 let first = graph.topological_order();
400 for _ in 0..20 {
401 assert_eq!(graph.topological_order(), first);
402 }
403 assert_eq!(first, vec!["a", "b", "c"]);
404 }
405
406 #[test]
407 fn test_ready_steps_returns_deterministic_order() {
408 let steps = vec![
409 make_step("c", vec![]),
410 make_step("a", vec![]),
411 make_step("b", vec!["a"]),
412 ];
413 let graph = TaskGraph::from_steps(steps).unwrap();
414 let ready = graph.ready_steps(&HashSet::new());
415 let pos_a = ready.iter().position(|s| s == "a").unwrap();
418 let pos_c = ready.iter().position(|s| s == "c").unwrap();
419 assert!(pos_a < pos_c);
420 }
421}