cuenv_core/tasks/
graph.rs

1//! Task graph builder using petgraph
2//!
3//! This module builds directed acyclic graphs (DAGs) from task definitions
4//! to handle dependencies and determine execution order.
5
6use super::{ParallelGroup, Task, TaskDefinition, TaskGroup, Tasks};
7use crate::Result;
8use petgraph::algo::{is_cyclic_directed, toposort};
9use petgraph::graph::{DiGraph, NodeIndex};
10use petgraph::visit::IntoNodeReferences;
11use std::collections::{HashMap, HashSet};
12use tracing::debug;
13
14/// A node in the task graph
15#[derive(Debug, Clone)]
16pub struct TaskNode {
17    /// Name of the task
18    pub name: String,
19    /// The task to execute
20    pub task: Task,
21}
22
23/// Task graph for dependency resolution and execution ordering
24pub struct TaskGraph {
25    /// The directed graph of tasks
26    graph: DiGraph<TaskNode, ()>,
27    /// Map from task names to node indices
28    name_to_node: HashMap<String, NodeIndex>,
29}
30
31impl TaskGraph {
32    /// Create a new empty task graph
33    pub fn new() -> Self {
34        Self {
35            graph: DiGraph::new(),
36            name_to_node: HashMap::new(),
37        }
38    }
39
40    /// Build a graph from a task definition
41    pub fn build_from_definition(
42        &mut self,
43        name: &str,
44        definition: &TaskDefinition,
45        all_tasks: &Tasks,
46    ) -> Result<Vec<NodeIndex>> {
47        match definition {
48            TaskDefinition::Single(task) => {
49                let node = self.add_task(name, task.as_ref().clone())?;
50                Ok(vec![node])
51            }
52            TaskDefinition::Group(group) => self.build_from_group(name, group, all_tasks),
53        }
54    }
55
56    /// Build a graph from a task group
57    fn build_from_group(
58        &mut self,
59        prefix: &str,
60        group: &TaskGroup,
61        all_tasks: &Tasks,
62    ) -> Result<Vec<NodeIndex>> {
63        match group {
64            TaskGroup::Sequential(tasks) => self.build_sequential_group(prefix, tasks, all_tasks),
65            TaskGroup::Parallel(group) => self.build_parallel_group(prefix, group, all_tasks),
66        }
67    }
68
69    /// Build a sequential task group (tasks run one after another)
70    fn build_sequential_group(
71        &mut self,
72        prefix: &str,
73        tasks: &[TaskDefinition],
74        all_tasks: &Tasks,
75    ) -> Result<Vec<NodeIndex>> {
76        let mut nodes = Vec::new();
77        let mut previous: Option<NodeIndex> = None;
78
79        for (i, task_def) in tasks.iter().enumerate() {
80            let task_name = format!("{}[{}]", prefix, i);
81            let task_nodes = self.build_from_definition(&task_name, task_def, all_tasks)?;
82
83            // For sequential execution, link previous task to current
84            if let Some(prev) = previous
85                && let Some(first) = task_nodes.first()
86            {
87                self.graph.add_edge(prev, *first, ());
88            }
89
90            if let Some(last) = task_nodes.last() {
91                previous = Some(*last);
92            }
93
94            nodes.extend(task_nodes);
95        }
96
97        Ok(nodes)
98    }
99
100    /// Build a parallel task group (tasks can run concurrently)
101    fn build_parallel_group(
102        &mut self,
103        prefix: &str,
104        group: &ParallelGroup,
105        all_tasks: &Tasks,
106    ) -> Result<Vec<NodeIndex>> {
107        let mut nodes = Vec::new();
108
109        for (name, task_def) in &group.tasks {
110            let task_name = format!("{}.{}", prefix, name);
111            let task_nodes = self.build_from_definition(&task_name, task_def, all_tasks)?;
112
113            // Apply group-level dependencies to each subtask
114            if !group.depends_on.is_empty() {
115                for node_idx in &task_nodes {
116                    let node = &mut self.graph[*node_idx];
117                    for dep in &group.depends_on {
118                        if !node.task.depends_on.contains(dep) {
119                            node.task.depends_on.push(dep.clone());
120                        }
121                    }
122                }
123            }
124
125            nodes.extend(task_nodes);
126        }
127
128        Ok(nodes)
129    }
130
131    /// Add a single task to the graph
132    pub fn add_task(&mut self, name: &str, task: Task) -> Result<NodeIndex> {
133        // Check if task already exists
134        if let Some(&node) = self.name_to_node.get(name) {
135            return Ok(node);
136        }
137
138        let node = TaskNode {
139            name: name.to_string(),
140            task,
141        };
142
143        let node_index = self.graph.add_node(node);
144        self.name_to_node.insert(name.to_string(), node_index);
145        debug!("Added task node '{}'", name);
146
147        Ok(node_index)
148    }
149
150    /// Add dependency edges after all tasks have been added
151    /// This ensures proper cycle detection and missing dependency validation
152    fn add_dependency_edges(&mut self) -> Result<()> {
153        let mut missing_deps = Vec::new();
154        let mut edges_to_add = Vec::new();
155
156        // Collect all dependency relationships
157        for (node_index, node) in self.graph.node_references() {
158            for dep_name in &node.task.depends_on {
159                if let Some(&dep_node_index) = self.name_to_node.get(dep_name as &str) {
160                    // Record edge to add later
161                    edges_to_add.push((dep_node_index, node_index));
162                } else {
163                    missing_deps.push((node.name.clone(), dep_name.clone()));
164                }
165            }
166        }
167
168        // Report missing dependencies
169        if !missing_deps.is_empty() {
170            let missing_list = missing_deps
171                .iter()
172                .map(|(task, dep)| format!("Task '{}' depends on missing task '{}'", task, dep))
173                .collect::<Vec<_>>()
174                .join(", ");
175            return Err(crate::Error::configuration(format!(
176                "Missing dependencies: {}",
177                missing_list
178            )));
179        }
180
181        // Add all edges
182        for (from, to) in edges_to_add {
183            self.graph.add_edge(from, to, ());
184        }
185
186        Ok(())
187    }
188
189    /// Check if the graph has cycles
190    pub fn has_cycles(&self) -> bool {
191        is_cyclic_directed(&self.graph)
192    }
193
194    /// Get topologically sorted list of tasks
195    pub fn topological_sort(&self) -> Result<Vec<TaskNode>> {
196        if self.has_cycles() {
197            return Err(crate::Error::configuration(
198                "Task dependency graph contains cycles".to_string(),
199            ));
200        }
201
202        match toposort(&self.graph, None) {
203            Ok(sorted_indices) => Ok(sorted_indices
204                .into_iter()
205                .map(|idx| self.graph[idx].clone())
206                .collect()),
207            Err(_) => Err(crate::Error::configuration(
208                "Failed to sort tasks topologically".to_string(),
209            )),
210        }
211    }
212
213    /// Get all tasks that can run in parallel (no dependencies between them)
214    pub fn get_parallel_groups(&self) -> Result<Vec<Vec<TaskNode>>> {
215        let sorted = self.topological_sort()?;
216
217        if sorted.is_empty() {
218            return Ok(vec![]);
219        }
220
221        // Group tasks by their dependency level
222        let mut groups: Vec<Vec<TaskNode>> = vec![];
223        let mut processed: HashMap<String, usize> = HashMap::new();
224
225        for task in sorted {
226            // Find the maximum level of all dependencies
227            let mut level = 0;
228            for dep in &task.task.depends_on {
229                if let Some(&dep_level) = processed.get(dep) {
230                    level = level.max(dep_level + 1);
231                }
232            }
233
234            // Add to appropriate group
235            if level >= groups.len() {
236                groups.resize(level + 1, vec![]);
237            }
238            groups[level].push(task.clone());
239            processed.insert(task.name.clone(), level);
240        }
241
242        Ok(groups)
243    }
244
245    /// Get the number of tasks in the graph
246    pub fn task_count(&self) -> usize {
247        self.graph.node_count()
248    }
249
250    /// Check if a task exists in the graph
251    pub fn contains_task(&self, name: &str) -> bool {
252        self.name_to_node.contains_key(name)
253    }
254
255    /// Build a complete graph from tasks with proper dependency resolution
256    /// This performs a two-pass build: first adding all nodes, then all edges
257    pub fn build_complete_graph(&mut self, tasks: &Tasks) -> Result<()> {
258        // First pass: Add all tasks as nodes
259        for (name, definition) in tasks.tasks.iter() {
260            match definition {
261                TaskDefinition::Single(task) => {
262                    self.add_task(name, task.as_ref().clone())?;
263                }
264                TaskDefinition::Group(_) => {
265                    // For groups, we'd need to expand them - this is more complex
266                    // and not needed for the current fix. Groups should be handled
267                    // by build_from_definition which already works correctly.
268                }
269            }
270        }
271
272        // Second pass: Add all dependency edges
273        self.add_dependency_edges()?;
274
275        Ok(())
276    }
277
278    /// Build graph for a specific task and all its transitive dependencies
279    pub fn build_for_task(&mut self, task_name: &str, all_tasks: &Tasks) -> Result<()> {
280        let mut to_process = vec![task_name.to_string()];
281        let mut processed = HashSet::new();
282
283        debug!(
284            "Building graph for '{}' with tasks {:?}",
285            task_name,
286            all_tasks.list_tasks()
287        );
288
289        // First pass: Collect all tasks that need to be included
290        while let Some(current_name) = to_process.pop() {
291            if processed.contains(&current_name) {
292                continue;
293            }
294            processed.insert(current_name.clone());
295
296            if let Some(definition) = all_tasks.get(&current_name) {
297                match definition {
298                    TaskDefinition::Single(task) => {
299                        self.add_task(&current_name, task.as_ref().clone())?;
300                        // Add dependencies to processing queue
301                        for dep in &task.depends_on {
302                            if !processed.contains(dep) {
303                                to_process.push(dep.clone());
304                            }
305                        }
306                    }
307                    TaskDefinition::Group(_) => {
308                        // Handle groups with build_from_definition
309                        let added_nodes =
310                            self.build_from_definition(&current_name, definition, all_tasks)?;
311                        // Collect dependencies from newly added tasks
312                        for node_idx in added_nodes {
313                            let node = &self.graph[node_idx];
314                            for dep in &node.task.depends_on {
315                                if !processed.contains(dep) {
316                                    to_process.push(dep.clone());
317                                }
318                            }
319                        }
320                    }
321                }
322            } else {
323                debug!("Task '{}' not found while building graph", current_name);
324            }
325        }
326
327        // Second pass: Add dependency edges
328        self.add_dependency_edges()?;
329
330        Ok(())
331    }
332}
333
334impl Default for TaskGraph {
335    fn default() -> Self {
336        Self::new()
337    }
338}
339
340#[cfg(test)]
341mod tests {
342    use super::*;
343
344    fn create_test_task(name: &str, deps: Vec<String>) -> Task {
345        Task {
346            command: format!("echo {}", name),
347            depends_on: deps,
348            description: Some(format!("Test task {}", name)),
349            ..Default::default()
350        }
351    }
352
353    #[test]
354    fn test_task_graph_new() {
355        let graph = TaskGraph::new();
356        assert_eq!(graph.task_count(), 0);
357    }
358
359    #[test]
360    fn test_add_single_task() {
361        let mut graph = TaskGraph::new();
362        let task = create_test_task("test", vec![]);
363
364        let node = graph.add_task("test", task).unwrap();
365        assert!(graph.contains_task("test"));
366        assert_eq!(graph.task_count(), 1);
367
368        // Adding same task again should return same node
369        let task2 = create_test_task("test", vec![]);
370        let node2 = graph.add_task("test", task2).unwrap();
371        assert_eq!(node, node2);
372        assert_eq!(graph.task_count(), 1);
373    }
374
375    #[test]
376    fn test_task_dependencies() {
377        let mut graph = TaskGraph::new();
378
379        // Add tasks with dependencies
380        let task1 = create_test_task("task1", vec![]);
381        let task2 = create_test_task("task2", vec!["task1".to_string()]);
382        let task3 = create_test_task("task3", vec!["task1".to_string(), "task2".to_string()]);
383
384        graph.add_task("task1", task1).unwrap();
385        graph.add_task("task2", task2).unwrap();
386        graph.add_task("task3", task3).unwrap();
387        graph.add_dependency_edges().unwrap(); // Add dependency edges after adding all tasks
388
389        assert_eq!(graph.task_count(), 3);
390        assert!(!graph.has_cycles());
391
392        let sorted = graph.topological_sort().unwrap();
393        assert_eq!(sorted.len(), 3);
394
395        // task1 should come before task2 and task3
396        let positions: HashMap<String, usize> = sorted
397            .iter()
398            .enumerate()
399            .map(|(i, node)| (node.name.clone(), i))
400            .collect();
401
402        assert!(positions["task1"] < positions["task2"]);
403        assert!(positions["task1"] < positions["task3"]);
404        assert!(positions["task2"] < positions["task3"]);
405    }
406
407    #[test]
408    fn test_cycle_detection() {
409        let mut graph = TaskGraph::new();
410
411        // Create a cycle: task1 -> task2 -> task3 -> task1
412        let task1 = create_test_task("task1", vec!["task3".to_string()]);
413        let task2 = create_test_task("task2", vec!["task1".to_string()]);
414        let task3 = create_test_task("task3", vec!["task2".to_string()]);
415
416        graph.add_task("task1", task1).unwrap();
417        graph.add_task("task2", task2).unwrap();
418        graph.add_task("task3", task3).unwrap();
419        graph.add_dependency_edges().unwrap(); // Add dependency edges after adding all tasks
420
421        assert!(graph.has_cycles());
422        assert!(graph.topological_sort().is_err());
423    }
424
425    #[test]
426    fn test_parallel_groups() {
427        let mut graph = TaskGraph::new();
428
429        // Create tasks that can run in parallel
430        // Level 0: task1, task2 (no dependencies)
431        // Level 1: task3 (depends on task1), task4 (depends on task2)
432        // Level 2: task5 (depends on task3 and task4)
433
434        let task1 = create_test_task("task1", vec![]);
435        let task2 = create_test_task("task2", vec![]);
436        let task3 = create_test_task("task3", vec!["task1".to_string()]);
437        let task4 = create_test_task("task4", vec!["task2".to_string()]);
438        let task5 = create_test_task("task5", vec!["task3".to_string(), "task4".to_string()]);
439
440        graph.add_task("task1", task1).unwrap();
441        graph.add_task("task2", task2).unwrap();
442        graph.add_task("task3", task3).unwrap();
443        graph.add_task("task4", task4).unwrap();
444        graph.add_task("task5", task5).unwrap();
445        graph.add_dependency_edges().unwrap(); // Add dependency edges after adding all tasks
446
447        let groups = graph.get_parallel_groups().unwrap();
448
449        // Should have 3 levels
450        assert_eq!(groups.len(), 3);
451
452        // Level 0 should have 2 tasks
453        assert_eq!(groups[0].len(), 2);
454
455        // Level 1 should have 2 tasks
456        assert_eq!(groups[1].len(), 2);
457
458        // Level 2 should have 1 task
459        assert_eq!(groups[2].len(), 1);
460        assert_eq!(groups[2][0].name, "task5");
461    }
462
463    #[test]
464    fn test_build_from_sequential_group() {
465        let mut graph = TaskGraph::new();
466        let tasks = Tasks::new();
467
468        let task1 = create_test_task("t1", vec![]);
469        let task2 = create_test_task("t2", vec![]);
470
471        let group = TaskGroup::Sequential(vec![
472            TaskDefinition::Single(Box::new(task1)),
473            TaskDefinition::Single(Box::new(task2)),
474        ]);
475
476        let nodes = graph.build_from_group("seq", &group, &tasks).unwrap();
477        assert_eq!(nodes.len(), 2);
478
479        // Sequential tasks should have dependency chain
480        let sorted = graph.topological_sort().unwrap();
481        assert_eq!(sorted.len(), 2);
482        assert_eq!(sorted[0].name, "seq[0]");
483        assert_eq!(sorted[1].name, "seq[1]");
484    }
485
486    #[test]
487    fn test_build_from_parallel_group() {
488        let mut graph = TaskGraph::new();
489        let tasks = Tasks::new();
490
491        let task1 = create_test_task("t1", vec![]);
492        let task2 = create_test_task("t2", vec![]);
493
494        let mut parallel_tasks = HashMap::new();
495        parallel_tasks.insert("first".to_string(), TaskDefinition::Single(Box::new(task1)));
496        parallel_tasks.insert(
497            "second".to_string(),
498            TaskDefinition::Single(Box::new(task2)),
499        );
500
501        let group = TaskGroup::Parallel(ParallelGroup {
502            tasks: parallel_tasks,
503            depends_on: vec![],
504        });
505
506        let nodes = graph.build_from_group("par", &group, &tasks).unwrap();
507        assert_eq!(nodes.len(), 2);
508
509        // Parallel tasks should not have dependencies between them
510        assert!(!graph.has_cycles());
511
512        let groups = graph.get_parallel_groups().unwrap();
513        assert_eq!(groups.len(), 1); // All in same level
514        assert_eq!(groups[0].len(), 2); // Both can run in parallel
515    }
516
517    #[test]
518    fn test_three_way_cycle_detection() {
519        let mut graph = TaskGraph::new();
520
521        // Create cyclic dependencies: A -> B -> C -> A
522        let task_a = create_test_task("task_a", vec!["task_c".to_string()]);
523        let task_b = create_test_task("task_b", vec!["task_a".to_string()]);
524        let task_c = create_test_task("task_c", vec!["task_b".to_string()]);
525
526        graph.add_task("task_a", task_a).unwrap();
527        graph.add_task("task_b", task_b).unwrap();
528        graph.add_task("task_c", task_c).unwrap();
529        graph.add_dependency_edges().unwrap(); // Add dependency edges after adding all tasks
530
531        // This should create a cycle
532        assert!(graph.has_cycles());
533
534        // Should fail when trying to get parallel groups
535        assert!(graph.get_parallel_groups().is_err());
536    }
537
538    #[test]
539    fn test_self_dependency_cycle() {
540        let mut graph = TaskGraph::new();
541
542        // Create self-referencing task
543        let task = create_test_task("self_ref", vec!["self_ref".to_string()]);
544        graph.add_task("self_ref", task).unwrap();
545        graph.add_dependency_edges().unwrap(); // Add dependency edges after adding all tasks
546
547        assert!(graph.has_cycles());
548        assert!(graph.get_parallel_groups().is_err());
549    }
550
551    #[test]
552    fn test_complex_dependency_graph() {
553        let mut graph = TaskGraph::new();
554
555        // Create a diamond dependency pattern:
556        //     A
557        //    / \
558        //   B   C
559        //    \ /
560        //     D
561        let task_a = create_test_task("a", vec![]);
562        let task_b = create_test_task("b", vec!["a".to_string()]);
563        let task_c = create_test_task("c", vec!["a".to_string()]);
564        let task_d = create_test_task("d", vec!["b".to_string(), "c".to_string()]);
565
566        graph.add_task("a", task_a).unwrap();
567        graph.add_task("b", task_b).unwrap();
568        graph.add_task("c", task_c).unwrap();
569        graph.add_task("d", task_d).unwrap();
570        graph.add_dependency_edges().unwrap(); // Add dependency edges after adding all tasks
571
572        assert!(!graph.has_cycles());
573        assert_eq!(graph.task_count(), 4);
574
575        let groups = graph.get_parallel_groups().unwrap();
576
577        // Should have 3 levels: [A], [B,C], [D]
578        assert_eq!(groups.len(), 3);
579        assert_eq!(groups[0].len(), 1); // A
580        assert_eq!(groups[1].len(), 2); // B and C can run in parallel
581        assert_eq!(groups[2].len(), 1); // D
582    }
583
584    #[test]
585    fn test_missing_dependency() {
586        let mut graph = TaskGraph::new();
587
588        // Create task with dependency that doesn't exist
589        let task = create_test_task("dependent", vec!["missing".to_string()]);
590        graph.add_task("dependent", task).unwrap();
591
592        // Should fail to get parallel groups due to missing dependency
593        assert!(graph.add_dependency_edges().is_err());
594    }
595
596    #[test]
597    fn test_empty_graph() {
598        let graph = TaskGraph::new();
599
600        assert_eq!(graph.task_count(), 0);
601        assert!(!graph.has_cycles());
602
603        let groups = graph.get_parallel_groups().unwrap();
604        assert!(groups.is_empty());
605    }
606
607    #[test]
608    fn test_single_task_no_deps() {
609        let mut graph = TaskGraph::new();
610
611        let task = create_test_task("solo", vec![]);
612        graph.add_task("solo", task).unwrap();
613
614        assert_eq!(graph.task_count(), 1);
615        assert!(!graph.has_cycles());
616
617        let groups = graph.get_parallel_groups().unwrap();
618        assert_eq!(groups.len(), 1);
619        assert_eq!(groups[0].len(), 1);
620    }
621
622    #[test]
623    fn test_linear_chain() {
624        let mut graph = TaskGraph::new();
625
626        // Create linear chain: A -> B -> C -> D
627        let task_a = create_test_task("a", vec![]);
628        let task_b = create_test_task("b", vec!["a".to_string()]);
629        let task_c = create_test_task("c", vec!["b".to_string()]);
630        let task_d = create_test_task("d", vec!["c".to_string()]);
631
632        graph.add_task("a", task_a).unwrap();
633        graph.add_task("b", task_b).unwrap();
634        graph.add_task("c", task_c).unwrap();
635        graph.add_task("d", task_d).unwrap();
636        graph.add_dependency_edges().unwrap(); // Add dependency edges after adding all tasks
637
638        assert!(!graph.has_cycles());
639        assert_eq!(graph.task_count(), 4);
640
641        let groups = graph.get_parallel_groups().unwrap();
642
643        // Should be 4 sequential groups
644        assert_eq!(groups.len(), 4);
645        for group in &groups {
646            assert_eq!(group.len(), 1);
647        }
648    }
649}