cuenv_core/tasks/
graph.rs

1//! Task graph builder using petgraph
2//!
3//! This module builds directed acyclic graphs (DAGs) from task definitions
4//! to handle dependencies and determine execution order.
5
6use super::{Task, TaskDefinition, TaskGroup, Tasks};
7use crate::Result;
8use petgraph::algo::{is_cyclic_directed, toposort};
9use petgraph::graph::{DiGraph, NodeIndex};
10use petgraph::visit::IntoNodeReferences;
11use std::collections::{HashMap, HashSet};
12use tracing::debug;
13
14/// A node in the task graph
15#[derive(Debug, Clone)]
16pub struct TaskNode {
17    /// Name of the task
18    pub name: String,
19    /// The task to execute
20    pub task: Task,
21}
22
23/// Task graph for dependency resolution and execution ordering
24pub struct TaskGraph {
25    /// The directed graph of tasks
26    graph: DiGraph<TaskNode, ()>,
27    /// Map from task names to node indices
28    name_to_node: HashMap<String, NodeIndex>,
29}
30
31impl TaskGraph {
32    /// Create a new empty task graph
33    pub fn new() -> Self {
34        Self {
35            graph: DiGraph::new(),
36            name_to_node: HashMap::new(),
37        }
38    }
39
40    /// Build a graph from a task definition
41    pub fn build_from_definition(
42        &mut self,
43        name: &str,
44        definition: &TaskDefinition,
45        all_tasks: &Tasks,
46    ) -> Result<Vec<NodeIndex>> {
47        match definition {
48            TaskDefinition::Single(task) => {
49                let node = self.add_task(name, task.as_ref().clone())?;
50                Ok(vec![node])
51            }
52            TaskDefinition::Group(group) => self.build_from_group(name, group, all_tasks),
53        }
54    }
55
56    /// Build a graph from a task group
57    fn build_from_group(
58        &mut self,
59        prefix: &str,
60        group: &TaskGroup,
61        all_tasks: &Tasks,
62    ) -> Result<Vec<NodeIndex>> {
63        match group {
64            TaskGroup::Sequential(tasks) => self.build_sequential_group(prefix, tasks, all_tasks),
65            TaskGroup::Parallel(tasks) => self.build_parallel_group(prefix, tasks, all_tasks),
66        }
67    }
68
69    /// Build a sequential task group (tasks run one after another)
70    fn build_sequential_group(
71        &mut self,
72        prefix: &str,
73        tasks: &[TaskDefinition],
74        all_tasks: &Tasks,
75    ) -> Result<Vec<NodeIndex>> {
76        let mut nodes = Vec::new();
77        let mut previous: Option<NodeIndex> = None;
78
79        for (i, task_def) in tasks.iter().enumerate() {
80            let task_name = format!("{}[{}]", prefix, i);
81            let task_nodes = self.build_from_definition(&task_name, task_def, all_tasks)?;
82
83            // For sequential execution, link previous task to current
84            if let Some(prev) = previous
85                && let Some(first) = task_nodes.first()
86            {
87                self.graph.add_edge(prev, *first, ());
88            }
89
90            if let Some(last) = task_nodes.last() {
91                previous = Some(*last);
92            }
93
94            nodes.extend(task_nodes);
95        }
96
97        Ok(nodes)
98    }
99
100    /// Build a parallel task group (tasks can run concurrently)
101    fn build_parallel_group(
102        &mut self,
103        prefix: &str,
104        tasks: &HashMap<String, TaskDefinition>,
105        all_tasks: &Tasks,
106    ) -> Result<Vec<NodeIndex>> {
107        let mut nodes = Vec::new();
108
109        for (name, task_def) in tasks {
110            let task_name = format!("{}.{}", prefix, name);
111            let task_nodes = self.build_from_definition(&task_name, task_def, all_tasks)?;
112            nodes.extend(task_nodes);
113        }
114
115        Ok(nodes)
116    }
117
118    /// Add a single task to the graph
119    pub fn add_task(&mut self, name: &str, task: Task) -> Result<NodeIndex> {
120        // Check if task already exists
121        if let Some(&node) = self.name_to_node.get(name) {
122            return Ok(node);
123        }
124
125        let node = TaskNode {
126            name: name.to_string(),
127            task,
128        };
129
130        let node_index = self.graph.add_node(node);
131        self.name_to_node.insert(name.to_string(), node_index);
132        debug!("Added task node '{}'", name);
133
134        Ok(node_index)
135    }
136
137    /// Add dependency edges after all tasks have been added
138    /// This ensures proper cycle detection and missing dependency validation
139    fn add_dependency_edges(&mut self) -> Result<()> {
140        let mut missing_deps = Vec::new();
141        let mut edges_to_add = Vec::new();
142
143        // Collect all dependency relationships
144        for (node_index, node) in self.graph.node_references() {
145            for dep_name in &node.task.depends_on {
146                if let Some(&dep_node_index) = self.name_to_node.get(dep_name as &str) {
147                    // Record edge to add later
148                    edges_to_add.push((dep_node_index, node_index));
149                } else {
150                    missing_deps.push((node.name.clone(), dep_name.clone()));
151                }
152            }
153        }
154
155        // Report missing dependencies
156        if !missing_deps.is_empty() {
157            let missing_list = missing_deps
158                .iter()
159                .map(|(task, dep)| format!("Task '{}' depends on missing task '{}'", task, dep))
160                .collect::<Vec<_>>()
161                .join(", ");
162            return Err(crate::Error::configuration(format!(
163                "Missing dependencies: {}",
164                missing_list
165            )));
166        }
167
168        // Add all edges
169        for (from, to) in edges_to_add {
170            self.graph.add_edge(from, to, ());
171        }
172
173        Ok(())
174    }
175
176    /// Check if the graph has cycles
177    pub fn has_cycles(&self) -> bool {
178        is_cyclic_directed(&self.graph)
179    }
180
181    /// Get topologically sorted list of tasks
182    pub fn topological_sort(&self) -> Result<Vec<TaskNode>> {
183        if self.has_cycles() {
184            return Err(crate::Error::configuration(
185                "Task dependency graph contains cycles".to_string(),
186            ));
187        }
188
189        match toposort(&self.graph, None) {
190            Ok(sorted_indices) => Ok(sorted_indices
191                .into_iter()
192                .map(|idx| self.graph[idx].clone())
193                .collect()),
194            Err(_) => Err(crate::Error::configuration(
195                "Failed to sort tasks topologically".to_string(),
196            )),
197        }
198    }
199
200    /// Get all tasks that can run in parallel (no dependencies between them)
201    pub fn get_parallel_groups(&self) -> Result<Vec<Vec<TaskNode>>> {
202        let sorted = self.topological_sort()?;
203
204        if sorted.is_empty() {
205            return Ok(vec![]);
206        }
207
208        // Group tasks by their dependency level
209        let mut groups: Vec<Vec<TaskNode>> = vec![];
210        let mut processed: HashMap<String, usize> = HashMap::new();
211
212        for task in sorted {
213            // Find the maximum level of all dependencies
214            let mut level = 0;
215            for dep in &task.task.depends_on {
216                if let Some(&dep_level) = processed.get(dep) {
217                    level = level.max(dep_level + 1);
218                }
219            }
220
221            // Add to appropriate group
222            if level >= groups.len() {
223                groups.resize(level + 1, vec![]);
224            }
225            groups[level].push(task.clone());
226            processed.insert(task.name.clone(), level);
227        }
228
229        Ok(groups)
230    }
231
232    /// Get the number of tasks in the graph
233    pub fn task_count(&self) -> usize {
234        self.graph.node_count()
235    }
236
237    /// Check if a task exists in the graph
238    pub fn contains_task(&self, name: &str) -> bool {
239        self.name_to_node.contains_key(name)
240    }
241
242    /// Build a complete graph from tasks with proper dependency resolution
243    /// This performs a two-pass build: first adding all nodes, then all edges
244    pub fn build_complete_graph(&mut self, tasks: &Tasks) -> Result<()> {
245        // First pass: Add all tasks as nodes
246        for (name, definition) in tasks.tasks.iter() {
247            match definition {
248                TaskDefinition::Single(task) => {
249                    self.add_task(name, task.as_ref().clone())?;
250                }
251                TaskDefinition::Group(_) => {
252                    // For groups, we'd need to expand them - this is more complex
253                    // and not needed for the current fix. Groups should be handled
254                    // by build_from_definition which already works correctly.
255                }
256            }
257        }
258
259        // Second pass: Add all dependency edges
260        self.add_dependency_edges()?;
261
262        Ok(())
263    }
264
265    /// Build graph for a specific task and all its transitive dependencies
266    pub fn build_for_task(&mut self, task_name: &str, all_tasks: &Tasks) -> Result<()> {
267        let mut to_process = vec![task_name.to_string()];
268        let mut processed = HashSet::new();
269
270        debug!(
271            "Building graph for '{}' with tasks {:?}",
272            task_name,
273            all_tasks.list_tasks()
274        );
275
276        // First pass: Collect all tasks that need to be included
277        while let Some(current_name) = to_process.pop() {
278            if processed.contains(&current_name) {
279                continue;
280            }
281            processed.insert(current_name.clone());
282
283            if let Some(definition) = all_tasks.get(&current_name) {
284                match definition {
285                    TaskDefinition::Single(task) => {
286                        self.add_task(&current_name, task.as_ref().clone())?;
287                        // Add dependencies to processing queue
288                        for dep in &task.depends_on {
289                            if !processed.contains(dep) {
290                                to_process.push(dep.clone());
291                            }
292                        }
293                    }
294                    TaskDefinition::Group(_) => {
295                        // Handle groups with build_from_definition
296                        self.build_from_definition(&current_name, definition, all_tasks)?;
297                    }
298                }
299            } else {
300                debug!("Task '{}' not found while building graph", current_name);
301            }
302        }
303
304        // Second pass: Add dependency edges
305        self.add_dependency_edges()?;
306
307        Ok(())
308    }
309}
310
311impl Default for TaskGraph {
312    fn default() -> Self {
313        Self::new()
314    }
315}
316
317#[cfg(test)]
318mod tests {
319    use super::*;
320
321    fn create_test_task(name: &str, deps: Vec<String>) -> Task {
322        Task {
323            command: format!("echo {}", name),
324            args: vec![],
325            shell: None,
326            env: HashMap::new(),
327            depends_on: deps,
328            inputs: vec![],
329            outputs: vec![],
330            external_inputs: None,
331            workspaces: vec![],
332            description: Some(format!("Test task {}", name)),
333        }
334    }
335
336    #[test]
337    fn test_task_graph_new() {
338        let graph = TaskGraph::new();
339        assert_eq!(graph.task_count(), 0);
340    }
341
342    #[test]
343    fn test_add_single_task() {
344        let mut graph = TaskGraph::new();
345        let task = create_test_task("test", vec![]);
346
347        let node = graph.add_task("test", task).unwrap();
348        assert!(graph.contains_task("test"));
349        assert_eq!(graph.task_count(), 1);
350
351        // Adding same task again should return same node
352        let task2 = create_test_task("test", vec![]);
353        let node2 = graph.add_task("test", task2).unwrap();
354        assert_eq!(node, node2);
355        assert_eq!(graph.task_count(), 1);
356    }
357
358    #[test]
359    fn test_task_dependencies() {
360        let mut graph = TaskGraph::new();
361
362        // Add tasks with dependencies
363        let task1 = create_test_task("task1", vec![]);
364        let task2 = create_test_task("task2", vec!["task1".to_string()]);
365        let task3 = create_test_task("task3", vec!["task1".to_string(), "task2".to_string()]);
366
367        graph.add_task("task1", task1).unwrap();
368        graph.add_task("task2", task2).unwrap();
369        graph.add_task("task3", task3).unwrap();
370        graph.add_dependency_edges().unwrap(); // Add dependency edges after adding all tasks
371
372        assert_eq!(graph.task_count(), 3);
373        assert!(!graph.has_cycles());
374
375        let sorted = graph.topological_sort().unwrap();
376        assert_eq!(sorted.len(), 3);
377
378        // task1 should come before task2 and task3
379        let positions: HashMap<String, usize> = sorted
380            .iter()
381            .enumerate()
382            .map(|(i, node)| (node.name.clone(), i))
383            .collect();
384
385        assert!(positions["task1"] < positions["task2"]);
386        assert!(positions["task1"] < positions["task3"]);
387        assert!(positions["task2"] < positions["task3"]);
388    }
389
390    #[test]
391    fn test_cycle_detection() {
392        let mut graph = TaskGraph::new();
393
394        // Create a cycle: task1 -> task2 -> task3 -> task1
395        let task1 = create_test_task("task1", vec!["task3".to_string()]);
396        let task2 = create_test_task("task2", vec!["task1".to_string()]);
397        let task3 = create_test_task("task3", vec!["task2".to_string()]);
398
399        graph.add_task("task1", task1).unwrap();
400        graph.add_task("task2", task2).unwrap();
401        graph.add_task("task3", task3).unwrap();
402        graph.add_dependency_edges().unwrap(); // Add dependency edges after adding all tasks
403
404        assert!(graph.has_cycles());
405        assert!(graph.topological_sort().is_err());
406    }
407
408    #[test]
409    fn test_parallel_groups() {
410        let mut graph = TaskGraph::new();
411
412        // Create tasks that can run in parallel
413        // Level 0: task1, task2 (no dependencies)
414        // Level 1: task3 (depends on task1), task4 (depends on task2)
415        // Level 2: task5 (depends on task3 and task4)
416
417        let task1 = create_test_task("task1", vec![]);
418        let task2 = create_test_task("task2", vec![]);
419        let task3 = create_test_task("task3", vec!["task1".to_string()]);
420        let task4 = create_test_task("task4", vec!["task2".to_string()]);
421        let task5 = create_test_task("task5", vec!["task3".to_string(), "task4".to_string()]);
422
423        graph.add_task("task1", task1).unwrap();
424        graph.add_task("task2", task2).unwrap();
425        graph.add_task("task3", task3).unwrap();
426        graph.add_task("task4", task4).unwrap();
427        graph.add_task("task5", task5).unwrap();
428        graph.add_dependency_edges().unwrap(); // Add dependency edges after adding all tasks
429
430        let groups = graph.get_parallel_groups().unwrap();
431
432        // Should have 3 levels
433        assert_eq!(groups.len(), 3);
434
435        // Level 0 should have 2 tasks
436        assert_eq!(groups[0].len(), 2);
437
438        // Level 1 should have 2 tasks
439        assert_eq!(groups[1].len(), 2);
440
441        // Level 2 should have 1 task
442        assert_eq!(groups[2].len(), 1);
443        assert_eq!(groups[2][0].name, "task5");
444    }
445
446    #[test]
447    fn test_build_from_sequential_group() {
448        let mut graph = TaskGraph::new();
449        let tasks = Tasks::new();
450
451        let task1 = create_test_task("t1", vec![]);
452        let task2 = create_test_task("t2", vec![]);
453
454        let group = TaskGroup::Sequential(vec![
455            TaskDefinition::Single(Box::new(task1)),
456            TaskDefinition::Single(Box::new(task2)),
457        ]);
458
459        let nodes = graph.build_from_group("seq", &group, &tasks).unwrap();
460        assert_eq!(nodes.len(), 2);
461
462        // Sequential tasks should have dependency chain
463        let sorted = graph.topological_sort().unwrap();
464        assert_eq!(sorted.len(), 2);
465        assert_eq!(sorted[0].name, "seq[0]");
466        assert_eq!(sorted[1].name, "seq[1]");
467    }
468
469    #[test]
470    fn test_build_from_parallel_group() {
471        let mut graph = TaskGraph::new();
472        let tasks = Tasks::new();
473
474        let task1 = create_test_task("t1", vec![]);
475        let task2 = create_test_task("t2", vec![]);
476
477        let mut parallel_tasks = HashMap::new();
478        parallel_tasks.insert("first".to_string(), TaskDefinition::Single(Box::new(task1)));
479        parallel_tasks.insert(
480            "second".to_string(),
481            TaskDefinition::Single(Box::new(task2)),
482        );
483
484        let group = TaskGroup::Parallel(parallel_tasks);
485
486        let nodes = graph.build_from_group("par", &group, &tasks).unwrap();
487        assert_eq!(nodes.len(), 2);
488
489        // Parallel tasks should not have dependencies between them
490        assert!(!graph.has_cycles());
491
492        let groups = graph.get_parallel_groups().unwrap();
493        assert_eq!(groups.len(), 1); // All in same level
494        assert_eq!(groups[0].len(), 2); // Both can run in parallel
495    }
496
497    #[test]
498    fn test_three_way_cycle_detection() {
499        let mut graph = TaskGraph::new();
500
501        // Create cyclic dependencies: A -> B -> C -> A
502        let task_a = create_test_task("task_a", vec!["task_c".to_string()]);
503        let task_b = create_test_task("task_b", vec!["task_a".to_string()]);
504        let task_c = create_test_task("task_c", vec!["task_b".to_string()]);
505
506        graph.add_task("task_a", task_a).unwrap();
507        graph.add_task("task_b", task_b).unwrap();
508        graph.add_task("task_c", task_c).unwrap();
509        graph.add_dependency_edges().unwrap(); // Add dependency edges after adding all tasks
510
511        // This should create a cycle
512        assert!(graph.has_cycles());
513
514        // Should fail when trying to get parallel groups
515        assert!(graph.get_parallel_groups().is_err());
516    }
517
518    #[test]
519    fn test_self_dependency_cycle() {
520        let mut graph = TaskGraph::new();
521
522        // Create self-referencing task
523        let task = create_test_task("self_ref", vec!["self_ref".to_string()]);
524        graph.add_task("self_ref", task).unwrap();
525        graph.add_dependency_edges().unwrap(); // Add dependency edges after adding all tasks
526
527        assert!(graph.has_cycles());
528        assert!(graph.get_parallel_groups().is_err());
529    }
530
531    #[test]
532    fn test_complex_dependency_graph() {
533        let mut graph = TaskGraph::new();
534
535        // Create a diamond dependency pattern:
536        //     A
537        //    / \
538        //   B   C
539        //    \ /
540        //     D
541        let task_a = create_test_task("a", vec![]);
542        let task_b = create_test_task("b", vec!["a".to_string()]);
543        let task_c = create_test_task("c", vec!["a".to_string()]);
544        let task_d = create_test_task("d", vec!["b".to_string(), "c".to_string()]);
545
546        graph.add_task("a", task_a).unwrap();
547        graph.add_task("b", task_b).unwrap();
548        graph.add_task("c", task_c).unwrap();
549        graph.add_task("d", task_d).unwrap();
550        graph.add_dependency_edges().unwrap(); // Add dependency edges after adding all tasks
551
552        assert!(!graph.has_cycles());
553        assert_eq!(graph.task_count(), 4);
554
555        let groups = graph.get_parallel_groups().unwrap();
556
557        // Should have 3 levels: [A], [B,C], [D]
558        assert_eq!(groups.len(), 3);
559        assert_eq!(groups[0].len(), 1); // A
560        assert_eq!(groups[1].len(), 2); // B and C can run in parallel
561        assert_eq!(groups[2].len(), 1); // D
562    }
563
564    #[test]
565    fn test_missing_dependency() {
566        let mut graph = TaskGraph::new();
567
568        // Create task with dependency that doesn't exist
569        let task = create_test_task("dependent", vec!["missing".to_string()]);
570        graph.add_task("dependent", task).unwrap();
571
572        // Should fail to get parallel groups due to missing dependency
573        assert!(graph.add_dependency_edges().is_err());
574    }
575
576    #[test]
577    fn test_empty_graph() {
578        let graph = TaskGraph::new();
579
580        assert_eq!(graph.task_count(), 0);
581        assert!(!graph.has_cycles());
582
583        let groups = graph.get_parallel_groups().unwrap();
584        assert!(groups.is_empty());
585    }
586
587    #[test]
588    fn test_single_task_no_deps() {
589        let mut graph = TaskGraph::new();
590
591        let task = create_test_task("solo", vec![]);
592        graph.add_task("solo", task).unwrap();
593
594        assert_eq!(graph.task_count(), 1);
595        assert!(!graph.has_cycles());
596
597        let groups = graph.get_parallel_groups().unwrap();
598        assert_eq!(groups.len(), 1);
599        assert_eq!(groups[0].len(), 1);
600    }
601
602    #[test]
603    fn test_linear_chain() {
604        let mut graph = TaskGraph::new();
605
606        // Create linear chain: A -> B -> C -> D
607        let task_a = create_test_task("a", vec![]);
608        let task_b = create_test_task("b", vec!["a".to_string()]);
609        let task_c = create_test_task("c", vec!["b".to_string()]);
610        let task_d = create_test_task("d", vec!["c".to_string()]);
611
612        graph.add_task("a", task_a).unwrap();
613        graph.add_task("b", task_b).unwrap();
614        graph.add_task("c", task_c).unwrap();
615        graph.add_task("d", task_d).unwrap();
616        graph.add_dependency_edges().unwrap(); // Add dependency edges after adding all tasks
617
618        assert!(!graph.has_cycles());
619        assert_eq!(graph.task_count(), 4);
620
621        let groups = graph.get_parallel_groups().unwrap();
622
623        // Should be 4 sequential groups
624        assert_eq!(groups.len(), 4);
625        for group in &groups {
626            assert_eq!(group.len(), 1);
627        }
628    }
629}