Skip to main content

trueno/brick/exec_graph/traversal/
analysis.rs

1//! ExecutionGraph analysis: critical path, slack, roofline, ping-pong detection.
2
3use std::collections::HashMap;
4
5use super::core::ExecutionGraph;
6use crate::brick::exec_graph::node::{EdgeType, ExecutionNode, ExecutionNodeId, TransferDirection};
7
8impl ExecutionGraph {
9    // ========================
10    // Phase 9: Critical Path Analysis (CPA)
11    // ========================
12
13    /// Get timing for a node (ns). Returns 0 for non-timed nodes.
14    fn node_timing_ns(&self, id: ExecutionNodeId) -> u64 {
15        debug_assert!(
16            (id.0 as usize) < self.nodes.len(),
17            "CB-BUDGET: node id {} out of bounds (graph has {} nodes)",
18            id.0,
19            self.nodes.len()
20        );
21        match &self.nodes[id.0 as usize] {
22            ExecutionNode::Brick { timing_ns, .. } => *timing_ns,
23            ExecutionNode::Kernel { timing_ns, .. } => timing_ns.unwrap_or(0),
24            ExecutionNode::Transfer { timing_ns, .. } => timing_ns.unwrap_or(0),
25            ExecutionNode::Function { .. }
26            | ExecutionNode::Layer { .. }
27            | ExecutionNode::AsyncTask { .. } => 0,
28        }
29    }
30
31    /// Compute critical path through execution graph using longest-path algorithm.
32    ///
33    /// Returns (critical_path_nodes, total_time_ns). The critical path represents
34    /// the longest chain of dependencies that determines total execution time.
35    ///
36    /// Reference: Graham et al. (1979) "Scheduling Algorithms for Multi-Processor Systems"
37    pub fn critical_path(&self) -> (Vec<ExecutionNodeId>, u64) {
38        if self.nodes.is_empty() {
39            return (vec![], 0);
40        }
41
42        // Build adjacency list for DependsOn and Sequence edges
43        let mut adj: Vec<Vec<(u32, u64)>> = vec![vec![]; self.nodes.len()];
44        for edge in &self.edges {
45            match &edge.edge_type {
46                EdgeType::DependsOn | EdgeType::Sequence => {
47                    let weight = self.node_timing_ns(edge.dst);
48                    adj[edge.src.0 as usize].push((edge.dst.0, weight));
49                }
50                EdgeType::Contains | EdgeType::Calls | EdgeType::Launches => {
51                    // Hierarchical edges: children contribute to parent time
52                    let weight = self.node_timing_ns(edge.dst);
53                    adj[edge.src.0 as usize].push((edge.dst.0, weight));
54                }
55                EdgeType::Transfer { .. } => {
56                    // Transfer edges carry their own timing
57                    let weight = self.node_timing_ns(edge.dst);
58                    adj[edge.src.0 as usize].push((edge.dst.0, weight));
59                }
60            }
61        }
62
63        // Topological sort using Kahn's algorithm
64        let mut in_degree = vec![0u32; self.nodes.len()];
65        for edges in &adj {
66            for (dst, _) in edges {
67                in_degree[*dst as usize] += 1;
68            }
69        }
70
71        let mut queue: Vec<u32> =
72            (0..self.nodes.len() as u32).filter(|&i| in_degree[i as usize] == 0).collect();
73        let mut topo_order = Vec::with_capacity(self.nodes.len());
74
75        while let Some(u) = queue.pop() {
76            topo_order.push(u);
77            for (v, _) in &adj[u as usize] {
78                in_degree[*v as usize] -= 1;
79                if in_degree[*v as usize] == 0 {
80                    queue.push(*v);
81                }
82            }
83        }
84
85        // Longest path DP
86        let mut dist = vec![0u64; self.nodes.len()];
87        let mut pred = vec![None::<u32>; self.nodes.len()];
88
89        // Initialize with node's own timing for roots
90        for &node in &topo_order {
91            if self.edges.iter().all(|e| e.dst.0 != node) {
92                dist[node as usize] = self.node_timing_ns(ExecutionNodeId(node));
93            }
94        }
95
96        for &u in &topo_order {
97            for (v, weight) in &adj[u as usize] {
98                let new_dist = dist[u as usize] + weight;
99                if new_dist > dist[*v as usize] {
100                    dist[*v as usize] = new_dist;
101                    pred[*v as usize] = Some(u);
102                }
103            }
104        }
105
106        // Find endpoint with maximum distance
107        let (end_node, &total_time) =
108            dist.iter().enumerate().max_by_key(|(_, &d)| d).unwrap_or((0, &0));
109
110        // Reconstruct path
111        let mut path = vec![];
112        let mut current = Some(end_node as u32);
113        while let Some(node) = current {
114            path.push(ExecutionNodeId(node));
115            current = pred[node as usize];
116        }
117        path.reverse();
118
119        (path, total_time)
120    }
121
122    /// Compute slack for each node (how much it can be delayed without affecting total time).
123    ///
124    /// Returns map from node ID to slack in nanoseconds. Nodes on critical path have slack = 0.
125    pub fn compute_slack(&self) -> HashMap<ExecutionNodeId, u64> {
126        let (critical_path, total_time) = self.critical_path();
127        let critical_set: std::collections::HashSet<_> = critical_path.iter().copied().collect();
128
129        let mut slack = HashMap::new();
130
131        // Build reverse adjacency
132        let mut reverse_adj: Vec<Vec<u32>> = vec![vec![]; self.nodes.len()];
133        for edge in &self.edges {
134            reverse_adj[edge.dst.0 as usize].push(edge.src.0);
135        }
136
137        // Forward pass: earliest start time
138        let mut earliest = vec![0u64; self.nodes.len()];
139        for i in 0..self.nodes.len() {
140            let mut max_pred = 0u64;
141            for &pred in &reverse_adj[i] {
142                max_pred = max_pred
143                    .max(earliest[pred as usize] + self.node_timing_ns(ExecutionNodeId(pred)));
144            }
145            earliest[i] = max_pred;
146        }
147
148        // Backward pass: latest start time
149        let mut latest = vec![total_time; self.nodes.len()];
150        for i in (0..self.nodes.len()).rev() {
151            let timing = self.node_timing_ns(ExecutionNodeId(i as u32));
152            let mut min_succ = total_time;
153            for edge in &self.edges {
154                if edge.src.0 == i as u32 {
155                    min_succ = min_succ.min(latest[edge.dst.0 as usize]);
156                }
157            }
158            latest[i] = min_succ.saturating_sub(timing);
159        }
160
161        // Slack = latest - earliest
162        for i in 0..self.nodes.len() {
163            let node_id = ExecutionNodeId(i as u32);
164            let node_slack = if critical_set.contains(&node_id) {
165                0
166            } else {
167                latest[i].saturating_sub(earliest[i])
168            };
169            slack.insert(node_id, node_slack);
170        }
171
172        slack
173    }
174
175    /// Compute roofline distance for kernel nodes.
176    ///
177    /// Returns map from kernel node ID to distance from roofline (0.0 = optimal).
178    /// Distance = 1.0 - min(achieved/peak_compute, achieved/peak_bandwidth).
179    ///
180    /// Reference: Williams et al. (2009) "Roofline: An Insightful Visual Performance Model"
181    pub fn roofline_distance(
182        &self,
183        peak_tflops: f32,
184        peak_bandwidth_gb_s: f32,
185    ) -> HashMap<ExecutionNodeId, f32> {
186        let mut distances = HashMap::new();
187
188        for (i, node) in self.nodes.iter().enumerate() {
189            if let ExecutionNode::Kernel { arithmetic_intensity, achieved_tflops, .. } = node {
190                if let (Some(ai), Some(achieved)) = (arithmetic_intensity, achieved_tflops) {
191                    // Roofline model: achievable = min(peak_compute, ai * bandwidth)
192                    let bandwidth_bound = *ai * peak_bandwidth_gb_s / 1000.0; // Convert GB/s to TFLOP/s
193                    let roofline_bound = peak_tflops.min(bandwidth_bound);
194                    let efficiency = achieved / roofline_bound;
195                    let distance = 1.0 - efficiency.min(1.0);
196                    distances.insert(ExecutionNodeId(i as u32), distance);
197                }
198            }
199        }
200
201        distances
202    }
203
204    /// Detect ping-pong memory transfer patterns (wasteful H2D followed by D2H).
205    ///
206    /// Returns pairs of transfer node IDs that exhibit ping-pong behavior.
207    pub fn detect_ping_pong(&self) -> Vec<(ExecutionNodeId, ExecutionNodeId)> {
208        let mut patterns = Vec::new();
209
210        // Find transfer nodes
211        let transfers: Vec<(usize, &ExecutionNode)> = self
212            .nodes
213            .iter()
214            .enumerate()
215            .filter(|(_, n)| matches!(n, ExecutionNode::Transfer { .. }))
216            .collect();
217
218        // Check for H2D followed by D2H on same data
219        for i in 0..transfers.len() {
220            for j in (i + 1)..transfers.len() {
221                if let (
222                    ExecutionNode::Transfer {
223                        src: src1,
224                        dst: dst1,
225                        direction: dir1,
226                        bytes: bytes1,
227                        ..
228                    },
229                    ExecutionNode::Transfer {
230                        src: src2,
231                        dst: dst2,
232                        direction: dir2,
233                        bytes: bytes2,
234                        ..
235                    },
236                ) = (&transfers[i].1, &transfers[j].1)
237                {
238                    // Ping-pong: H2D then D2H with matching src/dst and same size
239                    let is_ping_pong = (*dir1 == TransferDirection::H2D
240                        && *dir2 == TransferDirection::D2H
241                        && dst1 == src2
242                        && bytes1 == bytes2)
243                        || (*dir1 == TransferDirection::D2H
244                            && *dir2 == TransferDirection::H2D
245                            && src1 == dst2
246                            && bytes1 == bytes2);
247
248                    if is_ping_pong {
249                        patterns.push((
250                            ExecutionNodeId(transfers[i].0 as u32),
251                            ExecutionNodeId(transfers[j].0 as u32),
252                        ));
253                    }
254                }
255            }
256        }
257
258        patterns
259    }
260
261    /// Get critical path analysis summary as formatted string.
262    pub fn critical_path_summary(&self) -> String {
263        let (path, total_ns) = self.critical_path();
264        let slack = self.compute_slack();
265
266        let mut output = String::new();
267        output.push_str(&format!(
268            "Critical Path: {:.2}ms ({} nodes)\n",
269            total_ns as f64 / 1_000_000.0,
270            path.len()
271        ));
272        output.push_str("─".repeat(50).as_str());
273        output.push('\n');
274
275        for (i, node_id) in path.iter().enumerate() {
276            let node = &self.nodes[node_id.0 as usize];
277            let timing = self.node_timing_ns(*node_id);
278            let node_name = Self::format_node_name(node);
279
280            let prefix = if i == 0 {
281                "┌"
282            } else if i == path.len() - 1 {
283                "└"
284            } else {
285                "│"
286            };
287            output.push_str(&format!(
288                "{} {} ({:.1}µs)\n",
289                prefix,
290                node_name,
291                timing as f64 / 1000.0
292            ));
293        }
294
295        // Show nodes with most slack (parallelization opportunities)
296        let mut slack_vec: Vec<_> = slack.iter().collect();
297        slack_vec.sort_by(|a, b| b.1.cmp(a.1));
298
299        if slack_vec.iter().any(|(_, &s)| s > 0) {
300            output.push_str("\nParallelization Opportunities (high slack):\n");
301            for (node_id, &node_slack) in slack_vec.iter().take(5) {
302                if node_slack > 0 {
303                    let node = &self.nodes[node_id.0 as usize];
304                    let node_name = Self::format_node_name(node);
305                    output.push_str(&format!(
306                        "  {} slack={:.1}µs\n",
307                        node_name,
308                        node_slack as f64 / 1000.0
309                    ));
310                }
311            }
312        }
313
314        output
315    }
316
317    /// Format a node name for display in critical path summaries.
318    fn format_node_name(node: &ExecutionNode) -> String {
319        match node {
320            ExecutionNode::Layer { index } => format!("Layer {}", index),
321            ExecutionNode::Brick { id, .. } => id.name().to_string(),
322            ExecutionNode::Kernel { name, .. } => name.clone(),
323            ExecutionNode::Function { name, .. } => name.clone(),
324            ExecutionNode::Transfer { direction, src, dst, .. } => {
325                format!("{:?} {} → {}", direction, src, dst)
326            }
327            ExecutionNode::AsyncTask { name, poll_count, .. } => {
328                format!("{} ({}polls)", name, poll_count)
329            }
330        }
331    }
332}