Skip to main content

cuenv_task_graph/
graph.rs

1//! Task graph builder using petgraph.
2//!
3//! This module builds directed acyclic graphs (DAGs) from task definitions
4//! to handle dependencies and determine execution order.
5
6use crate::{Error, Result, TaskNodeData, TaskResolution, TaskResolver};
7use petgraph::algo::{is_cyclic_directed, toposort};
8use petgraph::graph::{DiGraph, NodeIndex};
9use petgraph::visit::IntoNodeReferences;
10use std::collections::{HashMap, HashSet};
11use tracing::debug;
12
13/// A node in the task graph.
14#[derive(Debug, Clone)]
15pub struct GraphNode<T> {
16    /// Name of the task.
17    pub name: String,
18    /// The task data.
19    pub task: T,
20}
21
22/// Task graph for dependency resolution and execution ordering.
23///
24/// This is a generic graph that can hold any task type implementing [`TaskNodeData`].
25/// It provides methods for building the graph, resolving dependencies, and
26/// computing execution order.
27pub struct TaskGraph<T: TaskNodeData> {
28    /// The directed graph of tasks.
29    graph: DiGraph<GraphNode<T>, ()>,
30    /// Map from task names to node indices.
31    name_to_node: HashMap<String, NodeIndex>,
32    /// Map from group prefix to child task names (for dependency expansion).
33    group_children: HashMap<String, Vec<String>>,
34}
35
36impl<T: TaskNodeData> TaskGraph<T> {
37    /// Create a new empty task graph.
38    #[must_use]
39    pub fn new() -> Self {
40        Self {
41            graph: DiGraph::new(),
42            name_to_node: HashMap::new(),
43            group_children: HashMap::new(),
44        }
45    }
46
47    /// Add a single task to the graph.
48    ///
49    /// If a task with the same name already exists, returns the existing node index.
50    ///
51    /// # Errors
52    ///
53    /// Currently infallible, but returns `Result` for API consistency.
54    pub fn add_task(&mut self, name: &str, task: T) -> Result<NodeIndex> {
55        // Check if task already exists
56        if let Some(&node) = self.name_to_node.get(name) {
57            return Ok(node);
58        }
59
60        let node = GraphNode {
61            name: name.to_string(),
62            task,
63        };
64
65        let node_index = self.graph.add_node(node);
66        self.name_to_node.insert(name.to_string(), node_index);
67        debug!("Added task node '{}'", name);
68
69        Ok(node_index)
70    }
71
72    /// Get a mutable reference to a task node by index.
73    pub fn get_node_mut(&mut self, index: NodeIndex) -> Option<&mut GraphNode<T>> {
74        self.graph.node_weight_mut(index)
75    }
76
77    /// Get a reference to a task node by name.
78    #[must_use]
79    pub fn get_node_by_name(&self, name: &str) -> Option<&GraphNode<T>> {
80        self.name_to_node
81            .get(name)
82            .and_then(|&idx| self.graph.node_weight(idx))
83    }
84
85    /// Register a group of child task names under a group prefix.
86    ///
87    /// This enables group-level dependency expansion where depending on
88    /// a group name will expand to depend on all child tasks.
89    pub fn register_group(&mut self, prefix: &str, children: Vec<String>) {
90        if !children.is_empty() {
91            self.group_children.insert(prefix.to_string(), children);
92        }
93    }
94
95    /// Expand a dependency name to leaf task names.
96    ///
97    /// If the dependency is a direct task, returns it as-is.
98    /// If it's a group name, recursively expands to all leaf tasks in that group.
99    fn expand_dep_to_leaf_tasks(&self, dep_name: &str) -> Vec<String> {
100        if self.name_to_node.contains_key(dep_name) {
101            // It's a leaf task (exists directly in the graph)
102            vec![dep_name.to_string()]
103        } else if let Some(children) = self.group_children.get(dep_name) {
104            // It's a group - recursively expand children
105            children
106                .iter()
107                .flat_map(|child| self.expand_dep_to_leaf_tasks(child))
108                .collect()
109        } else {
110            // Not found - will be caught as missing dependency later
111            vec![dep_name.to_string()]
112        }
113    }
114
115    /// Add dependency edges after all tasks have been added.
116    ///
117    /// This ensures proper cycle detection and missing dependency validation.
118    ///
119    /// # Errors
120    ///
121    /// Returns an error if any task depends on a non-existent task.
122    pub fn add_dependency_edges(&mut self) -> Result<()> {
123        let mut missing_deps = Vec::new();
124        let mut edges_to_add = Vec::new();
125
126        // Collect all dependency relationships
127        for (node_index, node) in self.graph.node_references() {
128            for dep_name in node.task.dependency_names() {
129                // Expand group references to leaf tasks
130                let expanded_deps = self.expand_dep_to_leaf_tasks(dep_name);
131
132                for expanded_dep in expanded_deps {
133                    if let Some(&dep_node_index) = self.name_to_node.get(&expanded_dep) {
134                        // Record edge to add later
135                        edges_to_add.push((dep_node_index, node_index));
136                    } else {
137                        missing_deps.push((node.name.clone(), expanded_dep));
138                    }
139                }
140            }
141        }
142
143        // Report missing dependencies
144        if !missing_deps.is_empty() {
145            return Err(Error::MissingDependencies {
146                missing: missing_deps,
147            });
148        }
149
150        // Add all edges
151        for (from, to) in edges_to_add {
152            self.graph.add_edge(from, to, ());
153        }
154
155        Ok(())
156    }
157
158    /// Add a direct edge between two tasks.
159    ///
160    /// This is a low-level method for adding edges directly, typically used
161    /// for sequential group ordering.
162    pub fn add_edge(&mut self, from: NodeIndex, to: NodeIndex) {
163        self.graph.add_edge(from, to, ());
164    }
165
166    /// Check if the graph has cycles.
167    #[must_use]
168    pub fn has_cycles(&self) -> bool {
169        is_cyclic_directed(&self.graph)
170    }
171
172    /// Get topologically sorted list of tasks.
173    ///
174    /// # Errors
175    ///
176    /// Returns an error if the graph contains cycles.
177    pub fn topological_sort(&self) -> Result<Vec<GraphNode<T>>> {
178        if self.has_cycles() {
179            return Err(Error::CycleDetected {
180                message: "Task dependency graph contains cycles".to_string(),
181            });
182        }
183
184        match toposort(&self.graph, None) {
185            Ok(sorted_indices) => Ok(sorted_indices
186                .into_iter()
187                .map(|idx| self.graph[idx].clone())
188                .collect()),
189            Err(_) => Err(Error::TopologicalSortFailed {
190                reason: "petgraph toposort failed".to_string(),
191            }),
192        }
193    }
194
195    /// Get all tasks that can run in parallel (no dependencies between them).
196    ///
197    /// Returns a vector of parallel groups, where each group contains tasks
198    /// that can execute concurrently. Groups are ordered by dependency level.
199    ///
200    /// # Errors
201    ///
202    /// Returns an error if the graph contains cycles.
203    pub fn get_parallel_groups(&self) -> Result<Vec<Vec<GraphNode<T>>>> {
204        let sorted = self.topological_sort()?;
205
206        if sorted.is_empty() {
207            return Ok(vec![]);
208        }
209
210        // Group tasks by their dependency level
211        let mut groups: Vec<Vec<GraphNode<T>>> = vec![];
212        let mut processed: HashMap<String, usize> = HashMap::new();
213
214        for task in sorted {
215            // Find the maximum level of all dependencies
216            let mut level = 0;
217            for dep in task.task.dependency_names() {
218                if let Some(&dep_level) = processed.get(dep) {
219                    level = level.max(dep_level + 1);
220                }
221            }
222
223            // Add to appropriate group
224            if level >= groups.len() {
225                groups.resize(level + 1, vec![]);
226            }
227            groups[level].push(task.clone());
228            processed.insert(task.name.clone(), level);
229        }
230
231        Ok(groups)
232    }
233
234    /// Get the number of tasks in the graph.
235    #[must_use]
236    pub fn task_count(&self) -> usize {
237        self.graph.node_count()
238    }
239
240    /// Check if a task exists in the graph.
241    #[must_use]
242    pub fn contains_task(&self, name: &str) -> bool {
243        self.name_to_node.contains_key(name)
244    }
245
246    /// Get the node index for a task by name.
247    #[must_use]
248    pub fn get_node_index(&self, name: &str) -> Option<NodeIndex> {
249        self.name_to_node.get(name).copied()
250    }
251
252    /// Iterate over all nodes in the graph.
253    pub fn iter_nodes(&self) -> impl Iterator<Item = (NodeIndex, &GraphNode<T>)> {
254        self.graph.node_references()
255    }
256
257    /// Build graph for a specific task and all its transitive dependencies.
258    ///
259    /// This method takes an iterator of all available tasks and builds
260    /// only the subgraph needed for the requested task.
261    ///
262    /// # Arguments
263    ///
264    /// * `task_name` - The name of the task to build the graph for
265    /// * `get_task` - Function that returns the task data for a given name
266    ///
267    /// # Errors
268    ///
269    /// Returns an error if dependencies cannot be resolved.
270    pub fn build_for_task<F>(&mut self, task_name: &str, mut get_task: F) -> Result<()>
271    where
272        F: FnMut(&str) -> Option<T>,
273    {
274        let mut to_process = vec![task_name.to_string()];
275        let mut processed = HashSet::new();
276
277        debug!("Building graph for '{}'", task_name);
278
279        // First pass: Collect all tasks that need to be included
280        while let Some(current_name) = to_process.pop() {
281            if processed.contains(&current_name) {
282                continue;
283            }
284            processed.insert(current_name.clone());
285
286            if let Some(task) = get_task(&current_name) {
287                // Collect dependencies before adding the task
288                let deps: Vec<String> = task.dependency_names().map(String::from).collect();
289
290                self.add_task(&current_name, task)?;
291
292                // Add dependencies to processing queue
293                for dep in deps {
294                    if !processed.contains(&dep) {
295                        to_process.push(dep);
296                    }
297                }
298            } else {
299                debug!("Task '{}' not found while building graph", current_name);
300            }
301        }
302
303        // Second pass: Add dependency edges
304        self.add_dependency_edges()?;
305
306        Ok(())
307    }
308
309    /// Build graph for a specific task using a resolver that handles group expansion.
310    ///
311    /// This method uses the [`TaskResolver`] trait to resolve task names, which enables
312    /// unified handling of single tasks and groups (sequential/parallel).
313    ///
314    /// # Arguments
315    ///
316    /// * `task_name` - The name of the task to build the graph for
317    /// * `resolver` - Implementation of [`TaskResolver`] that provides task lookup and group expansion
318    ///
319    /// # Errors
320    ///
321    /// Returns an error if dependencies cannot be resolved.
322    pub fn build_for_task_with_resolver<R>(&mut self, task_name: &str, resolver: &R) -> Result<()>
323    where
324        R: TaskResolver<T>,
325    {
326        let mut to_process = vec![task_name.to_string()];
327        let mut processed = HashSet::new();
328        // Track sequential orderings for second pass
329        let mut sequential_orderings: Vec<Vec<String>> = Vec::new();
330        // Track parallel group depends_on to apply to leaf tasks
331        let mut pending_group_deps: HashMap<String, Vec<String>> = HashMap::new();
332
333        debug!("Building graph with resolver for '{}'", task_name);
334
335        // First pass: Collect all tasks and track sequential groups
336        while let Some(current_name) = to_process.pop() {
337            if processed.contains(&current_name) {
338                continue;
339            }
340            processed.insert(current_name.clone());
341
342            match resolver.resolve(&current_name) {
343                Some(TaskResolution::Single(mut task)) => {
344                    // Apply any pending group-level dependencies
345                    // Walk up the path to find parent groups
346                    let path_parts: Vec<&str> = current_name.split('.').collect();
347                    for i in 1..path_parts.len() {
348                        let parent_path = path_parts[..i].join(".");
349                        if let Some(deps) = pending_group_deps.get(&parent_path) {
350                            for dep in deps {
351                                task.add_dependency(dep.clone());
352                            }
353                        }
354                    }
355                    // Also check for bracket notation parents (e.g., "build[0]" -> "build")
356                    if let Some(bracket_idx) = current_name.find('[') {
357                        let parent_path = &current_name[..bracket_idx];
358                        if let Some(deps) = pending_group_deps.get(parent_path) {
359                            for dep in deps {
360                                task.add_dependency(dep.clone());
361                            }
362                        }
363                    }
364
365                    self.add_task(&current_name, task.clone())?;
366
367                    // Add dependencies to processing queue
368                    for dep in task.dependency_names() {
369                        if !processed.contains(dep) {
370                            to_process.push(dep.to_string());
371                        }
372                    }
373                }
374                Some(TaskResolution::Sequential { children }) => {
375                    self.register_group(&current_name, children.clone());
376                    // Track ordering for second pass
377                    sequential_orderings.push(children.clone());
378                    for child in children {
379                        if !processed.contains(&child) {
380                            to_process.push(child);
381                        }
382                    }
383                }
384                Some(TaskResolution::Parallel {
385                    children,
386                    depends_on,
387                }) => {
388                    self.register_group(&current_name, children.clone());
389                    // Store group-level deps to apply to leaf tasks
390                    if !depends_on.is_empty() {
391                        pending_group_deps.insert(current_name.clone(), depends_on.clone());
392                        // Also add the group deps to processing queue
393                        for dep in &depends_on {
394                            if !processed.contains(dep) {
395                                to_process.push(dep.clone());
396                            }
397                        }
398                    }
399                    for child in children {
400                        if !processed.contains(&child) {
401                            to_process.push(child);
402                        }
403                    }
404                }
405                None => {
406                    debug!("Task '{}' not found while building graph", current_name);
407                }
408            }
409        }
410
411        // Second pass: Add sequential ordering edges
412        for ordering in sequential_orderings {
413            for window in ordering.windows(2) {
414                if let [prev, next] = window {
415                    // Add edge from prev to next (prev must complete before next)
416                    if let (Some(prev_idx), Some(next_idx)) =
417                        (self.get_node_index(prev), self.get_node_index(next))
418                    {
419                        self.add_edge(prev_idx, next_idx);
420                    }
421                }
422            }
423        }
424
425        // Third pass: Add dependency edges from task.depends_on
426        self.add_dependency_edges()
427    }
428
429    /// Compute which tasks from a pipeline are affected, using transitive dependency propagation.
430    ///
431    /// This method determines which tasks need to run based on:
432    /// 1. Direct effect: The predicate returns true for the task
433    /// 2. Transitive effect: A task depends on an affected task
434    /// 3. External effect: An external dependency (e.g., `#project:task`) is affected
435    ///
436    /// # Arguments
437    ///
438    /// * `pipeline_tasks` - The names of tasks in the pipeline to check
439    /// * `is_directly_affected` - Predicate that returns true if a task is directly affected
440    /// * `is_external_affected` - Optional predicate for external dependencies (starting with `#`)
441    ///
442    /// # Returns
443    ///
444    /// A vector of task names that are affected, in pipeline order.
445    ///
446    /// # Example
447    ///
448    /// ```ignore
449    /// // Without external dependency checking
450    /// let affected = graph.compute_affected(
451    ///     &["build", "test", "deploy"],
452    ///     |task| task.is_affected_by(&changed_files, &project_root),
453    ///     None::<fn(&str) -> bool>,
454    /// );
455    ///
456    /// // With external dependency checking (for CI cross-project deps)
457    /// let affected = graph.compute_affected(
458    ///     &["build", "test", "deploy"],
459    ///     |task| task.is_affected_by(&changed_files, &project_root),
460    ///     Some(|dep: &str| check_external_dependency(dep, &all_projects, &changed_files)),
461    /// );
462    /// ```
463    #[allow(clippy::needless_pass_by_value)] // Option<E> is intentionally by-value for ergonomic API
464    pub fn compute_affected<F, E>(
465        &self,
466        pipeline_tasks: &[impl AsRef<str>],
467        is_directly_affected: F,
468        is_external_affected: Option<E>,
469    ) -> Vec<String>
470    where
471        F: Fn(&T) -> bool,
472        E: Fn(&str) -> bool,
473    {
474        use std::collections::HashSet;
475
476        let mut affected = HashSet::new();
477
478        // 1. Find directly affected tasks
479        for task_name in pipeline_tasks {
480            let task_name = task_name.as_ref();
481            if let Some(node) = self.get_node_by_name(task_name)
482                && is_directly_affected(&node.task)
483            {
484                affected.insert(task_name.to_string());
485            }
486        }
487
488        // 2. Propagate through dependencies (tasks that depend on affected tasks become affected)
489        let mut changed = true;
490        while changed {
491            changed = false;
492            for task_name in pipeline_tasks {
493                let task_name = task_name.as_ref();
494                if affected.contains(task_name) {
495                    continue;
496                }
497
498                if let Some(node) = self.get_node_by_name(task_name) {
499                    for dep in node.task.dependency_names() {
500                        // Check if this is an external dependency (starts with #)
501                        // Note: Cross-project refs are no longer supported, but keeping check for safety
502                        if dep.starts_with('#') {
503                            if is_external_affected
504                                .as_ref()
505                                .is_some_and(|resolver| resolver(dep))
506                            {
507                                affected.insert(task_name.to_string());
508                                changed = true;
509                                break;
510                            }
511                            continue;
512                        }
513
514                        // Expand group dependencies to their leaf tasks
515                        let leaf_deps = self.expand_dep_to_leaf_tasks(dep);
516                        for leaf_dep in leaf_deps {
517                            if affected.contains(&leaf_dep) {
518                                affected.insert(task_name.to_string());
519                                changed = true;
520                                break;
521                            }
522                        }
523                        if changed {
524                            break;
525                        }
526                    }
527                }
528            }
529        }
530
531        // Return in pipeline order
532        pipeline_tasks
533            .iter()
534            .map(|t| t.as_ref().to_string())
535            .filter(|t| affected.contains(t))
536            .collect()
537    }
538}
539
540impl<T: TaskNodeData> Default for TaskGraph<T> {
541    fn default() -> Self {
542        Self::new()
543    }
544}
545
546/// Compute the transitive closure of dependencies from an initial set.
547///
548/// Given a set of starting nodes and a function to retrieve dependencies,
549/// returns all nodes reachable by following dependency edges.
550///
551/// # Arguments
552///
553/// * `initial` - Starting set of node names
554/// * `get_deps` - Function that returns dependencies for a given node name
555///
556/// # Example
557///
558/// ```ignore
559/// use cuenv_task_graph::compute_transitive_closure;
560/// use std::collections::HashMap;
561///
562/// let deps: HashMap<&str, Vec<String>> = [
563///     ("build", vec![]),
564///     ("test", vec!["build".to_string()]),
565///     ("deploy", vec!["test".to_string()]),
566/// ].into_iter().collect();
567///
568/// let closure = compute_transitive_closure(
569///     ["deploy"],
570///     |name| deps.get(name).map(|v| v.as_slice()),
571/// );
572/// // closure contains: {"deploy", "test", "build"}
573/// ```
574#[must_use]
575pub fn compute_transitive_closure<'a>(
576    initial: impl IntoIterator<Item = &'a str>,
577    get_deps: impl Fn(&str) -> Option<&'a [String]>,
578) -> std::collections::HashSet<String> {
579    use std::collections::HashSet;
580
581    let mut all = HashSet::new();
582    let mut frontier: Vec<&str> = Vec::new();
583
584    // Initialize with starting set
585    for name in initial {
586        if all.insert(name.to_string()) {
587            frontier.push(name);
588        }
589    }
590
591    // BFS through dependencies
592    while let Some(task_id) = frontier.pop() {
593        if let Some(deps) = get_deps(task_id) {
594            for dep in deps {
595                if all.insert(dep.clone()) {
596                    // Use string slice from the set we just inserted
597                    // This is safe because we're doing BFS and won't revisit
598                    frontier.push(dep.as_str());
599                }
600            }
601        }
602    }
603
604    all
605}
606
607#[cfg(test)]
608mod tests {
609    use super::*;
610
611    /// Simple test task implementation
612    #[derive(Clone, Debug, Default)]
613    struct TestTask {
614        depends_on: Vec<String>,
615    }
616
617    impl TestTask {
618        fn new(deps: &[&str]) -> Self {
619            Self {
620                depends_on: deps.iter().map(|s| (*s).to_string()).collect(),
621            }
622        }
623    }
624
625    impl TaskNodeData for TestTask {
626        fn dependency_names(&self) -> impl Iterator<Item = &str> {
627            self.depends_on.iter().map(String::as_str)
628        }
629
630        fn add_dependency(&mut self, dep: String) {
631            if !self.depends_on.contains(&dep) {
632                self.depends_on.push(dep);
633            }
634        }
635    }
636
637    #[test]
638    fn test_task_graph_new() {
639        let graph: TaskGraph<TestTask> = TaskGraph::new();
640        assert_eq!(graph.task_count(), 0);
641    }
642
643    #[test]
644    fn test_add_single_task() {
645        let mut graph = TaskGraph::new();
646        let task = TestTask::new(&[]);
647
648        let node = graph.add_task("test", task).unwrap();
649        assert!(graph.contains_task("test"));
650        assert_eq!(graph.task_count(), 1);
651
652        // Adding same task again should return same node
653        let task2 = TestTask::new(&[]);
654        let node2 = graph.add_task("test", task2).unwrap();
655        assert_eq!(node, node2);
656        assert_eq!(graph.task_count(), 1);
657    }
658
659    #[test]
660    fn test_task_dependencies() {
661        let mut graph = TaskGraph::new();
662
663        // Add tasks with dependencies
664        let task1 = TestTask::new(&[]);
665        let task2 = TestTask::new(&["task1"]);
666        let task3 = TestTask::new(&["task1", "task2"]);
667
668        graph.add_task("task1", task1).unwrap();
669        graph.add_task("task2", task2).unwrap();
670        graph.add_task("task3", task3).unwrap();
671        graph.add_dependency_edges().unwrap();
672
673        assert_eq!(graph.task_count(), 3);
674        assert!(!graph.has_cycles());
675
676        let sorted = graph.topological_sort().unwrap();
677        assert_eq!(sorted.len(), 3);
678
679        // task1 should come before task2 and task3
680        let positions: HashMap<String, usize> = sorted
681            .iter()
682            .enumerate()
683            .map(|(i, node)| (node.name.clone(), i))
684            .collect();
685
686        assert!(positions["task1"] < positions["task2"]);
687        assert!(positions["task1"] < positions["task3"]);
688        assert!(positions["task2"] < positions["task3"]);
689    }
690
691    #[test]
692    fn test_cycle_detection() {
693        let mut graph = TaskGraph::new();
694
695        // Create a cycle: task1 -> task2 -> task3 -> task1
696        let task1 = TestTask::new(&["task3"]);
697        let task2 = TestTask::new(&["task1"]);
698        let task3 = TestTask::new(&["task2"]);
699
700        graph.add_task("task1", task1).unwrap();
701        graph.add_task("task2", task2).unwrap();
702        graph.add_task("task3", task3).unwrap();
703        graph.add_dependency_edges().unwrap();
704
705        assert!(graph.has_cycles());
706        assert!(graph.topological_sort().is_err());
707    }
708
709    #[test]
710    fn test_parallel_groups() {
711        let mut graph = TaskGraph::new();
712
713        // Create tasks that can run in parallel
714        // Level 0: task1, task2 (no dependencies)
715        // Level 1: task3 (depends on task1), task4 (depends on task2)
716        // Level 2: task5 (depends on task3 and task4)
717
718        let task1 = TestTask::new(&[]);
719        let task2 = TestTask::new(&[]);
720        let task3 = TestTask::new(&["task1"]);
721        let task4 = TestTask::new(&["task2"]);
722        let task5 = TestTask::new(&["task3", "task4"]);
723
724        graph.add_task("task1", task1).unwrap();
725        graph.add_task("task2", task2).unwrap();
726        graph.add_task("task3", task3).unwrap();
727        graph.add_task("task4", task4).unwrap();
728        graph.add_task("task5", task5).unwrap();
729        graph.add_dependency_edges().unwrap();
730
731        let groups = graph.get_parallel_groups().unwrap();
732
733        // Should have 3 levels
734        assert_eq!(groups.len(), 3);
735
736        // Level 0 should have 2 tasks
737        assert_eq!(groups[0].len(), 2);
738
739        // Level 1 should have 2 tasks
740        assert_eq!(groups[1].len(), 2);
741
742        // Level 2 should have 1 task
743        assert_eq!(groups[2].len(), 1);
744        assert_eq!(groups[2][0].name, "task5");
745    }
746
747    #[test]
748    fn test_group_dependency_expansion() {
749        let mut graph = TaskGraph::new();
750
751        // Register a group "build" with two children
752        graph.register_group(
753            "build",
754            vec!["build.deps".to_string(), "build.compile".to_string()],
755        );
756
757        // Add the child tasks
758        let deps_task = TestTask::new(&[]);
759        let compile_task = TestTask::new(&[]);
760        graph.add_task("build.deps", deps_task).unwrap();
761        graph.add_task("build.compile", compile_task).unwrap();
762
763        // Add a task that depends on the group name "build"
764        let test_task = TestTask::new(&["build"]);
765        graph.add_task("test", test_task).unwrap();
766
767        // This should succeed - "build" expands to both children
768        graph.add_dependency_edges().unwrap();
769
770        assert!(!graph.has_cycles());
771        assert_eq!(graph.task_count(), 3);
772
773        // test should come after both build.deps and build.compile
774        let sorted = graph.topological_sort().unwrap();
775        let positions: HashMap<String, usize> = sorted
776            .iter()
777            .enumerate()
778            .map(|(i, node)| (node.name.clone(), i))
779            .collect();
780
781        assert!(positions["build.deps"] < positions["test"]);
782        assert!(positions["build.compile"] < positions["test"]);
783    }
784
785    #[test]
786    fn test_missing_dependency() {
787        let mut graph = TaskGraph::new();
788
789        // Create task with dependency that doesn't exist
790        let task = TestTask::new(&["missing"]);
791        graph.add_task("dependent", task).unwrap();
792
793        // Should fail to add edges due to missing dependency
794        assert!(graph.add_dependency_edges().is_err());
795    }
796
797    #[test]
798    fn test_empty_graph() {
799        let graph: TaskGraph<TestTask> = TaskGraph::new();
800
801        assert_eq!(graph.task_count(), 0);
802        assert!(!graph.has_cycles());
803
804        let groups = graph.get_parallel_groups().unwrap();
805        assert!(groups.is_empty());
806    }
807
808    #[test]
809    fn test_diamond_dependency() {
810        let mut graph = TaskGraph::new();
811
812        // Create a diamond dependency pattern:
813        //     A
814        //    / \
815        //   B   C
816        //    \ /
817        //     D
818        let task_a = TestTask::new(&[]);
819        let task_b = TestTask::new(&["a"]);
820        let task_c = TestTask::new(&["a"]);
821        let task_d = TestTask::new(&["b", "c"]);
822
823        graph.add_task("a", task_a).unwrap();
824        graph.add_task("b", task_b).unwrap();
825        graph.add_task("c", task_c).unwrap();
826        graph.add_task("d", task_d).unwrap();
827        graph.add_dependency_edges().unwrap();
828
829        assert!(!graph.has_cycles());
830        assert_eq!(graph.task_count(), 4);
831
832        let groups = graph.get_parallel_groups().unwrap();
833
834        // Should have 3 levels: [A], [B,C], [D]
835        assert_eq!(groups.len(), 3);
836        assert_eq!(groups[0].len(), 1); // A
837        assert_eq!(groups[1].len(), 2); // B and C can run in parallel
838        assert_eq!(groups[2].len(), 1); // D
839    }
840
841    #[test]
842    fn test_self_dependency_cycle() {
843        let mut graph = TaskGraph::new();
844
845        // Create self-referencing task
846        let task = TestTask::new(&["self_ref"]);
847        graph.add_task("self_ref", task).unwrap();
848        graph.add_dependency_edges().unwrap();
849
850        assert!(graph.has_cycles());
851        assert!(graph.get_parallel_groups().is_err());
852    }
853
854    /// Test that shared dependencies appear only once in the DAG.
855    ///
856    /// When task A and task B both depend on task C, task C should only
857    /// appear once in the task graph (deduplication).
858    #[test]
859    fn test_shared_dependency_deduplication() {
860        let mut graph = TaskGraph::new();
861
862        // Create pattern where both A and B depend on C:
863        //     C
864        //    / \
865        //   A   B
866        let task_c = TestTask::new(&[]);
867        let task_a = TestTask::new(&["c"]);
868        let task_b = TestTask::new(&["c"]);
869
870        graph.add_task("c", task_c).unwrap();
871        graph.add_task("a", task_a).unwrap();
872        graph.add_task("b", task_b).unwrap();
873        graph.add_dependency_edges().unwrap();
874
875        // Verify task C appears exactly once in the graph
876        assert_eq!(graph.task_count(), 3, "Should have exactly 3 tasks");
877
878        // Count occurrences of task C in the topological sort
879        let sorted = graph.topological_sort().unwrap();
880        let c_count = sorted.iter().filter(|node| node.name == "c").count();
881        assert_eq!(c_count, 1, "Task C should appear exactly once in the DAG");
882
883        // Verify execution order: C comes before both A and B
884        let positions: std::collections::HashMap<String, usize> = sorted
885            .iter()
886            .enumerate()
887            .map(|(i, node)| (node.name.clone(), i))
888            .collect();
889        assert!(positions["c"] < positions["a"], "C should execute before A");
890        assert!(positions["c"] < positions["b"], "C should execute before B");
891
892        // Verify parallel groups: C in level 0, A and B in level 1
893        let groups = graph.get_parallel_groups().unwrap();
894        assert_eq!(groups.len(), 2, "Should have 2 execution levels");
895        assert_eq!(groups[0].len(), 1, "Level 0 should have 1 task (C)");
896        assert_eq!(groups[0][0].name, "c");
897        assert_eq!(groups[1].len(), 2, "Level 1 should have 2 tasks (A and B)");
898    }
899
900    #[test]
901    fn test_build_for_task() {
902        let mut graph = TaskGraph::new();
903
904        // Create a map of available tasks
905        let mut all_tasks = HashMap::new();
906        all_tasks.insert("a".to_string(), TestTask::new(&[]));
907        all_tasks.insert("b".to_string(), TestTask::new(&["a"]));
908        all_tasks.insert("c".to_string(), TestTask::new(&["b"]));
909        all_tasks.insert("d".to_string(), TestTask::new(&[])); // Not a dependency of c
910
911        // Build graph for "c" - should include a, b, c but not d
912        graph
913            .build_for_task("c", |name| all_tasks.get(name).cloned())
914            .unwrap();
915
916        assert_eq!(graph.task_count(), 3);
917        assert!(graph.contains_task("a"));
918        assert!(graph.contains_task("b"));
919        assert!(graph.contains_task("c"));
920        assert!(!graph.contains_task("d"));
921    }
922
923    // Tests for TaskResolver functionality
924
925    use crate::{TaskResolution, TaskResolver};
926
927    /// Test resolver that supports groups
928    struct TestResolver {
929        tasks: HashMap<String, TestTask>,
930        sequential_groups: HashMap<String, Vec<String>>,
931        parallel_groups: HashMap<String, (Vec<String>, Vec<String>)>, // (children, depends_on)
932    }
933
934    impl TestResolver {
935        fn new() -> Self {
936            Self {
937                tasks: HashMap::new(),
938                sequential_groups: HashMap::new(),
939                parallel_groups: HashMap::new(),
940            }
941        }
942
943        fn add_task(&mut self, name: &str, task: TestTask) {
944            self.tasks.insert(name.to_string(), task);
945        }
946
947        fn add_sequential_group(&mut self, name: &str, children: &[&str]) {
948            self.sequential_groups.insert(
949                name.to_string(),
950                children.iter().map(|s| (*s).to_string()).collect(),
951            );
952        }
953
954        fn add_parallel_group(&mut self, name: &str, children: &[&str], depends_on: &[&str]) {
955            self.parallel_groups.insert(
956                name.to_string(),
957                (
958                    children.iter().map(|s| (*s).to_string()).collect(),
959                    depends_on.iter().map(|s| (*s).to_string()).collect(),
960                ),
961            );
962        }
963    }
964
965    impl TaskResolver<TestTask> for TestResolver {
966        fn resolve(&self, name: &str) -> Option<TaskResolution<TestTask>> {
967            // Check if it's a direct task
968            if let Some(task) = self.tasks.get(name) {
969                return Some(TaskResolution::Single(task.clone()));
970            }
971            // Check if it's a sequential group
972            if let Some(children) = self.sequential_groups.get(name) {
973                return Some(TaskResolution::Sequential {
974                    children: children.clone(),
975                });
976            }
977            // Check if it's a parallel group
978            if let Some((children, depends_on)) = self.parallel_groups.get(name) {
979                return Some(TaskResolution::Parallel {
980                    children: children.clone(),
981                    depends_on: depends_on.clone(),
982                });
983            }
984            None
985        }
986    }
987
988    #[test]
989    fn test_resolver_single_task() {
990        let mut resolver = TestResolver::new();
991        resolver.add_task("build", TestTask::new(&[]));
992        resolver.add_task("test", TestTask::new(&["build"]));
993
994        let mut graph = TaskGraph::new();
995        graph
996            .build_for_task_with_resolver("test", &resolver)
997            .unwrap();
998
999        assert_eq!(graph.task_count(), 2);
1000        assert!(graph.contains_task("build"));
1001        assert!(graph.contains_task("test"));
1002
1003        let sorted = graph.topological_sort().unwrap();
1004        let positions: HashMap<String, usize> = sorted
1005            .iter()
1006            .enumerate()
1007            .map(|(i, n)| (n.name.clone(), i))
1008            .collect();
1009
1010        assert!(positions["build"] < positions["test"]);
1011    }
1012
1013    #[test]
1014    fn test_resolver_sequential_group() {
1015        let mut resolver = TestResolver::new();
1016        // Sequential group: build[0] -> build[1] -> build[2]
1017        resolver.add_sequential_group("build", &["build[0]", "build[1]", "build[2]"]);
1018        resolver.add_task("build[0]", TestTask::new(&[]));
1019        resolver.add_task("build[1]", TestTask::new(&[]));
1020        resolver.add_task("build[2]", TestTask::new(&[]));
1021
1022        let mut graph = TaskGraph::new();
1023        graph
1024            .build_for_task_with_resolver("build", &resolver)
1025            .unwrap();
1026
1027        assert_eq!(graph.task_count(), 3);
1028
1029        let sorted = graph.topological_sort().unwrap();
1030        let positions: HashMap<String, usize> = sorted
1031            .iter()
1032            .enumerate()
1033            .map(|(i, n)| (n.name.clone(), i))
1034            .collect();
1035
1036        // Sequential ordering must be preserved
1037        assert!(positions["build[0]"] < positions["build[1]"]);
1038        assert!(positions["build[1]"] < positions["build[2]"]);
1039    }
1040
1041    #[test]
1042    fn test_resolver_parallel_group() {
1043        let mut resolver = TestResolver::new();
1044        // Parallel group with children
1045        resolver.add_parallel_group(
1046            "build",
1047            &["build.frontend", "build.backend"],
1048            &[], // no group-level deps
1049        );
1050        resolver.add_task("build.frontend", TestTask::new(&[]));
1051        resolver.add_task("build.backend", TestTask::new(&[]));
1052
1053        let mut graph = TaskGraph::new();
1054        graph
1055            .build_for_task_with_resolver("build", &resolver)
1056            .unwrap();
1057
1058        assert_eq!(graph.task_count(), 2);
1059        assert!(graph.contains_task("build.frontend"));
1060        assert!(graph.contains_task("build.backend"));
1061
1062        // Both should be at same level (can run in parallel)
1063        let groups = graph.get_parallel_groups().unwrap();
1064        assert_eq!(groups.len(), 1); // Single level
1065        assert_eq!(groups[0].len(), 2); // Both tasks
1066    }
1067
1068    #[test]
1069    fn test_resolver_parallel_group_with_depends_on() {
1070        let mut resolver = TestResolver::new();
1071        // Setup task first
1072        resolver.add_task("setup", TestTask::new(&[]));
1073        // Parallel group with group-level depends_on
1074        resolver.add_parallel_group(
1075            "build",
1076            &["build.frontend", "build.backend"],
1077            &["setup"], // group depends on setup
1078        );
1079        resolver.add_task("build.frontend", TestTask::new(&[]));
1080        resolver.add_task("build.backend", TestTask::new(&[]));
1081
1082        let mut graph = TaskGraph::new();
1083        graph
1084            .build_for_task_with_resolver("build", &resolver)
1085            .unwrap();
1086
1087        assert_eq!(graph.task_count(), 3);
1088
1089        let sorted = graph.topological_sort().unwrap();
1090        let positions: HashMap<String, usize> = sorted
1091            .iter()
1092            .enumerate()
1093            .map(|(i, n)| (n.name.clone(), i))
1094            .collect();
1095
1096        // Setup must come before both children
1097        assert!(positions["setup"] < positions["build.frontend"]);
1098        assert!(positions["setup"] < positions["build.backend"]);
1099    }
1100
1101    #[test]
1102    fn test_resolver_nested_groups() {
1103        let mut resolver = TestResolver::new();
1104        // Top level parallel group
1105        resolver.add_parallel_group("build", &["build.frontend", "build.backend"], &[]);
1106        // Nested sequential group
1107        resolver.add_sequential_group(
1108            "build.frontend",
1109            &["build.frontend[0]", "build.frontend[1]"],
1110        );
1111        resolver.add_task("build.frontend[0]", TestTask::new(&[]));
1112        resolver.add_task("build.frontend[1]", TestTask::new(&[]));
1113        resolver.add_task("build.backend", TestTask::new(&[]));
1114
1115        let mut graph = TaskGraph::new();
1116        graph
1117            .build_for_task_with_resolver("build", &resolver)
1118            .unwrap();
1119
1120        assert_eq!(graph.task_count(), 3);
1121
1122        let sorted = graph.topological_sort().unwrap();
1123        let positions: HashMap<String, usize> = sorted
1124            .iter()
1125            .enumerate()
1126            .map(|(i, n)| (n.name.clone(), i))
1127            .collect();
1128
1129        // Sequential ordering within frontend must be preserved
1130        assert!(positions["build.frontend[0]"] < positions["build.frontend[1]"]);
1131    }
1132
1133    // ==========================================================================
1134    // compute_affected tests
1135    // ==========================================================================
1136
1137    #[test]
1138    fn test_compute_affected_direct() {
1139        let mut graph = TaskGraph::new();
1140        graph.add_task("build", TestTask::new(&[])).unwrap();
1141        graph.add_task("test", TestTask::new(&["build"])).unwrap();
1142        graph.add_task("deploy", TestTask::new(&["test"])).unwrap();
1143        graph.add_dependency_edges().unwrap();
1144
1145        // Only build is directly affected
1146        let affected = graph.compute_affected(
1147            &["build", "test", "deploy"],
1148            |task| {
1149                // Simulate: build has no deps (directly affected), others don't
1150                task.depends_on.is_empty()
1151            },
1152            None::<fn(&str) -> bool>,
1153        );
1154
1155        // build is directly affected, test and deploy are transitively affected
1156        assert_eq!(affected, vec!["build", "test", "deploy"]);
1157    }
1158
1159    #[test]
1160    fn test_compute_affected_none() {
1161        let mut graph = TaskGraph::new();
1162        graph.add_task("build", TestTask::new(&[])).unwrap();
1163        graph.add_task("test", TestTask::new(&["build"])).unwrap();
1164        graph.add_dependency_edges().unwrap();
1165
1166        // Nothing is directly affected
1167        let affected =
1168            graph.compute_affected(&["build", "test"], |_task| false, None::<fn(&str) -> bool>);
1169
1170        assert!(affected.is_empty());
1171    }
1172
1173    #[test]
1174    fn test_compute_affected_preserves_pipeline_order() {
1175        let mut graph = TaskGraph::new();
1176        graph.add_task("deploy", TestTask::new(&["test"])).unwrap();
1177        graph.add_task("test", TestTask::new(&["build"])).unwrap();
1178        graph.add_task("build", TestTask::new(&[])).unwrap();
1179        graph.add_dependency_edges().unwrap();
1180
1181        // All directly affected
1182        let affected = graph.compute_affected(
1183            &["build", "test", "deploy"],
1184            |_| true,
1185            None::<fn(&str) -> bool>,
1186        );
1187
1188        // Should preserve pipeline order, not graph order
1189        assert_eq!(affected, vec!["build", "test", "deploy"]);
1190    }
1191
1192    #[test]
1193    fn test_compute_affected_transitive_only() {
1194        let mut graph = TaskGraph::new();
1195        graph.add_task("build", TestTask::new(&[])).unwrap();
1196        graph.add_task("test", TestTask::new(&["build"])).unwrap();
1197        graph.add_task("deploy", TestTask::new(&["test"])).unwrap();
1198        graph.add_dependency_edges().unwrap();
1199
1200        // Only test is directly affected, but deploy depends on it
1201        let affected = graph.compute_affected(
1202            &["build", "test", "deploy"],
1203            |task| {
1204                // Only "test" has exactly one dependency
1205                task.depends_on.len() == 1 && task.depends_on[0] == "build"
1206            },
1207            None::<fn(&str) -> bool>,
1208        );
1209
1210        // test is directly affected, deploy is transitively affected
1211        // build is not affected because nothing depends on what build does
1212        assert_eq!(affected, vec!["test", "deploy"]);
1213    }
1214
1215    #[test]
1216    fn test_compute_affected_with_external_resolver() {
1217        let mut graph = TaskGraph::new();
1218        // build depends on an external project task, test depends on build
1219        graph
1220            .add_task("build", TestTask::new(&["#external:lib"]))
1221            .unwrap();
1222        graph.add_task("test", TestTask::new(&["build"])).unwrap();
1223        // Don't call add_dependency_edges() - external deps would fail validation
1224        // We manually add the internal edge
1225        let build_idx = *graph.name_to_node.get("build").unwrap();
1226        let test_idx = *graph.name_to_node.get("test").unwrap();
1227        graph.add_edge(build_idx, test_idx);
1228
1229        // External resolver: #external:lib is affected
1230        let affected = graph.compute_affected(
1231            &["build", "test"],
1232            |_task| false, // Nothing directly affected
1233            Some(|dep: &str| dep == "#external:lib"),
1234        );
1235
1236        // build is affected via external dep, test is transitively affected
1237        assert_eq!(affected, vec!["build", "test"]);
1238    }
1239
1240    #[test]
1241    fn test_compute_affected_external_not_affected() {
1242        let mut graph = TaskGraph::new();
1243        graph
1244            .add_task("build", TestTask::new(&["#external:lib"]))
1245            .unwrap();
1246        graph.add_task("test", TestTask::new(&["build"])).unwrap();
1247        // Don't call add_dependency_edges() - external deps would fail validation
1248        let build_idx = *graph.name_to_node.get("build").unwrap();
1249        let test_idx = *graph.name_to_node.get("test").unwrap();
1250        graph.add_edge(build_idx, test_idx);
1251
1252        // External resolver: nothing is affected
1253        let affected =
1254            graph.compute_affected(&["build", "test"], |_task| false, Some(|_dep: &str| false));
1255
1256        assert!(affected.is_empty());
1257    }
1258
1259    // ==========================================================================
1260    // compute_transitive_closure tests
1261    // ==========================================================================
1262
1263    #[test]
1264    fn test_transitive_closure_empty() {
1265        let deps: std::collections::HashMap<&str, Vec<String>> = std::collections::HashMap::new();
1266        let closure = compute_transitive_closure(std::iter::empty::<&str>(), |name| {
1267            deps.get(name).map(|v| v.as_slice())
1268        });
1269        assert!(closure.is_empty());
1270    }
1271
1272    #[test]
1273    fn test_transitive_closure_single_node_no_deps() {
1274        let deps: std::collections::HashMap<&str, Vec<String>> =
1275            [("build", vec![])].into_iter().collect();
1276        let closure =
1277            compute_transitive_closure(["build"], |name| deps.get(name).map(|v| v.as_slice()));
1278        assert_eq!(closure.len(), 1);
1279        assert!(closure.contains("build"));
1280    }
1281
1282    #[test]
1283    fn test_transitive_closure_chain() {
1284        // deploy -> test -> build
1285        let deps: std::collections::HashMap<&str, Vec<String>> = [
1286            ("build", vec![]),
1287            ("test", vec!["build".to_string()]),
1288            ("deploy", vec!["test".to_string()]),
1289        ]
1290        .into_iter()
1291        .collect();
1292
1293        let closure =
1294            compute_transitive_closure(["deploy"], |name| deps.get(name).map(|v| v.as_slice()));
1295
1296        assert_eq!(closure.len(), 3);
1297        assert!(closure.contains("deploy"));
1298        assert!(closure.contains("test"));
1299        assert!(closure.contains("build"));
1300    }
1301
1302    #[test]
1303    fn test_transitive_closure_diamond() {
1304        //      A
1305        //     / \
1306        //    B   C
1307        //     \ /
1308        //      D
1309        let deps: std::collections::HashMap<&str, Vec<String>> = [
1310            ("D", vec![]),
1311            ("B", vec!["D".to_string()]),
1312            ("C", vec!["D".to_string()]),
1313            ("A", vec!["B".to_string(), "C".to_string()]),
1314        ]
1315        .into_iter()
1316        .collect();
1317
1318        let closure =
1319            compute_transitive_closure(["A"], |name| deps.get(name).map(|v| v.as_slice()));
1320
1321        assert_eq!(closure.len(), 4);
1322        assert!(closure.contains("A"));
1323        assert!(closure.contains("B"));
1324        assert!(closure.contains("C"));
1325        assert!(closure.contains("D"));
1326    }
1327
1328    #[test]
1329    fn test_transitive_closure_multiple_initial() {
1330        // Two separate chains: A -> B, C -> D
1331        let deps: std::collections::HashMap<&str, Vec<String>> = [
1332            ("B", vec![]),
1333            ("A", vec!["B".to_string()]),
1334            ("D", vec![]),
1335            ("C", vec!["D".to_string()]),
1336        ]
1337        .into_iter()
1338        .collect();
1339
1340        let closure =
1341            compute_transitive_closure(["A", "C"], |name| deps.get(name).map(|v| v.as_slice()));
1342
1343        assert_eq!(closure.len(), 4);
1344        assert!(closure.contains("A"));
1345        assert!(closure.contains("B"));
1346        assert!(closure.contains("C"));
1347        assert!(closure.contains("D"));
1348    }
1349
1350    #[test]
1351    fn test_transitive_closure_missing_dep() {
1352        // A depends on nonexistent B - should just include A
1353        let deps: std::collections::HashMap<&str, Vec<String>> =
1354            [("A", vec!["B".to_string()])].into_iter().collect();
1355
1356        let closure =
1357            compute_transitive_closure(["A"], |name| deps.get(name).map(|v| v.as_slice()));
1358
1359        // A is included, B is added to closure even though it has no entry (it's a valid node name)
1360        assert_eq!(closure.len(), 2);
1361        assert!(closure.contains("A"));
1362        assert!(closure.contains("B"));
1363    }
1364}