1use super::{Task, TaskDefinition, TaskGroup, Tasks};
7use crate::Result;
8use petgraph::algo::{is_cyclic_directed, toposort};
9use petgraph::graph::{DiGraph, NodeIndex};
10use petgraph::visit::IntoNodeReferences;
11use std::collections::{HashMap, HashSet};
12
13#[derive(Debug, Clone)]
15pub struct TaskNode {
16 pub name: String,
18 pub task: Task,
20}
21
22pub struct TaskGraph {
24 graph: DiGraph<TaskNode, ()>,
26 name_to_node: HashMap<String, NodeIndex>,
28}
29
30impl TaskGraph {
31 pub fn new() -> Self {
33 Self {
34 graph: DiGraph::new(),
35 name_to_node: HashMap::new(),
36 }
37 }
38
39 pub fn build_from_definition(
41 &mut self,
42 name: &str,
43 definition: &TaskDefinition,
44 all_tasks: &Tasks,
45 ) -> Result<Vec<NodeIndex>> {
46 match definition {
47 TaskDefinition::Single(task) => {
48 let node = self.add_task(name, task.clone())?;
49 Ok(vec![node])
50 }
51 TaskDefinition::Group(group) => self.build_from_group(name, group, all_tasks),
52 }
53 }
54
55 fn build_from_group(
57 &mut self,
58 prefix: &str,
59 group: &TaskGroup,
60 all_tasks: &Tasks,
61 ) -> Result<Vec<NodeIndex>> {
62 match group {
63 TaskGroup::Sequential(tasks) => self.build_sequential_group(prefix, tasks, all_tasks),
64 TaskGroup::Parallel(tasks) => self.build_parallel_group(prefix, tasks, all_tasks),
65 }
66 }
67
68 fn build_sequential_group(
70 &mut self,
71 prefix: &str,
72 tasks: &[TaskDefinition],
73 all_tasks: &Tasks,
74 ) -> Result<Vec<NodeIndex>> {
75 let mut nodes = Vec::new();
76 let mut previous: Option<NodeIndex> = None;
77
78 for (i, task_def) in tasks.iter().enumerate() {
79 let task_name = format!("{}[{}]", prefix, i);
80 let task_nodes = self.build_from_definition(&task_name, task_def, all_tasks)?;
81
82 if let Some(prev) = previous
84 && let Some(first) = task_nodes.first()
85 {
86 self.graph.add_edge(prev, *first, ());
87 }
88
89 if let Some(last) = task_nodes.last() {
90 previous = Some(*last);
91 }
92
93 nodes.extend(task_nodes);
94 }
95
96 Ok(nodes)
97 }
98
99 fn build_parallel_group(
101 &mut self,
102 prefix: &str,
103 tasks: &HashMap<String, TaskDefinition>,
104 all_tasks: &Tasks,
105 ) -> Result<Vec<NodeIndex>> {
106 let mut nodes = Vec::new();
107
108 for (name, task_def) in tasks {
109 let task_name = format!("{}.{}", prefix, name);
110 let task_nodes = self.build_from_definition(&task_name, task_def, all_tasks)?;
111 nodes.extend(task_nodes);
112 }
113
114 Ok(nodes)
115 }
116
117 pub fn add_task(&mut self, name: &str, task: Task) -> Result<NodeIndex> {
119 if let Some(&node) = self.name_to_node.get(name) {
121 return Ok(node);
122 }
123
124 let node = TaskNode {
125 name: name.to_string(),
126 task: task.clone(),
127 };
128
129 let node_index = self.graph.add_node(node);
130 self.name_to_node.insert(name.to_string(), node_index);
131
132 Ok(node_index)
133 }
134
135 fn add_dependency_edges(&mut self) -> Result<()> {
138 let mut missing_deps = Vec::new();
139 let mut edges_to_add = Vec::new();
140
141 for (node_index, node) in self.graph.node_references() {
143 for dep_name in &node.task.depends_on {
144 if let Some(&dep_node_index) = self.name_to_node.get(dep_name as &str) {
145 edges_to_add.push((dep_node_index, node_index));
147 } else {
148 missing_deps.push((node.name.clone(), dep_name.clone()));
149 }
150 }
151 }
152
153 if !missing_deps.is_empty() {
155 let missing_list = missing_deps
156 .iter()
157 .map(|(task, dep)| format!("Task '{}' depends on missing task '{}'", task, dep))
158 .collect::<Vec<_>>()
159 .join(", ");
160 return Err(crate::Error::configuration(format!(
161 "Missing dependencies: {}",
162 missing_list
163 )));
164 }
165
166 for (from, to) in edges_to_add {
168 self.graph.add_edge(from, to, ());
169 }
170
171 Ok(())
172 }
173
174 pub fn has_cycles(&self) -> bool {
176 is_cyclic_directed(&self.graph)
177 }
178
179 pub fn topological_sort(&self) -> Result<Vec<TaskNode>> {
181 if self.has_cycles() {
182 return Err(crate::Error::configuration(
183 "Task dependency graph contains cycles".to_string(),
184 ));
185 }
186
187 match toposort(&self.graph, None) {
188 Ok(sorted_indices) => Ok(sorted_indices
189 .into_iter()
190 .map(|idx| self.graph[idx].clone())
191 .collect()),
192 Err(_) => Err(crate::Error::configuration(
193 "Failed to sort tasks topologically".to_string(),
194 )),
195 }
196 }
197
198 pub fn get_parallel_groups(&self) -> Result<Vec<Vec<TaskNode>>> {
200 let sorted = self.topological_sort()?;
201
202 if sorted.is_empty() {
203 return Ok(vec![]);
204 }
205
206 let mut groups: Vec<Vec<TaskNode>> = vec![];
208 let mut processed: HashMap<String, usize> = HashMap::new();
209
210 for task in sorted {
211 let mut level = 0;
213 for dep in &task.task.depends_on {
214 if let Some(&dep_level) = processed.get(dep) {
215 level = level.max(dep_level + 1);
216 }
217 }
218
219 if level >= groups.len() {
221 groups.resize(level + 1, vec![]);
222 }
223 groups[level].push(task.clone());
224 processed.insert(task.name.clone(), level);
225 }
226
227 Ok(groups)
228 }
229
230 pub fn task_count(&self) -> usize {
232 self.graph.node_count()
233 }
234
235 pub fn contains_task(&self, name: &str) -> bool {
237 self.name_to_node.contains_key(name)
238 }
239
240 pub fn build_complete_graph(&mut self, tasks: &Tasks) -> Result<()> {
243 for (name, definition) in tasks.tasks.iter() {
245 match definition {
246 TaskDefinition::Single(task) => {
247 self.add_task(name, task.clone())?;
248 }
249 TaskDefinition::Group(_) => {
250 }
254 }
255 }
256
257 self.add_dependency_edges()?;
259
260 Ok(())
261 }
262
263 pub fn build_for_task(&mut self, task_name: &str, all_tasks: &Tasks) -> Result<()> {
265 let mut to_process = vec![task_name.to_string()];
266 let mut processed = HashSet::new();
267
268 while let Some(current_name) = to_process.pop() {
270 if processed.contains(¤t_name) {
271 continue;
272 }
273 processed.insert(current_name.clone());
274
275 if let Some(definition) = all_tasks.get(¤t_name) {
276 match definition {
277 TaskDefinition::Single(task) => {
278 self.add_task(¤t_name, task.clone())?;
279 for dep in &task.depends_on {
281 if !processed.contains(dep) {
282 to_process.push(dep.clone());
283 }
284 }
285 }
286 TaskDefinition::Group(_) => {
287 self.build_from_definition(¤t_name, definition, all_tasks)?;
289 }
290 }
291 }
292 }
293
294 self.add_dependency_edges()?;
296
297 Ok(())
298 }
299}
300
301impl Default for TaskGraph {
302 fn default() -> Self {
303 Self::new()
304 }
305}
306
307#[cfg(test)]
308mod tests {
309 use super::*;
310
311 fn create_test_task(name: &str, deps: Vec<String>) -> Task {
312 Task {
313 command: format!("echo {}", name),
314 args: vec![],
315 shell: None,
316 env: HashMap::new(),
317 depends_on: deps,
318 inputs: vec![],
319 outputs: vec![],
320 description: Some(format!("Test task {}", name)),
321 }
322 }
323
324 #[test]
325 fn test_task_graph_new() {
326 let graph = TaskGraph::new();
327 assert_eq!(graph.task_count(), 0);
328 }
329
330 #[test]
331 fn test_add_single_task() {
332 let mut graph = TaskGraph::new();
333 let task = create_test_task("test", vec![]);
334
335 let node = graph.add_task("test", task).unwrap();
336 assert!(graph.contains_task("test"));
337 assert_eq!(graph.task_count(), 1);
338
339 let task2 = create_test_task("test", vec![]);
341 let node2 = graph.add_task("test", task2).unwrap();
342 assert_eq!(node, node2);
343 assert_eq!(graph.task_count(), 1);
344 }
345
346 #[test]
347 fn test_task_dependencies() {
348 let mut graph = TaskGraph::new();
349
350 let task1 = create_test_task("task1", vec![]);
352 let task2 = create_test_task("task2", vec!["task1".to_string()]);
353 let task3 = create_test_task("task3", vec!["task1".to_string(), "task2".to_string()]);
354
355 graph.add_task("task1", task1).unwrap();
356 graph.add_task("task2", task2).unwrap();
357 graph.add_task("task3", task3).unwrap();
358 graph.add_dependency_edges().unwrap(); assert_eq!(graph.task_count(), 3);
361 assert!(!graph.has_cycles());
362
363 let sorted = graph.topological_sort().unwrap();
364 assert_eq!(sorted.len(), 3);
365
366 let positions: HashMap<String, usize> = sorted
368 .iter()
369 .enumerate()
370 .map(|(i, node)| (node.name.clone(), i))
371 .collect();
372
373 assert!(positions["task1"] < positions["task2"]);
374 assert!(positions["task1"] < positions["task3"]);
375 assert!(positions["task2"] < positions["task3"]);
376 }
377
378 #[test]
379 fn test_cycle_detection() {
380 let mut graph = TaskGraph::new();
381
382 let task1 = create_test_task("task1", vec!["task3".to_string()]);
384 let task2 = create_test_task("task2", vec!["task1".to_string()]);
385 let task3 = create_test_task("task3", vec!["task2".to_string()]);
386
387 graph.add_task("task1", task1).unwrap();
388 graph.add_task("task2", task2).unwrap();
389 graph.add_task("task3", task3).unwrap();
390 graph.add_dependency_edges().unwrap(); assert!(graph.has_cycles());
393 assert!(graph.topological_sort().is_err());
394 }
395
396 #[test]
397 fn test_parallel_groups() {
398 let mut graph = TaskGraph::new();
399
400 let task1 = create_test_task("task1", vec![]);
406 let task2 = create_test_task("task2", vec![]);
407 let task3 = create_test_task("task3", vec!["task1".to_string()]);
408 let task4 = create_test_task("task4", vec!["task2".to_string()]);
409 let task5 = create_test_task("task5", vec!["task3".to_string(), "task4".to_string()]);
410
411 graph.add_task("task1", task1).unwrap();
412 graph.add_task("task2", task2).unwrap();
413 graph.add_task("task3", task3).unwrap();
414 graph.add_task("task4", task4).unwrap();
415 graph.add_task("task5", task5).unwrap();
416 graph.add_dependency_edges().unwrap(); let groups = graph.get_parallel_groups().unwrap();
419
420 assert_eq!(groups.len(), 3);
422
423 assert_eq!(groups[0].len(), 2);
425
426 assert_eq!(groups[1].len(), 2);
428
429 assert_eq!(groups[2].len(), 1);
431 assert_eq!(groups[2][0].name, "task5");
432 }
433
434 #[test]
435 fn test_build_from_sequential_group() {
436 let mut graph = TaskGraph::new();
437 let tasks = Tasks::new();
438
439 let task1 = create_test_task("t1", vec![]);
440 let task2 = create_test_task("t2", vec![]);
441
442 let group = TaskGroup::Sequential(vec![
443 TaskDefinition::Single(task1),
444 TaskDefinition::Single(task2),
445 ]);
446
447 let nodes = graph.build_from_group("seq", &group, &tasks).unwrap();
448 assert_eq!(nodes.len(), 2);
449
450 let sorted = graph.topological_sort().unwrap();
452 assert_eq!(sorted.len(), 2);
453 assert_eq!(sorted[0].name, "seq[0]");
454 assert_eq!(sorted[1].name, "seq[1]");
455 }
456
457 #[test]
458 fn test_build_from_parallel_group() {
459 let mut graph = TaskGraph::new();
460 let tasks = Tasks::new();
461
462 let task1 = create_test_task("t1", vec![]);
463 let task2 = create_test_task("t2", vec![]);
464
465 let mut parallel_tasks = HashMap::new();
466 parallel_tasks.insert("first".to_string(), TaskDefinition::Single(task1));
467 parallel_tasks.insert("second".to_string(), TaskDefinition::Single(task2));
468
469 let group = TaskGroup::Parallel(parallel_tasks);
470
471 let nodes = graph.build_from_group("par", &group, &tasks).unwrap();
472 assert_eq!(nodes.len(), 2);
473
474 assert!(!graph.has_cycles());
476
477 let groups = graph.get_parallel_groups().unwrap();
478 assert_eq!(groups.len(), 1); assert_eq!(groups[0].len(), 2); }
481
482 #[test]
483 fn test_three_way_cycle_detection() {
484 let mut graph = TaskGraph::new();
485
486 let task_a = create_test_task("task_a", vec!["task_c".to_string()]);
488 let task_b = create_test_task("task_b", vec!["task_a".to_string()]);
489 let task_c = create_test_task("task_c", vec!["task_b".to_string()]);
490
491 graph.add_task("task_a", task_a).unwrap();
492 graph.add_task("task_b", task_b).unwrap();
493 graph.add_task("task_c", task_c).unwrap();
494 graph.add_dependency_edges().unwrap(); assert!(graph.has_cycles());
498
499 assert!(graph.get_parallel_groups().is_err());
501 }
502
503 #[test]
504 fn test_self_dependency_cycle() {
505 let mut graph = TaskGraph::new();
506
507 let task = create_test_task("self_ref", vec!["self_ref".to_string()]);
509 graph.add_task("self_ref", task).unwrap();
510 graph.add_dependency_edges().unwrap(); assert!(graph.has_cycles());
513 assert!(graph.get_parallel_groups().is_err());
514 }
515
516 #[test]
517 fn test_complex_dependency_graph() {
518 let mut graph = TaskGraph::new();
519
520 let task_a = create_test_task("a", vec![]);
527 let task_b = create_test_task("b", vec!["a".to_string()]);
528 let task_c = create_test_task("c", vec!["a".to_string()]);
529 let task_d = create_test_task("d", vec!["b".to_string(), "c".to_string()]);
530
531 graph.add_task("a", task_a).unwrap();
532 graph.add_task("b", task_b).unwrap();
533 graph.add_task("c", task_c).unwrap();
534 graph.add_task("d", task_d).unwrap();
535 graph.add_dependency_edges().unwrap(); assert!(!graph.has_cycles());
538 assert_eq!(graph.task_count(), 4);
539
540 let groups = graph.get_parallel_groups().unwrap();
541
542 assert_eq!(groups.len(), 3);
544 assert_eq!(groups[0].len(), 1); assert_eq!(groups[1].len(), 2); assert_eq!(groups[2].len(), 1); }
548
549 #[test]
550 fn test_missing_dependency() {
551 let mut graph = TaskGraph::new();
552
553 let task = create_test_task("dependent", vec!["missing".to_string()]);
555 graph.add_task("dependent", task).unwrap();
556
557 assert!(graph.add_dependency_edges().is_err());
559 }
560
561 #[test]
562 fn test_empty_graph() {
563 let graph = TaskGraph::new();
564
565 assert_eq!(graph.task_count(), 0);
566 assert!(!graph.has_cycles());
567
568 let groups = graph.get_parallel_groups().unwrap();
569 assert!(groups.is_empty());
570 }
571
572 #[test]
573 fn test_single_task_no_deps() {
574 let mut graph = TaskGraph::new();
575
576 let task = create_test_task("solo", vec![]);
577 graph.add_task("solo", task).unwrap();
578
579 assert_eq!(graph.task_count(), 1);
580 assert!(!graph.has_cycles());
581
582 let groups = graph.get_parallel_groups().unwrap();
583 assert_eq!(groups.len(), 1);
584 assert_eq!(groups[0].len(), 1);
585 }
586
587 #[test]
588 fn test_linear_chain() {
589 let mut graph = TaskGraph::new();
590
591 let task_a = create_test_task("a", vec![]);
593 let task_b = create_test_task("b", vec!["a".to_string()]);
594 let task_c = create_test_task("c", vec!["b".to_string()]);
595 let task_d = create_test_task("d", vec!["c".to_string()]);
596
597 graph.add_task("a", task_a).unwrap();
598 graph.add_task("b", task_b).unwrap();
599 graph.add_task("c", task_c).unwrap();
600 graph.add_task("d", task_d).unwrap();
601 graph.add_dependency_edges().unwrap(); assert!(!graph.has_cycles());
604 assert_eq!(graph.task_count(), 4);
605
606 let groups = graph.get_parallel_groups().unwrap();
607
608 assert_eq!(groups.len(), 4);
610 for group in &groups {
611 assert_eq!(group.len(), 1);
612 }
613 }
614}