1use super::{Task, TaskDefinition, TaskGroup, Tasks};
7use crate::Result;
8use petgraph::algo::{is_cyclic_directed, toposort};
9use petgraph::graph::{DiGraph, NodeIndex};
10use petgraph::visit::IntoNodeReferences;
11use std::collections::{HashMap, HashSet};
12use tracing::debug;
13
14#[derive(Debug, Clone)]
16pub struct TaskNode {
17 pub name: String,
19 pub task: Task,
21}
22
23pub struct TaskGraph {
25 graph: DiGraph<TaskNode, ()>,
27 name_to_node: HashMap<String, NodeIndex>,
29}
30
31impl TaskGraph {
32 pub fn new() -> Self {
34 Self {
35 graph: DiGraph::new(),
36 name_to_node: HashMap::new(),
37 }
38 }
39
40 pub fn build_from_definition(
42 &mut self,
43 name: &str,
44 definition: &TaskDefinition,
45 all_tasks: &Tasks,
46 ) -> Result<Vec<NodeIndex>> {
47 match definition {
48 TaskDefinition::Single(task) => {
49 let node = self.add_task(name, task.as_ref().clone())?;
50 Ok(vec![node])
51 }
52 TaskDefinition::Group(group) => self.build_from_group(name, group, all_tasks),
53 }
54 }
55
56 fn build_from_group(
58 &mut self,
59 prefix: &str,
60 group: &TaskGroup,
61 all_tasks: &Tasks,
62 ) -> Result<Vec<NodeIndex>> {
63 match group {
64 TaskGroup::Sequential(tasks) => self.build_sequential_group(prefix, tasks, all_tasks),
65 TaskGroup::Parallel(tasks) => self.build_parallel_group(prefix, tasks, all_tasks),
66 }
67 }
68
69 fn build_sequential_group(
71 &mut self,
72 prefix: &str,
73 tasks: &[TaskDefinition],
74 all_tasks: &Tasks,
75 ) -> Result<Vec<NodeIndex>> {
76 let mut nodes = Vec::new();
77 let mut previous: Option<NodeIndex> = None;
78
79 for (i, task_def) in tasks.iter().enumerate() {
80 let task_name = format!("{}[{}]", prefix, i);
81 let task_nodes = self.build_from_definition(&task_name, task_def, all_tasks)?;
82
83 if let Some(prev) = previous
85 && let Some(first) = task_nodes.first()
86 {
87 self.graph.add_edge(prev, *first, ());
88 }
89
90 if let Some(last) = task_nodes.last() {
91 previous = Some(*last);
92 }
93
94 nodes.extend(task_nodes);
95 }
96
97 Ok(nodes)
98 }
99
100 fn build_parallel_group(
102 &mut self,
103 prefix: &str,
104 tasks: &HashMap<String, TaskDefinition>,
105 all_tasks: &Tasks,
106 ) -> Result<Vec<NodeIndex>> {
107 let mut nodes = Vec::new();
108
109 for (name, task_def) in tasks {
110 let task_name = format!("{}.{}", prefix, name);
111 let task_nodes = self.build_from_definition(&task_name, task_def, all_tasks)?;
112 nodes.extend(task_nodes);
113 }
114
115 Ok(nodes)
116 }
117
118 pub fn add_task(&mut self, name: &str, task: Task) -> Result<NodeIndex> {
120 if let Some(&node) = self.name_to_node.get(name) {
122 return Ok(node);
123 }
124
125 let node = TaskNode {
126 name: name.to_string(),
127 task,
128 };
129
130 let node_index = self.graph.add_node(node);
131 self.name_to_node.insert(name.to_string(), node_index);
132 debug!("Added task node '{}'", name);
133
134 Ok(node_index)
135 }
136
137 fn add_dependency_edges(&mut self) -> Result<()> {
140 let mut missing_deps = Vec::new();
141 let mut edges_to_add = Vec::new();
142
143 for (node_index, node) in self.graph.node_references() {
145 for dep_name in &node.task.depends_on {
146 if let Some(&dep_node_index) = self.name_to_node.get(dep_name as &str) {
147 edges_to_add.push((dep_node_index, node_index));
149 } else {
150 missing_deps.push((node.name.clone(), dep_name.clone()));
151 }
152 }
153 }
154
155 if !missing_deps.is_empty() {
157 let missing_list = missing_deps
158 .iter()
159 .map(|(task, dep)| format!("Task '{}' depends on missing task '{}'", task, dep))
160 .collect::<Vec<_>>()
161 .join(", ");
162 return Err(crate::Error::configuration(format!(
163 "Missing dependencies: {}",
164 missing_list
165 )));
166 }
167
168 for (from, to) in edges_to_add {
170 self.graph.add_edge(from, to, ());
171 }
172
173 Ok(())
174 }
175
176 pub fn has_cycles(&self) -> bool {
178 is_cyclic_directed(&self.graph)
179 }
180
181 pub fn topological_sort(&self) -> Result<Vec<TaskNode>> {
183 if self.has_cycles() {
184 return Err(crate::Error::configuration(
185 "Task dependency graph contains cycles".to_string(),
186 ));
187 }
188
189 match toposort(&self.graph, None) {
190 Ok(sorted_indices) => Ok(sorted_indices
191 .into_iter()
192 .map(|idx| self.graph[idx].clone())
193 .collect()),
194 Err(_) => Err(crate::Error::configuration(
195 "Failed to sort tasks topologically".to_string(),
196 )),
197 }
198 }
199
200 pub fn get_parallel_groups(&self) -> Result<Vec<Vec<TaskNode>>> {
202 let sorted = self.topological_sort()?;
203
204 if sorted.is_empty() {
205 return Ok(vec![]);
206 }
207
208 let mut groups: Vec<Vec<TaskNode>> = vec![];
210 let mut processed: HashMap<String, usize> = HashMap::new();
211
212 for task in sorted {
213 let mut level = 0;
215 for dep in &task.task.depends_on {
216 if let Some(&dep_level) = processed.get(dep) {
217 level = level.max(dep_level + 1);
218 }
219 }
220
221 if level >= groups.len() {
223 groups.resize(level + 1, vec![]);
224 }
225 groups[level].push(task.clone());
226 processed.insert(task.name.clone(), level);
227 }
228
229 Ok(groups)
230 }
231
232 pub fn task_count(&self) -> usize {
234 self.graph.node_count()
235 }
236
237 pub fn contains_task(&self, name: &str) -> bool {
239 self.name_to_node.contains_key(name)
240 }
241
242 pub fn build_complete_graph(&mut self, tasks: &Tasks) -> Result<()> {
245 for (name, definition) in tasks.tasks.iter() {
247 match definition {
248 TaskDefinition::Single(task) => {
249 self.add_task(name, task.as_ref().clone())?;
250 }
251 TaskDefinition::Group(_) => {
252 }
256 }
257 }
258
259 self.add_dependency_edges()?;
261
262 Ok(())
263 }
264
265 pub fn build_for_task(&mut self, task_name: &str, all_tasks: &Tasks) -> Result<()> {
267 let mut to_process = vec![task_name.to_string()];
268 let mut processed = HashSet::new();
269
270 debug!(
271 "Building graph for '{}' with tasks {:?}",
272 task_name,
273 all_tasks.list_tasks()
274 );
275
276 while let Some(current_name) = to_process.pop() {
278 if processed.contains(¤t_name) {
279 continue;
280 }
281 processed.insert(current_name.clone());
282
283 if let Some(definition) = all_tasks.get(¤t_name) {
284 match definition {
285 TaskDefinition::Single(task) => {
286 self.add_task(¤t_name, task.as_ref().clone())?;
287 for dep in &task.depends_on {
289 if !processed.contains(dep) {
290 to_process.push(dep.clone());
291 }
292 }
293 }
294 TaskDefinition::Group(_) => {
295 self.build_from_definition(¤t_name, definition, all_tasks)?;
297 }
298 }
299 } else {
300 debug!("Task '{}' not found while building graph", current_name);
301 }
302 }
303
304 self.add_dependency_edges()?;
306
307 Ok(())
308 }
309}
310
311impl Default for TaskGraph {
312 fn default() -> Self {
313 Self::new()
314 }
315}
316
317#[cfg(test)]
318mod tests {
319 use super::*;
320
321 fn create_test_task(name: &str, deps: Vec<String>) -> Task {
322 Task {
323 command: format!("echo {}", name),
324 args: vec![],
325 shell: None,
326 env: HashMap::new(),
327 depends_on: deps,
328 inputs: vec![],
329 outputs: vec![],
330 external_inputs: None,
331 workspaces: vec![],
332 description: Some(format!("Test task {}", name)),
333 }
334 }
335
336 #[test]
337 fn test_task_graph_new() {
338 let graph = TaskGraph::new();
339 assert_eq!(graph.task_count(), 0);
340 }
341
342 #[test]
343 fn test_add_single_task() {
344 let mut graph = TaskGraph::new();
345 let task = create_test_task("test", vec![]);
346
347 let node = graph.add_task("test", task).unwrap();
348 assert!(graph.contains_task("test"));
349 assert_eq!(graph.task_count(), 1);
350
351 let task2 = create_test_task("test", vec![]);
353 let node2 = graph.add_task("test", task2).unwrap();
354 assert_eq!(node, node2);
355 assert_eq!(graph.task_count(), 1);
356 }
357
358 #[test]
359 fn test_task_dependencies() {
360 let mut graph = TaskGraph::new();
361
362 let task1 = create_test_task("task1", vec![]);
364 let task2 = create_test_task("task2", vec!["task1".to_string()]);
365 let task3 = create_test_task("task3", vec!["task1".to_string(), "task2".to_string()]);
366
367 graph.add_task("task1", task1).unwrap();
368 graph.add_task("task2", task2).unwrap();
369 graph.add_task("task3", task3).unwrap();
370 graph.add_dependency_edges().unwrap(); assert_eq!(graph.task_count(), 3);
373 assert!(!graph.has_cycles());
374
375 let sorted = graph.topological_sort().unwrap();
376 assert_eq!(sorted.len(), 3);
377
378 let positions: HashMap<String, usize> = sorted
380 .iter()
381 .enumerate()
382 .map(|(i, node)| (node.name.clone(), i))
383 .collect();
384
385 assert!(positions["task1"] < positions["task2"]);
386 assert!(positions["task1"] < positions["task3"]);
387 assert!(positions["task2"] < positions["task3"]);
388 }
389
390 #[test]
391 fn test_cycle_detection() {
392 let mut graph = TaskGraph::new();
393
394 let task1 = create_test_task("task1", vec!["task3".to_string()]);
396 let task2 = create_test_task("task2", vec!["task1".to_string()]);
397 let task3 = create_test_task("task3", vec!["task2".to_string()]);
398
399 graph.add_task("task1", task1).unwrap();
400 graph.add_task("task2", task2).unwrap();
401 graph.add_task("task3", task3).unwrap();
402 graph.add_dependency_edges().unwrap(); assert!(graph.has_cycles());
405 assert!(graph.topological_sort().is_err());
406 }
407
408 #[test]
409 fn test_parallel_groups() {
410 let mut graph = TaskGraph::new();
411
412 let task1 = create_test_task("task1", vec![]);
418 let task2 = create_test_task("task2", vec![]);
419 let task3 = create_test_task("task3", vec!["task1".to_string()]);
420 let task4 = create_test_task("task4", vec!["task2".to_string()]);
421 let task5 = create_test_task("task5", vec!["task3".to_string(), "task4".to_string()]);
422
423 graph.add_task("task1", task1).unwrap();
424 graph.add_task("task2", task2).unwrap();
425 graph.add_task("task3", task3).unwrap();
426 graph.add_task("task4", task4).unwrap();
427 graph.add_task("task5", task5).unwrap();
428 graph.add_dependency_edges().unwrap(); let groups = graph.get_parallel_groups().unwrap();
431
432 assert_eq!(groups.len(), 3);
434
435 assert_eq!(groups[0].len(), 2);
437
438 assert_eq!(groups[1].len(), 2);
440
441 assert_eq!(groups[2].len(), 1);
443 assert_eq!(groups[2][0].name, "task5");
444 }
445
446 #[test]
447 fn test_build_from_sequential_group() {
448 let mut graph = TaskGraph::new();
449 let tasks = Tasks::new();
450
451 let task1 = create_test_task("t1", vec![]);
452 let task2 = create_test_task("t2", vec![]);
453
454 let group = TaskGroup::Sequential(vec![
455 TaskDefinition::Single(Box::new(task1)),
456 TaskDefinition::Single(Box::new(task2)),
457 ]);
458
459 let nodes = graph.build_from_group("seq", &group, &tasks).unwrap();
460 assert_eq!(nodes.len(), 2);
461
462 let sorted = graph.topological_sort().unwrap();
464 assert_eq!(sorted.len(), 2);
465 assert_eq!(sorted[0].name, "seq[0]");
466 assert_eq!(sorted[1].name, "seq[1]");
467 }
468
469 #[test]
470 fn test_build_from_parallel_group() {
471 let mut graph = TaskGraph::new();
472 let tasks = Tasks::new();
473
474 let task1 = create_test_task("t1", vec![]);
475 let task2 = create_test_task("t2", vec![]);
476
477 let mut parallel_tasks = HashMap::new();
478 parallel_tasks.insert("first".to_string(), TaskDefinition::Single(Box::new(task1)));
479 parallel_tasks.insert(
480 "second".to_string(),
481 TaskDefinition::Single(Box::new(task2)),
482 );
483
484 let group = TaskGroup::Parallel(parallel_tasks);
485
486 let nodes = graph.build_from_group("par", &group, &tasks).unwrap();
487 assert_eq!(nodes.len(), 2);
488
489 assert!(!graph.has_cycles());
491
492 let groups = graph.get_parallel_groups().unwrap();
493 assert_eq!(groups.len(), 1); assert_eq!(groups[0].len(), 2); }
496
497 #[test]
498 fn test_three_way_cycle_detection() {
499 let mut graph = TaskGraph::new();
500
501 let task_a = create_test_task("task_a", vec!["task_c".to_string()]);
503 let task_b = create_test_task("task_b", vec!["task_a".to_string()]);
504 let task_c = create_test_task("task_c", vec!["task_b".to_string()]);
505
506 graph.add_task("task_a", task_a).unwrap();
507 graph.add_task("task_b", task_b).unwrap();
508 graph.add_task("task_c", task_c).unwrap();
509 graph.add_dependency_edges().unwrap(); assert!(graph.has_cycles());
513
514 assert!(graph.get_parallel_groups().is_err());
516 }
517
518 #[test]
519 fn test_self_dependency_cycle() {
520 let mut graph = TaskGraph::new();
521
522 let task = create_test_task("self_ref", vec!["self_ref".to_string()]);
524 graph.add_task("self_ref", task).unwrap();
525 graph.add_dependency_edges().unwrap(); assert!(graph.has_cycles());
528 assert!(graph.get_parallel_groups().is_err());
529 }
530
531 #[test]
532 fn test_complex_dependency_graph() {
533 let mut graph = TaskGraph::new();
534
535 let task_a = create_test_task("a", vec![]);
542 let task_b = create_test_task("b", vec!["a".to_string()]);
543 let task_c = create_test_task("c", vec!["a".to_string()]);
544 let task_d = create_test_task("d", vec!["b".to_string(), "c".to_string()]);
545
546 graph.add_task("a", task_a).unwrap();
547 graph.add_task("b", task_b).unwrap();
548 graph.add_task("c", task_c).unwrap();
549 graph.add_task("d", task_d).unwrap();
550 graph.add_dependency_edges().unwrap(); assert!(!graph.has_cycles());
553 assert_eq!(graph.task_count(), 4);
554
555 let groups = graph.get_parallel_groups().unwrap();
556
557 assert_eq!(groups.len(), 3);
559 assert_eq!(groups[0].len(), 1); assert_eq!(groups[1].len(), 2); assert_eq!(groups[2].len(), 1); }
563
564 #[test]
565 fn test_missing_dependency() {
566 let mut graph = TaskGraph::new();
567
568 let task = create_test_task("dependent", vec!["missing".to_string()]);
570 graph.add_task("dependent", task).unwrap();
571
572 assert!(graph.add_dependency_edges().is_err());
574 }
575
576 #[test]
577 fn test_empty_graph() {
578 let graph = TaskGraph::new();
579
580 assert_eq!(graph.task_count(), 0);
581 assert!(!graph.has_cycles());
582
583 let groups = graph.get_parallel_groups().unwrap();
584 assert!(groups.is_empty());
585 }
586
587 #[test]
588 fn test_single_task_no_deps() {
589 let mut graph = TaskGraph::new();
590
591 let task = create_test_task("solo", vec![]);
592 graph.add_task("solo", task).unwrap();
593
594 assert_eq!(graph.task_count(), 1);
595 assert!(!graph.has_cycles());
596
597 let groups = graph.get_parallel_groups().unwrap();
598 assert_eq!(groups.len(), 1);
599 assert_eq!(groups[0].len(), 1);
600 }
601
602 #[test]
603 fn test_linear_chain() {
604 let mut graph = TaskGraph::new();
605
606 let task_a = create_test_task("a", vec![]);
608 let task_b = create_test_task("b", vec!["a".to_string()]);
609 let task_c = create_test_task("c", vec!["b".to_string()]);
610 let task_d = create_test_task("d", vec!["c".to_string()]);
611
612 graph.add_task("a", task_a).unwrap();
613 graph.add_task("b", task_b).unwrap();
614 graph.add_task("c", task_c).unwrap();
615 graph.add_task("d", task_d).unwrap();
616 graph.add_dependency_edges().unwrap(); assert!(!graph.has_cycles());
619 assert_eq!(graph.task_count(), 4);
620
621 let groups = graph.get_parallel_groups().unwrap();
622
623 assert_eq!(groups.len(), 4);
625 for group in &groups {
626 assert_eq!(group.len(), 1);
627 }
628 }
629}