cuenv_ci/executor/
graph.rs

1//! CI Task Graph
2//!
3//! Builds a directed acyclic graph (DAG) from IR tasks for dependency-ordered
4//! parallel execution.
5
6use crate::compiler::digest::compute_task_digest;
7use crate::ir::{IntermediateRepresentation, Task as IRTask};
8use petgraph::algo::{is_cyclic_directed, toposort};
9use petgraph::graph::{DiGraph, NodeIndex};
10use petgraph::visit::IntoNodeReferences;
11use std::collections::HashMap;
12use thiserror::Error;
13
14/// Error types for task graph operations
15#[derive(Debug, Error)]
16pub enum GraphError {
17    /// Task dependency cycle detected
18    #[error("Task dependency graph contains cycle involving: {tasks}")]
19    CyclicDependency { tasks: String },
20
21    /// Missing dependency reference
22    #[error("Task '{task}' depends on non-existent task '{dependency}'")]
23    MissingDependency { task: String, dependency: String },
24
25    /// Topological sort failed
26    #[error("Failed to determine task execution order")]
27    SortFailed,
28}
29
30/// A node in the CI task graph
31#[derive(Debug, Clone)]
32pub struct CITaskNode {
33    /// Task ID from IR
34    pub id: String,
35    /// The IR task definition
36    pub task: IRTask,
37    /// Pre-computed digest for cache lookup (computed after secret resolution)
38    pub digest: String,
39}
40
41/// CI Task graph for dependency resolution and parallel execution
42pub struct CITaskGraph {
43    /// The directed graph of tasks
44    graph: DiGraph<CITaskNode, ()>,
45    /// Map from task IDs to node indices
46    id_to_index: HashMap<String, NodeIndex>,
47}
48
49impl CITaskGraph {
50    /// Build a task graph from an IR document
51    ///
52    /// # Errors
53    /// Returns error if dependencies reference non-existent tasks or if
54    /// the graph contains cycles.
55    pub fn from_ir(ir: &IntermediateRepresentation) -> Result<Self, GraphError> {
56        let mut graph = DiGraph::new();
57        let mut id_to_index = HashMap::new();
58
59        // First pass: Add all tasks as nodes
60        for task in &ir.tasks {
61            let node = CITaskNode {
62                id: task.id.clone(),
63                task: task.clone(),
64                digest: String::new(), // Computed later with secrets
65            };
66            let index = graph.add_node(node);
67            id_to_index.insert(task.id.clone(), index);
68        }
69
70        // Second pass: Add dependency edges
71        let mut edges_to_add = Vec::new();
72        for task in &ir.tasks {
73            let task_index = id_to_index[&task.id];
74            for dep_id in &task.depends_on {
75                let dep_index =
76                    id_to_index
77                        .get(dep_id)
78                        .ok_or_else(|| GraphError::MissingDependency {
79                            task: task.id.clone(),
80                            dependency: dep_id.clone(),
81                        })?;
82                // Edge goes from dependency to dependent (dep -> task)
83                edges_to_add.push((*dep_index, task_index));
84            }
85        }
86
87        for (from, to) in edges_to_add {
88            graph.add_edge(from, to, ());
89        }
90
91        let result = Self { graph, id_to_index };
92
93        // Check for cycles
94        if result.has_cycles() {
95            // Find tasks involved in cycle for error message
96            let task_ids: Vec<_> = result
97                .graph
98                .node_references()
99                .map(|(_, n)| n.id.clone())
100                .collect();
101            return Err(GraphError::CyclicDependency {
102                tasks: task_ids.join(", "),
103            });
104        }
105
106        Ok(result)
107    }
108
109    /// Check if the graph contains cycles
110    #[must_use]
111    pub fn has_cycles(&self) -> bool {
112        is_cyclic_directed(&self.graph)
113    }
114
115    /// Get the number of tasks in the graph
116    #[must_use]
117    pub fn task_count(&self) -> usize {
118        self.graph.node_count()
119    }
120
121    /// Get tasks grouped by dependency level for parallel execution
122    ///
123    /// Returns groups where all tasks in a group can execute concurrently
124    /// because they have no dependencies on each other.
125    ///
126    /// # Errors
127    /// Returns error if topological sort fails (shouldn't happen if cycle
128    /// check passed).
129    pub fn get_parallel_groups(&self) -> Result<Vec<Vec<&CITaskNode>>, GraphError> {
130        // Topological sort
131        let sorted_indices = toposort(&self.graph, None).map_err(|_| GraphError::SortFailed)?;
132
133        if sorted_indices.is_empty() {
134            return Ok(vec![]);
135        }
136
137        // Group tasks by their dependency level
138        let mut groups: Vec<Vec<&CITaskNode>> = vec![];
139        let mut processed: HashMap<&str, usize> = HashMap::new();
140
141        for node_index in sorted_indices {
142            let node = &self.graph[node_index];
143
144            // Find the maximum level of all dependencies
145            let mut level = 0;
146            for dep_id in &node.task.depends_on {
147                if let Some(&dep_level) = processed.get(dep_id.as_str()) {
148                    level = level.max(dep_level + 1);
149                }
150            }
151
152            // Add to appropriate group
153            if level >= groups.len() {
154                groups.resize_with(level + 1, Vec::new);
155            }
156            groups[level].push(node);
157            processed.insert(&node.id, level);
158        }
159
160        Ok(groups)
161    }
162
163    /// Compute digests for all tasks after secret resolution
164    ///
165    /// This must be called after secrets are resolved to include secret
166    /// fingerprints in the digest computation.
167    ///
168    /// # Arguments
169    /// * `ir` - The IR document (for runtime lookups)
170    /// * `secret_fingerprints` - Map of `task_id` -> (`secret_name` -> fingerprint)
171    /// * `system_salt` - Optional system salt for secret HMAC
172    pub fn compute_digests(
173        &mut self,
174        ir: &IntermediateRepresentation,
175        secret_fingerprints: &HashMap<String, HashMap<String, String>>,
176        system_salt: Option<&str>,
177    ) {
178        for node_index in self.graph.node_indices() {
179            let node = &self.graph[node_index];
180            let task = &node.task;
181
182            // Get runtime digest if task has runtime
183            let runtime_digest = task
184                .runtime
185                .as_ref()
186                .and_then(|rid| ir.runtimes.iter().find(|r| &r.id == rid))
187                .map(|r| r.digest.as_str());
188
189            // Get secret fingerprints for this task
190            let task_fingerprints = secret_fingerprints.get(&task.id);
191
192            let digest = compute_task_digest(
193                &task.command,
194                &task.env,
195                &task.inputs,
196                runtime_digest,
197                task_fingerprints,
198                system_salt,
199            );
200
201            // Update the node's digest
202            self.graph[node_index].digest = digest;
203        }
204    }
205
206    /// Get a task node by ID
207    #[must_use]
208    pub fn get_task(&self, id: &str) -> Option<&CITaskNode> {
209        self.id_to_index.get(id).map(|&idx| &self.graph[idx])
210    }
211
212    /// Get all task IDs in the graph
213    #[must_use]
214    pub fn task_ids(&self) -> Vec<&str> {
215        self.graph
216            .node_references()
217            .map(|(_, n)| n.id.as_str())
218            .collect()
219    }
220
221    /// Propagate `cache_policy`: disabled to tasks that transitively depend on deployment tasks
222    ///
223    /// According to PRD v1.3, tasks depending on deployments should inherit
224    /// `cache_policy`: disabled for that execution to ensure deployment ordering
225    /// is always respected.
226    ///
227    /// # Returns
228    /// List of task IDs that had their cache policy changed
229    pub fn propagate_deployment_cache_policy(&mut self) -> Vec<String> {
230        use crate::ir::CachePolicy;
231        use petgraph::visit::Dfs;
232
233        let mut changed = Vec::new();
234
235        // Find all deployment task indices
236        let deployment_indices: Vec<NodeIndex> = self
237            .graph
238            .node_indices()
239            .filter(|&idx| self.graph[idx].task.deployment)
240            .collect();
241
242        // For each deployment task, traverse all descendants and mark them as disabled
243        for deploy_idx in deployment_indices {
244            // Use DFS to find all tasks that depend on this deployment
245            // We need to traverse in reverse direction (dependents of deployment)
246            let mut dfs = Dfs::new(&self.graph, deploy_idx);
247            while let Some(node_idx) = dfs.next(&self.graph) {
248                // Skip the deployment task itself (already marked disabled by validation)
249                if node_idx == deploy_idx {
250                    continue;
251                }
252
253                let node = &mut self.graph[node_idx];
254                if node.task.cache_policy != CachePolicy::Disabled {
255                    tracing::debug!(
256                        task = %node.id,
257                        reason = "depends on deployment task",
258                        "Setting cache_policy to disabled"
259                    );
260                    node.task.cache_policy = CachePolicy::Disabled;
261                    changed.push(node.id.clone());
262                }
263            }
264        }
265
266        // Deduplicate (task may depend on multiple deployment tasks)
267        changed.sort();
268        changed.dedup();
269
270        if !changed.is_empty() {
271            tracing::info!(
272                count = changed.len(),
273                tasks = ?changed,
274                "Disabled caching for tasks depending on deployments"
275            );
276        }
277
278        changed
279    }
280
281    /// Check if a task transitively depends on any deployment task
282    #[must_use]
283    pub fn depends_on_deployment(&self, task_id: &str) -> bool {
284        use petgraph::algo::has_path_connecting;
285
286        let Some(&task_idx) = self.id_to_index.get(task_id) else {
287            return false;
288        };
289
290        // Check if there's a path from any deployment task to this task
291        for node_idx in self.graph.node_indices() {
292            if self.graph[node_idx].task.deployment
293                && node_idx != task_idx
294                && has_path_connecting(&self.graph, node_idx, task_idx, None)
295            {
296                return true;
297            }
298        }
299
300        false
301    }
302}
303
304#[cfg(test)]
305mod tests {
306    use super::*;
307    use crate::ir::{CachePolicy, StageConfiguration, Task};
308
309    fn make_task(id: &str, deps: &[&str]) -> Task {
310        Task {
311            id: id.to_string(),
312            runtime: None,
313            command: vec!["echo".to_string(), id.to_string()],
314            shell: false,
315            env: HashMap::new(),
316            secrets: HashMap::new(),
317            resources: None,
318            concurrency_group: None,
319            inputs: vec![],
320            outputs: vec![],
321            depends_on: deps.iter().map(|s| (*s).to_string()).collect(),
322            cache_policy: CachePolicy::Normal,
323            deployment: false,
324            manual_approval: false,
325        }
326    }
327
328    fn make_ir(tasks: Vec<Task>) -> IntermediateRepresentation {
329        IntermediateRepresentation {
330            version: "1.4".to_string(),
331            pipeline: crate::ir::PipelineMetadata {
332                name: "test".to_string(),
333                environment: None,
334                requires_onepassword: false,
335                project_name: None,
336                trigger: None,
337            },
338            runtimes: vec![],
339            stages: StageConfiguration::default(),
340            tasks,
341        }
342    }
343
344    #[test]
345    fn test_simple_graph() {
346        let ir = make_ir(vec![make_task("build", &[]), make_task("test", &["build"])]);
347
348        let graph = CITaskGraph::from_ir(&ir).unwrap();
349        assert_eq!(graph.task_count(), 2);
350        assert!(!graph.has_cycles());
351    }
352
353    #[test]
354    fn test_parallel_groups_linear() {
355        // build -> test -> deploy (all sequential)
356        let ir = make_ir(vec![
357            make_task("build", &[]),
358            make_task("test", &["build"]),
359            make_task("deploy", &["test"]),
360        ]);
361
362        let graph = CITaskGraph::from_ir(&ir).unwrap();
363        let groups = graph.get_parallel_groups().unwrap();
364
365        assert_eq!(groups.len(), 3);
366        assert_eq!(groups[0].len(), 1); // build
367        assert_eq!(groups[1].len(), 1); // test
368        assert_eq!(groups[2].len(), 1); // deploy
369    }
370
371    #[test]
372    fn test_parallel_groups_diamond() {
373        // build -> test1 -\
374        //       -> test2 -/-> deploy
375        let ir = make_ir(vec![
376            make_task("build", &[]),
377            make_task("test1", &["build"]),
378            make_task("test2", &["build"]),
379            make_task("deploy", &["test1", "test2"]),
380        ]);
381
382        let graph = CITaskGraph::from_ir(&ir).unwrap();
383        let groups = graph.get_parallel_groups().unwrap();
384
385        assert_eq!(groups.len(), 3);
386        assert_eq!(groups[0].len(), 1); // build
387        assert_eq!(groups[1].len(), 2); // test1, test2 (parallel)
388        assert_eq!(groups[2].len(), 1); // deploy
389    }
390
391    #[test]
392    fn test_cycle_detection() {
393        let ir = make_ir(vec![
394            make_task("a", &["c"]),
395            make_task("b", &["a"]),
396            make_task("c", &["b"]),
397        ]);
398
399        let result = CITaskGraph::from_ir(&ir);
400        assert!(matches!(result, Err(GraphError::CyclicDependency { .. })));
401    }
402
403    #[test]
404    fn test_missing_dependency() {
405        let ir = make_ir(vec![make_task("test", &["nonexistent"])]);
406
407        let result = CITaskGraph::from_ir(&ir);
408        assert!(matches!(result, Err(GraphError::MissingDependency { .. })));
409    }
410
411    #[test]
412    fn test_digest_computation() {
413        let ir = make_ir(vec![make_task("build", &[])]);
414
415        let mut graph = CITaskGraph::from_ir(&ir).unwrap();
416        graph.compute_digests(&ir, &HashMap::new(), None);
417
418        let task = graph.get_task("build").unwrap();
419        assert!(!task.digest.is_empty());
420        assert!(task.digest.starts_with("sha256:"));
421    }
422
423    #[test]
424    fn test_independent_tasks_same_group() {
425        // Three independent tasks should all be in level 0
426        let ir = make_ir(vec![
427            make_task("task1", &[]),
428            make_task("task2", &[]),
429            make_task("task3", &[]),
430        ]);
431
432        let graph = CITaskGraph::from_ir(&ir).unwrap();
433        let groups = graph.get_parallel_groups().unwrap();
434
435        assert_eq!(groups.len(), 1);
436        assert_eq!(groups[0].len(), 3);
437    }
438}