1use super::{ParallelGroup, Task, TaskDefinition, TaskGroup, Tasks};
7use crate::Result;
8use petgraph::algo::{is_cyclic_directed, toposort};
9use petgraph::graph::{DiGraph, NodeIndex};
10use petgraph::visit::IntoNodeReferences;
11use std::collections::{HashMap, HashSet};
12use tracing::debug;
13
14#[derive(Debug, Clone)]
16pub struct TaskNode {
17 pub name: String,
19 pub task: Task,
21}
22
23pub struct TaskGraph {
25 graph: DiGraph<TaskNode, ()>,
27 name_to_node: HashMap<String, NodeIndex>,
29}
30
31impl TaskGraph {
32 pub fn new() -> Self {
34 Self {
35 graph: DiGraph::new(),
36 name_to_node: HashMap::new(),
37 }
38 }
39
40 pub fn build_from_definition(
42 &mut self,
43 name: &str,
44 definition: &TaskDefinition,
45 all_tasks: &Tasks,
46 ) -> Result<Vec<NodeIndex>> {
47 match definition {
48 TaskDefinition::Single(task) => {
49 let node = self.add_task(name, task.as_ref().clone())?;
50 Ok(vec![node])
51 }
52 TaskDefinition::Group(group) => self.build_from_group(name, group, all_tasks),
53 }
54 }
55
56 fn build_from_group(
58 &mut self,
59 prefix: &str,
60 group: &TaskGroup,
61 all_tasks: &Tasks,
62 ) -> Result<Vec<NodeIndex>> {
63 match group {
64 TaskGroup::Sequential(tasks) => self.build_sequential_group(prefix, tasks, all_tasks),
65 TaskGroup::Parallel(group) => self.build_parallel_group(prefix, group, all_tasks),
66 }
67 }
68
69 fn build_sequential_group(
71 &mut self,
72 prefix: &str,
73 tasks: &[TaskDefinition],
74 all_tasks: &Tasks,
75 ) -> Result<Vec<NodeIndex>> {
76 let mut nodes = Vec::new();
77 let mut previous: Option<NodeIndex> = None;
78
79 for (i, task_def) in tasks.iter().enumerate() {
80 let task_name = format!("{}[{}]", prefix, i);
81 let task_nodes = self.build_from_definition(&task_name, task_def, all_tasks)?;
82
83 if let Some(prev) = previous
85 && let Some(first) = task_nodes.first()
86 {
87 self.graph.add_edge(prev, *first, ());
88 }
89
90 if let Some(last) = task_nodes.last() {
91 previous = Some(*last);
92 }
93
94 nodes.extend(task_nodes);
95 }
96
97 Ok(nodes)
98 }
99
100 fn build_parallel_group(
102 &mut self,
103 prefix: &str,
104 group: &ParallelGroup,
105 all_tasks: &Tasks,
106 ) -> Result<Vec<NodeIndex>> {
107 let mut nodes = Vec::new();
108
109 for (name, task_def) in &group.tasks {
110 let task_name = format!("{}.{}", prefix, name);
111 let task_nodes = self.build_from_definition(&task_name, task_def, all_tasks)?;
112
113 if !group.depends_on.is_empty() {
115 for node_idx in &task_nodes {
116 let node = &mut self.graph[*node_idx];
117 for dep in &group.depends_on {
118 if !node.task.depends_on.contains(dep) {
119 node.task.depends_on.push(dep.clone());
120 }
121 }
122 }
123 }
124
125 nodes.extend(task_nodes);
126 }
127
128 Ok(nodes)
129 }
130
131 pub fn add_task(&mut self, name: &str, task: Task) -> Result<NodeIndex> {
133 if let Some(&node) = self.name_to_node.get(name) {
135 return Ok(node);
136 }
137
138 let node = TaskNode {
139 name: name.to_string(),
140 task,
141 };
142
143 let node_index = self.graph.add_node(node);
144 self.name_to_node.insert(name.to_string(), node_index);
145 debug!("Added task node '{}'", name);
146
147 Ok(node_index)
148 }
149
150 fn add_dependency_edges(&mut self) -> Result<()> {
153 let mut missing_deps = Vec::new();
154 let mut edges_to_add = Vec::new();
155
156 for (node_index, node) in self.graph.node_references() {
158 for dep_name in &node.task.depends_on {
159 if let Some(&dep_node_index) = self.name_to_node.get(dep_name as &str) {
160 edges_to_add.push((dep_node_index, node_index));
162 } else {
163 missing_deps.push((node.name.clone(), dep_name.clone()));
164 }
165 }
166 }
167
168 if !missing_deps.is_empty() {
170 let missing_list = missing_deps
171 .iter()
172 .map(|(task, dep)| format!("Task '{}' depends on missing task '{}'", task, dep))
173 .collect::<Vec<_>>()
174 .join(", ");
175 return Err(crate::Error::configuration(format!(
176 "Missing dependencies: {}",
177 missing_list
178 )));
179 }
180
181 for (from, to) in edges_to_add {
183 self.graph.add_edge(from, to, ());
184 }
185
186 Ok(())
187 }
188
189 pub fn has_cycles(&self) -> bool {
191 is_cyclic_directed(&self.graph)
192 }
193
194 pub fn topological_sort(&self) -> Result<Vec<TaskNode>> {
196 if self.has_cycles() {
197 return Err(crate::Error::configuration(
198 "Task dependency graph contains cycles".to_string(),
199 ));
200 }
201
202 match toposort(&self.graph, None) {
203 Ok(sorted_indices) => Ok(sorted_indices
204 .into_iter()
205 .map(|idx| self.graph[idx].clone())
206 .collect()),
207 Err(_) => Err(crate::Error::configuration(
208 "Failed to sort tasks topologically".to_string(),
209 )),
210 }
211 }
212
213 pub fn get_parallel_groups(&self) -> Result<Vec<Vec<TaskNode>>> {
215 let sorted = self.topological_sort()?;
216
217 if sorted.is_empty() {
218 return Ok(vec![]);
219 }
220
221 let mut groups: Vec<Vec<TaskNode>> = vec![];
223 let mut processed: HashMap<String, usize> = HashMap::new();
224
225 for task in sorted {
226 let mut level = 0;
228 for dep in &task.task.depends_on {
229 if let Some(&dep_level) = processed.get(dep) {
230 level = level.max(dep_level + 1);
231 }
232 }
233
234 if level >= groups.len() {
236 groups.resize(level + 1, vec![]);
237 }
238 groups[level].push(task.clone());
239 processed.insert(task.name.clone(), level);
240 }
241
242 Ok(groups)
243 }
244
245 pub fn task_count(&self) -> usize {
247 self.graph.node_count()
248 }
249
250 pub fn contains_task(&self, name: &str) -> bool {
252 self.name_to_node.contains_key(name)
253 }
254
255 pub fn build_complete_graph(&mut self, tasks: &Tasks) -> Result<()> {
258 for (name, definition) in tasks.tasks.iter() {
260 match definition {
261 TaskDefinition::Single(task) => {
262 self.add_task(name, task.as_ref().clone())?;
263 }
264 TaskDefinition::Group(_) => {
265 }
269 }
270 }
271
272 self.add_dependency_edges()?;
274
275 Ok(())
276 }
277
278 pub fn build_for_task(&mut self, task_name: &str, all_tasks: &Tasks) -> Result<()> {
280 let mut to_process = vec![task_name.to_string()];
281 let mut processed = HashSet::new();
282
283 debug!(
284 "Building graph for '{}' with tasks {:?}",
285 task_name,
286 all_tasks.list_tasks()
287 );
288
289 while let Some(current_name) = to_process.pop() {
291 if processed.contains(¤t_name) {
292 continue;
293 }
294 processed.insert(current_name.clone());
295
296 if let Some(definition) = all_tasks.get(¤t_name) {
297 match definition {
298 TaskDefinition::Single(task) => {
299 self.add_task(¤t_name, task.as_ref().clone())?;
300 for dep in &task.depends_on {
302 if !processed.contains(dep) {
303 to_process.push(dep.clone());
304 }
305 }
306 }
307 TaskDefinition::Group(_) => {
308 let added_nodes =
310 self.build_from_definition(¤t_name, definition, all_tasks)?;
311 for node_idx in added_nodes {
313 let node = &self.graph[node_idx];
314 for dep in &node.task.depends_on {
315 if !processed.contains(dep) {
316 to_process.push(dep.clone());
317 }
318 }
319 }
320 }
321 }
322 } else {
323 debug!("Task '{}' not found while building graph", current_name);
324 }
325 }
326
327 self.add_dependency_edges()?;
329
330 Ok(())
331 }
332}
333
334impl Default for TaskGraph {
335 fn default() -> Self {
336 Self::new()
337 }
338}
339
340#[cfg(test)]
341mod tests {
342 use super::*;
343
344 fn create_test_task(name: &str, deps: Vec<String>) -> Task {
345 Task {
346 command: format!("echo {}", name),
347 depends_on: deps,
348 description: Some(format!("Test task {}", name)),
349 ..Default::default()
350 }
351 }
352
353 #[test]
354 fn test_task_graph_new() {
355 let graph = TaskGraph::new();
356 assert_eq!(graph.task_count(), 0);
357 }
358
359 #[test]
360 fn test_add_single_task() {
361 let mut graph = TaskGraph::new();
362 let task = create_test_task("test", vec![]);
363
364 let node = graph.add_task("test", task).unwrap();
365 assert!(graph.contains_task("test"));
366 assert_eq!(graph.task_count(), 1);
367
368 let task2 = create_test_task("test", vec![]);
370 let node2 = graph.add_task("test", task2).unwrap();
371 assert_eq!(node, node2);
372 assert_eq!(graph.task_count(), 1);
373 }
374
375 #[test]
376 fn test_task_dependencies() {
377 let mut graph = TaskGraph::new();
378
379 let task1 = create_test_task("task1", vec![]);
381 let task2 = create_test_task("task2", vec!["task1".to_string()]);
382 let task3 = create_test_task("task3", vec!["task1".to_string(), "task2".to_string()]);
383
384 graph.add_task("task1", task1).unwrap();
385 graph.add_task("task2", task2).unwrap();
386 graph.add_task("task3", task3).unwrap();
387 graph.add_dependency_edges().unwrap(); assert_eq!(graph.task_count(), 3);
390 assert!(!graph.has_cycles());
391
392 let sorted = graph.topological_sort().unwrap();
393 assert_eq!(sorted.len(), 3);
394
395 let positions: HashMap<String, usize> = sorted
397 .iter()
398 .enumerate()
399 .map(|(i, node)| (node.name.clone(), i))
400 .collect();
401
402 assert!(positions["task1"] < positions["task2"]);
403 assert!(positions["task1"] < positions["task3"]);
404 assert!(positions["task2"] < positions["task3"]);
405 }
406
407 #[test]
408 fn test_cycle_detection() {
409 let mut graph = TaskGraph::new();
410
411 let task1 = create_test_task("task1", vec!["task3".to_string()]);
413 let task2 = create_test_task("task2", vec!["task1".to_string()]);
414 let task3 = create_test_task("task3", vec!["task2".to_string()]);
415
416 graph.add_task("task1", task1).unwrap();
417 graph.add_task("task2", task2).unwrap();
418 graph.add_task("task3", task3).unwrap();
419 graph.add_dependency_edges().unwrap(); assert!(graph.has_cycles());
422 assert!(graph.topological_sort().is_err());
423 }
424
425 #[test]
426 fn test_parallel_groups() {
427 let mut graph = TaskGraph::new();
428
429 let task1 = create_test_task("task1", vec![]);
435 let task2 = create_test_task("task2", vec![]);
436 let task3 = create_test_task("task3", vec!["task1".to_string()]);
437 let task4 = create_test_task("task4", vec!["task2".to_string()]);
438 let task5 = create_test_task("task5", vec!["task3".to_string(), "task4".to_string()]);
439
440 graph.add_task("task1", task1).unwrap();
441 graph.add_task("task2", task2).unwrap();
442 graph.add_task("task3", task3).unwrap();
443 graph.add_task("task4", task4).unwrap();
444 graph.add_task("task5", task5).unwrap();
445 graph.add_dependency_edges().unwrap(); let groups = graph.get_parallel_groups().unwrap();
448
449 assert_eq!(groups.len(), 3);
451
452 assert_eq!(groups[0].len(), 2);
454
455 assert_eq!(groups[1].len(), 2);
457
458 assert_eq!(groups[2].len(), 1);
460 assert_eq!(groups[2][0].name, "task5");
461 }
462
463 #[test]
464 fn test_build_from_sequential_group() {
465 let mut graph = TaskGraph::new();
466 let tasks = Tasks::new();
467
468 let task1 = create_test_task("t1", vec![]);
469 let task2 = create_test_task("t2", vec![]);
470
471 let group = TaskGroup::Sequential(vec![
472 TaskDefinition::Single(Box::new(task1)),
473 TaskDefinition::Single(Box::new(task2)),
474 ]);
475
476 let nodes = graph.build_from_group("seq", &group, &tasks).unwrap();
477 assert_eq!(nodes.len(), 2);
478
479 let sorted = graph.topological_sort().unwrap();
481 assert_eq!(sorted.len(), 2);
482 assert_eq!(sorted[0].name, "seq[0]");
483 assert_eq!(sorted[1].name, "seq[1]");
484 }
485
486 #[test]
487 fn test_build_from_parallel_group() {
488 let mut graph = TaskGraph::new();
489 let tasks = Tasks::new();
490
491 let task1 = create_test_task("t1", vec![]);
492 let task2 = create_test_task("t2", vec![]);
493
494 let mut parallel_tasks = HashMap::new();
495 parallel_tasks.insert("first".to_string(), TaskDefinition::Single(Box::new(task1)));
496 parallel_tasks.insert(
497 "second".to_string(),
498 TaskDefinition::Single(Box::new(task2)),
499 );
500
501 let group = TaskGroup::Parallel(ParallelGroup {
502 tasks: parallel_tasks,
503 depends_on: vec![],
504 });
505
506 let nodes = graph.build_from_group("par", &group, &tasks).unwrap();
507 assert_eq!(nodes.len(), 2);
508
509 assert!(!graph.has_cycles());
511
512 let groups = graph.get_parallel_groups().unwrap();
513 assert_eq!(groups.len(), 1); assert_eq!(groups[0].len(), 2); }
516
517 #[test]
518 fn test_three_way_cycle_detection() {
519 let mut graph = TaskGraph::new();
520
521 let task_a = create_test_task("task_a", vec!["task_c".to_string()]);
523 let task_b = create_test_task("task_b", vec!["task_a".to_string()]);
524 let task_c = create_test_task("task_c", vec!["task_b".to_string()]);
525
526 graph.add_task("task_a", task_a).unwrap();
527 graph.add_task("task_b", task_b).unwrap();
528 graph.add_task("task_c", task_c).unwrap();
529 graph.add_dependency_edges().unwrap(); assert!(graph.has_cycles());
533
534 assert!(graph.get_parallel_groups().is_err());
536 }
537
538 #[test]
539 fn test_self_dependency_cycle() {
540 let mut graph = TaskGraph::new();
541
542 let task = create_test_task("self_ref", vec!["self_ref".to_string()]);
544 graph.add_task("self_ref", task).unwrap();
545 graph.add_dependency_edges().unwrap(); assert!(graph.has_cycles());
548 assert!(graph.get_parallel_groups().is_err());
549 }
550
551 #[test]
552 fn test_complex_dependency_graph() {
553 let mut graph = TaskGraph::new();
554
555 let task_a = create_test_task("a", vec![]);
562 let task_b = create_test_task("b", vec!["a".to_string()]);
563 let task_c = create_test_task("c", vec!["a".to_string()]);
564 let task_d = create_test_task("d", vec!["b".to_string(), "c".to_string()]);
565
566 graph.add_task("a", task_a).unwrap();
567 graph.add_task("b", task_b).unwrap();
568 graph.add_task("c", task_c).unwrap();
569 graph.add_task("d", task_d).unwrap();
570 graph.add_dependency_edges().unwrap(); assert!(!graph.has_cycles());
573 assert_eq!(graph.task_count(), 4);
574
575 let groups = graph.get_parallel_groups().unwrap();
576
577 assert_eq!(groups.len(), 3);
579 assert_eq!(groups[0].len(), 1); assert_eq!(groups[1].len(), 2); assert_eq!(groups[2].len(), 1); }
583
584 #[test]
585 fn test_missing_dependency() {
586 let mut graph = TaskGraph::new();
587
588 let task = create_test_task("dependent", vec!["missing".to_string()]);
590 graph.add_task("dependent", task).unwrap();
591
592 assert!(graph.add_dependency_edges().is_err());
594 }
595
596 #[test]
597 fn test_empty_graph() {
598 let graph = TaskGraph::new();
599
600 assert_eq!(graph.task_count(), 0);
601 assert!(!graph.has_cycles());
602
603 let groups = graph.get_parallel_groups().unwrap();
604 assert!(groups.is_empty());
605 }
606
607 #[test]
608 fn test_single_task_no_deps() {
609 let mut graph = TaskGraph::new();
610
611 let task = create_test_task("solo", vec![]);
612 graph.add_task("solo", task).unwrap();
613
614 assert_eq!(graph.task_count(), 1);
615 assert!(!graph.has_cycles());
616
617 let groups = graph.get_parallel_groups().unwrap();
618 assert_eq!(groups.len(), 1);
619 assert_eq!(groups[0].len(), 1);
620 }
621
622 #[test]
623 fn test_linear_chain() {
624 let mut graph = TaskGraph::new();
625
626 let task_a = create_test_task("a", vec![]);
628 let task_b = create_test_task("b", vec!["a".to_string()]);
629 let task_c = create_test_task("c", vec!["b".to_string()]);
630 let task_d = create_test_task("d", vec!["c".to_string()]);
631
632 graph.add_task("a", task_a).unwrap();
633 graph.add_task("b", task_b).unwrap();
634 graph.add_task("c", task_c).unwrap();
635 graph.add_task("d", task_d).unwrap();
636 graph.add_dependency_edges().unwrap(); assert!(!graph.has_cycles());
639 assert_eq!(graph.task_count(), 4);
640
641 let groups = graph.get_parallel_groups().unwrap();
642
643 assert_eq!(groups.len(), 4);
645 for group in &groups {
646 assert_eq!(group.len(), 1);
647 }
648 }
649}