cuenv_core/tasks/
graph.rs

1//! Task graph builder using petgraph
2//!
3//! This module builds directed acyclic graphs (DAGs) from task definitions
4//! to handle dependencies and determine execution order.
5
6use super::{Task, TaskDefinition, TaskGroup, Tasks};
7use crate::Result;
8use petgraph::algo::{is_cyclic_directed, toposort};
9use petgraph::graph::{DiGraph, NodeIndex};
10use petgraph::visit::IntoNodeReferences;
11use std::collections::{HashMap, HashSet};
12
13/// A node in the task graph
14#[derive(Debug, Clone)]
15pub struct TaskNode {
16    /// Name of the task
17    pub name: String,
18    /// The task to execute
19    pub task: Task,
20}
21
22/// Task graph for dependency resolution and execution ordering
23pub struct TaskGraph {
24    /// The directed graph of tasks
25    graph: DiGraph<TaskNode, ()>,
26    /// Map from task names to node indices
27    name_to_node: HashMap<String, NodeIndex>,
28}
29
30impl TaskGraph {
31    /// Create a new empty task graph
32    pub fn new() -> Self {
33        Self {
34            graph: DiGraph::new(),
35            name_to_node: HashMap::new(),
36        }
37    }
38
39    /// Build a graph from a task definition
40    pub fn build_from_definition(
41        &mut self,
42        name: &str,
43        definition: &TaskDefinition,
44        all_tasks: &Tasks,
45    ) -> Result<Vec<NodeIndex>> {
46        match definition {
47            TaskDefinition::Single(task) => {
48                let node = self.add_task(name, task.clone())?;
49                Ok(vec![node])
50            }
51            TaskDefinition::Group(group) => self.build_from_group(name, group, all_tasks),
52        }
53    }
54
55    /// Build a graph from a task group
56    fn build_from_group(
57        &mut self,
58        prefix: &str,
59        group: &TaskGroup,
60        all_tasks: &Tasks,
61    ) -> Result<Vec<NodeIndex>> {
62        match group {
63            TaskGroup::Sequential(tasks) => self.build_sequential_group(prefix, tasks, all_tasks),
64            TaskGroup::Parallel(tasks) => self.build_parallel_group(prefix, tasks, all_tasks),
65        }
66    }
67
68    /// Build a sequential task group (tasks run one after another)
69    fn build_sequential_group(
70        &mut self,
71        prefix: &str,
72        tasks: &[TaskDefinition],
73        all_tasks: &Tasks,
74    ) -> Result<Vec<NodeIndex>> {
75        let mut nodes = Vec::new();
76        let mut previous: Option<NodeIndex> = None;
77
78        for (i, task_def) in tasks.iter().enumerate() {
79            let task_name = format!("{}[{}]", prefix, i);
80            let task_nodes = self.build_from_definition(&task_name, task_def, all_tasks)?;
81
82            // For sequential execution, link previous task to current
83            if let Some(prev) = previous
84                && let Some(first) = task_nodes.first()
85            {
86                self.graph.add_edge(prev, *first, ());
87            }
88
89            if let Some(last) = task_nodes.last() {
90                previous = Some(*last);
91            }
92
93            nodes.extend(task_nodes);
94        }
95
96        Ok(nodes)
97    }
98
99    /// Build a parallel task group (tasks can run concurrently)
100    fn build_parallel_group(
101        &mut self,
102        prefix: &str,
103        tasks: &HashMap<String, TaskDefinition>,
104        all_tasks: &Tasks,
105    ) -> Result<Vec<NodeIndex>> {
106        let mut nodes = Vec::new();
107
108        for (name, task_def) in tasks {
109            let task_name = format!("{}.{}", prefix, name);
110            let task_nodes = self.build_from_definition(&task_name, task_def, all_tasks)?;
111            nodes.extend(task_nodes);
112        }
113
114        Ok(nodes)
115    }
116
117    /// Add a single task to the graph
118    pub fn add_task(&mut self, name: &str, task: Task) -> Result<NodeIndex> {
119        // Check if task already exists
120        if let Some(&node) = self.name_to_node.get(name) {
121            return Ok(node);
122        }
123
124        let node = TaskNode {
125            name: name.to_string(),
126            task: task.clone(),
127        };
128
129        let node_index = self.graph.add_node(node);
130        self.name_to_node.insert(name.to_string(), node_index);
131
132        Ok(node_index)
133    }
134
135    /// Add dependency edges after all tasks have been added
136    /// This ensures proper cycle detection and missing dependency validation
137    fn add_dependency_edges(&mut self) -> Result<()> {
138        let mut missing_deps = Vec::new();
139        let mut edges_to_add = Vec::new();
140
141        // Collect all dependency relationships
142        for (node_index, node) in self.graph.node_references() {
143            for dep_name in &node.task.depends_on {
144                if let Some(&dep_node_index) = self.name_to_node.get(dep_name as &str) {
145                    // Record edge to add later
146                    edges_to_add.push((dep_node_index, node_index));
147                } else {
148                    missing_deps.push((node.name.clone(), dep_name.clone()));
149                }
150            }
151        }
152
153        // Report missing dependencies
154        if !missing_deps.is_empty() {
155            let missing_list = missing_deps
156                .iter()
157                .map(|(task, dep)| format!("Task '{}' depends on missing task '{}'", task, dep))
158                .collect::<Vec<_>>()
159                .join(", ");
160            return Err(crate::Error::configuration(format!(
161                "Missing dependencies: {}",
162                missing_list
163            )));
164        }
165
166        // Add all edges
167        for (from, to) in edges_to_add {
168            self.graph.add_edge(from, to, ());
169        }
170
171        Ok(())
172    }
173
174    /// Check if the graph has cycles
175    pub fn has_cycles(&self) -> bool {
176        is_cyclic_directed(&self.graph)
177    }
178
179    /// Get topologically sorted list of tasks
180    pub fn topological_sort(&self) -> Result<Vec<TaskNode>> {
181        if self.has_cycles() {
182            return Err(crate::Error::configuration(
183                "Task dependency graph contains cycles".to_string(),
184            ));
185        }
186
187        match toposort(&self.graph, None) {
188            Ok(sorted_indices) => Ok(sorted_indices
189                .into_iter()
190                .map(|idx| self.graph[idx].clone())
191                .collect()),
192            Err(_) => Err(crate::Error::configuration(
193                "Failed to sort tasks topologically".to_string(),
194            )),
195        }
196    }
197
198    /// Get all tasks that can run in parallel (no dependencies between them)
199    pub fn get_parallel_groups(&self) -> Result<Vec<Vec<TaskNode>>> {
200        let sorted = self.topological_sort()?;
201
202        if sorted.is_empty() {
203            return Ok(vec![]);
204        }
205
206        // Group tasks by their dependency level
207        let mut groups: Vec<Vec<TaskNode>> = vec![];
208        let mut processed: HashMap<String, usize> = HashMap::new();
209
210        for task in sorted {
211            // Find the maximum level of all dependencies
212            let mut level = 0;
213            for dep in &task.task.depends_on {
214                if let Some(&dep_level) = processed.get(dep) {
215                    level = level.max(dep_level + 1);
216                }
217            }
218
219            // Add to appropriate group
220            if level >= groups.len() {
221                groups.resize(level + 1, vec![]);
222            }
223            groups[level].push(task.clone());
224            processed.insert(task.name.clone(), level);
225        }
226
227        Ok(groups)
228    }
229
230    /// Get the number of tasks in the graph
231    pub fn task_count(&self) -> usize {
232        self.graph.node_count()
233    }
234
235    /// Check if a task exists in the graph
236    pub fn contains_task(&self, name: &str) -> bool {
237        self.name_to_node.contains_key(name)
238    }
239
240    /// Build a complete graph from tasks with proper dependency resolution
241    /// This performs a two-pass build: first adding all nodes, then all edges
242    pub fn build_complete_graph(&mut self, tasks: &Tasks) -> Result<()> {
243        // First pass: Add all tasks as nodes
244        for (name, definition) in tasks.tasks.iter() {
245            match definition {
246                TaskDefinition::Single(task) => {
247                    self.add_task(name, task.clone())?;
248                }
249                TaskDefinition::Group(_) => {
250                    // For groups, we'd need to expand them - this is more complex
251                    // and not needed for the current fix. Groups should be handled
252                    // by build_from_definition which already works correctly.
253                }
254            }
255        }
256
257        // Second pass: Add all dependency edges
258        self.add_dependency_edges()?;
259
260        Ok(())
261    }
262
263    /// Build graph for a specific task and all its transitive dependencies
264    pub fn build_for_task(&mut self, task_name: &str, all_tasks: &Tasks) -> Result<()> {
265        let mut to_process = vec![task_name.to_string()];
266        let mut processed = HashSet::new();
267
268        // First pass: Collect all tasks that need to be included
269        while let Some(current_name) = to_process.pop() {
270            if processed.contains(&current_name) {
271                continue;
272            }
273            processed.insert(current_name.clone());
274
275            if let Some(definition) = all_tasks.get(&current_name) {
276                match definition {
277                    TaskDefinition::Single(task) => {
278                        self.add_task(&current_name, task.clone())?;
279                        // Add dependencies to processing queue
280                        for dep in &task.depends_on {
281                            if !processed.contains(dep) {
282                                to_process.push(dep.clone());
283                            }
284                        }
285                    }
286                    TaskDefinition::Group(_) => {
287                        // Handle groups with build_from_definition
288                        self.build_from_definition(&current_name, definition, all_tasks)?;
289                    }
290                }
291            }
292        }
293
294        // Second pass: Add dependency edges
295        self.add_dependency_edges()?;
296
297        Ok(())
298    }
299}
300
301impl Default for TaskGraph {
302    fn default() -> Self {
303        Self::new()
304    }
305}
306
307#[cfg(test)]
308mod tests {
309    use super::*;
310
311    fn create_test_task(name: &str, deps: Vec<String>) -> Task {
312        Task {
313            command: format!("echo {}", name),
314            args: vec![],
315            shell: None,
316            env: HashMap::new(),
317            depends_on: deps,
318            inputs: vec![],
319            outputs: vec![],
320            description: Some(format!("Test task {}", name)),
321        }
322    }
323
324    #[test]
325    fn test_task_graph_new() {
326        let graph = TaskGraph::new();
327        assert_eq!(graph.task_count(), 0);
328    }
329
330    #[test]
331    fn test_add_single_task() {
332        let mut graph = TaskGraph::new();
333        let task = create_test_task("test", vec![]);
334
335        let node = graph.add_task("test", task).unwrap();
336        assert!(graph.contains_task("test"));
337        assert_eq!(graph.task_count(), 1);
338
339        // Adding same task again should return same node
340        let task2 = create_test_task("test", vec![]);
341        let node2 = graph.add_task("test", task2).unwrap();
342        assert_eq!(node, node2);
343        assert_eq!(graph.task_count(), 1);
344    }
345
346    #[test]
347    fn test_task_dependencies() {
348        let mut graph = TaskGraph::new();
349
350        // Add tasks with dependencies
351        let task1 = create_test_task("task1", vec![]);
352        let task2 = create_test_task("task2", vec!["task1".to_string()]);
353        let task3 = create_test_task("task3", vec!["task1".to_string(), "task2".to_string()]);
354
355        graph.add_task("task1", task1).unwrap();
356        graph.add_task("task2", task2).unwrap();
357        graph.add_task("task3", task3).unwrap();
358        graph.add_dependency_edges().unwrap(); // Add dependency edges after adding all tasks
359
360        assert_eq!(graph.task_count(), 3);
361        assert!(!graph.has_cycles());
362
363        let sorted = graph.topological_sort().unwrap();
364        assert_eq!(sorted.len(), 3);
365
366        // task1 should come before task2 and task3
367        let positions: HashMap<String, usize> = sorted
368            .iter()
369            .enumerate()
370            .map(|(i, node)| (node.name.clone(), i))
371            .collect();
372
373        assert!(positions["task1"] < positions["task2"]);
374        assert!(positions["task1"] < positions["task3"]);
375        assert!(positions["task2"] < positions["task3"]);
376    }
377
378    #[test]
379    fn test_cycle_detection() {
380        let mut graph = TaskGraph::new();
381
382        // Create a cycle: task1 -> task2 -> task3 -> task1
383        let task1 = create_test_task("task1", vec!["task3".to_string()]);
384        let task2 = create_test_task("task2", vec!["task1".to_string()]);
385        let task3 = create_test_task("task3", vec!["task2".to_string()]);
386
387        graph.add_task("task1", task1).unwrap();
388        graph.add_task("task2", task2).unwrap();
389        graph.add_task("task3", task3).unwrap();
390        graph.add_dependency_edges().unwrap(); // Add dependency edges after adding all tasks
391
392        assert!(graph.has_cycles());
393        assert!(graph.topological_sort().is_err());
394    }
395
396    #[test]
397    fn test_parallel_groups() {
398        let mut graph = TaskGraph::new();
399
400        // Create tasks that can run in parallel
401        // Level 0: task1, task2 (no dependencies)
402        // Level 1: task3 (depends on task1), task4 (depends on task2)
403        // Level 2: task5 (depends on task3 and task4)
404
405        let task1 = create_test_task("task1", vec![]);
406        let task2 = create_test_task("task2", vec![]);
407        let task3 = create_test_task("task3", vec!["task1".to_string()]);
408        let task4 = create_test_task("task4", vec!["task2".to_string()]);
409        let task5 = create_test_task("task5", vec!["task3".to_string(), "task4".to_string()]);
410
411        graph.add_task("task1", task1).unwrap();
412        graph.add_task("task2", task2).unwrap();
413        graph.add_task("task3", task3).unwrap();
414        graph.add_task("task4", task4).unwrap();
415        graph.add_task("task5", task5).unwrap();
416        graph.add_dependency_edges().unwrap(); // Add dependency edges after adding all tasks
417
418        let groups = graph.get_parallel_groups().unwrap();
419
420        // Should have 3 levels
421        assert_eq!(groups.len(), 3);
422
423        // Level 0 should have 2 tasks
424        assert_eq!(groups[0].len(), 2);
425
426        // Level 1 should have 2 tasks
427        assert_eq!(groups[1].len(), 2);
428
429        // Level 2 should have 1 task
430        assert_eq!(groups[2].len(), 1);
431        assert_eq!(groups[2][0].name, "task5");
432    }
433
434    #[test]
435    fn test_build_from_sequential_group() {
436        let mut graph = TaskGraph::new();
437        let tasks = Tasks::new();
438
439        let task1 = create_test_task("t1", vec![]);
440        let task2 = create_test_task("t2", vec![]);
441
442        let group = TaskGroup::Sequential(vec![
443            TaskDefinition::Single(task1),
444            TaskDefinition::Single(task2),
445        ]);
446
447        let nodes = graph.build_from_group("seq", &group, &tasks).unwrap();
448        assert_eq!(nodes.len(), 2);
449
450        // Sequential tasks should have dependency chain
451        let sorted = graph.topological_sort().unwrap();
452        assert_eq!(sorted.len(), 2);
453        assert_eq!(sorted[0].name, "seq[0]");
454        assert_eq!(sorted[1].name, "seq[1]");
455    }
456
457    #[test]
458    fn test_build_from_parallel_group() {
459        let mut graph = TaskGraph::new();
460        let tasks = Tasks::new();
461
462        let task1 = create_test_task("t1", vec![]);
463        let task2 = create_test_task("t2", vec![]);
464
465        let mut parallel_tasks = HashMap::new();
466        parallel_tasks.insert("first".to_string(), TaskDefinition::Single(task1));
467        parallel_tasks.insert("second".to_string(), TaskDefinition::Single(task2));
468
469        let group = TaskGroup::Parallel(parallel_tasks);
470
471        let nodes = graph.build_from_group("par", &group, &tasks).unwrap();
472        assert_eq!(nodes.len(), 2);
473
474        // Parallel tasks should not have dependencies between them
475        assert!(!graph.has_cycles());
476
477        let groups = graph.get_parallel_groups().unwrap();
478        assert_eq!(groups.len(), 1); // All in same level
479        assert_eq!(groups[0].len(), 2); // Both can run in parallel
480    }
481
482    #[test]
483    fn test_three_way_cycle_detection() {
484        let mut graph = TaskGraph::new();
485
486        // Create cyclic dependencies: A -> B -> C -> A
487        let task_a = create_test_task("task_a", vec!["task_c".to_string()]);
488        let task_b = create_test_task("task_b", vec!["task_a".to_string()]);
489        let task_c = create_test_task("task_c", vec!["task_b".to_string()]);
490
491        graph.add_task("task_a", task_a).unwrap();
492        graph.add_task("task_b", task_b).unwrap();
493        graph.add_task("task_c", task_c).unwrap();
494        graph.add_dependency_edges().unwrap(); // Add dependency edges after adding all tasks
495
496        // This should create a cycle
497        assert!(graph.has_cycles());
498
499        // Should fail when trying to get parallel groups
500        assert!(graph.get_parallel_groups().is_err());
501    }
502
503    #[test]
504    fn test_self_dependency_cycle() {
505        let mut graph = TaskGraph::new();
506
507        // Create self-referencing task
508        let task = create_test_task("self_ref", vec!["self_ref".to_string()]);
509        graph.add_task("self_ref", task).unwrap();
510        graph.add_dependency_edges().unwrap(); // Add dependency edges after adding all tasks
511
512        assert!(graph.has_cycles());
513        assert!(graph.get_parallel_groups().is_err());
514    }
515
516    #[test]
517    fn test_complex_dependency_graph() {
518        let mut graph = TaskGraph::new();
519
520        // Create a diamond dependency pattern:
521        //     A
522        //    / \
523        //   B   C
524        //    \ /
525        //     D
526        let task_a = create_test_task("a", vec![]);
527        let task_b = create_test_task("b", vec!["a".to_string()]);
528        let task_c = create_test_task("c", vec!["a".to_string()]);
529        let task_d = create_test_task("d", vec!["b".to_string(), "c".to_string()]);
530
531        graph.add_task("a", task_a).unwrap();
532        graph.add_task("b", task_b).unwrap();
533        graph.add_task("c", task_c).unwrap();
534        graph.add_task("d", task_d).unwrap();
535        graph.add_dependency_edges().unwrap(); // Add dependency edges after adding all tasks
536
537        assert!(!graph.has_cycles());
538        assert_eq!(graph.task_count(), 4);
539
540        let groups = graph.get_parallel_groups().unwrap();
541
542        // Should have 3 levels: [A], [B,C], [D]
543        assert_eq!(groups.len(), 3);
544        assert_eq!(groups[0].len(), 1); // A
545        assert_eq!(groups[1].len(), 2); // B and C can run in parallel
546        assert_eq!(groups[2].len(), 1); // D
547    }
548
549    #[test]
550    fn test_missing_dependency() {
551        let mut graph = TaskGraph::new();
552
553        // Create task with dependency that doesn't exist
554        let task = create_test_task("dependent", vec!["missing".to_string()]);
555        graph.add_task("dependent", task).unwrap();
556
557        // Should fail to get parallel groups due to missing dependency
558        assert!(graph.add_dependency_edges().is_err());
559    }
560
561    #[test]
562    fn test_empty_graph() {
563        let graph = TaskGraph::new();
564
565        assert_eq!(graph.task_count(), 0);
566        assert!(!graph.has_cycles());
567
568        let groups = graph.get_parallel_groups().unwrap();
569        assert!(groups.is_empty());
570    }
571
572    #[test]
573    fn test_single_task_no_deps() {
574        let mut graph = TaskGraph::new();
575
576        let task = create_test_task("solo", vec![]);
577        graph.add_task("solo", task).unwrap();
578
579        assert_eq!(graph.task_count(), 1);
580        assert!(!graph.has_cycles());
581
582        let groups = graph.get_parallel_groups().unwrap();
583        assert_eq!(groups.len(), 1);
584        assert_eq!(groups[0].len(), 1);
585    }
586
587    #[test]
588    fn test_linear_chain() {
589        let mut graph = TaskGraph::new();
590
591        // Create linear chain: A -> B -> C -> D
592        let task_a = create_test_task("a", vec![]);
593        let task_b = create_test_task("b", vec!["a".to_string()]);
594        let task_c = create_test_task("c", vec!["b".to_string()]);
595        let task_d = create_test_task("d", vec!["c".to_string()]);
596
597        graph.add_task("a", task_a).unwrap();
598        graph.add_task("b", task_b).unwrap();
599        graph.add_task("c", task_c).unwrap();
600        graph.add_task("d", task_d).unwrap();
601        graph.add_dependency_edges().unwrap(); // Add dependency edges after adding all tasks
602
603        assert!(!graph.has_cycles());
604        assert_eq!(graph.task_count(), 4);
605
606        let groups = graph.get_parallel_groups().unwrap();
607
608        // Should be 4 sequential groups
609        assert_eq!(groups.len(), 4);
610        for group in &groups {
611            assert_eq!(group.len(), 1);
612        }
613    }
614}