trueno/brick/exec_graph/traversal/
analysis.rs1use std::collections::HashMap;
4
5use super::core::ExecutionGraph;
6use crate::brick::exec_graph::node::{EdgeType, ExecutionNode, ExecutionNodeId, TransferDirection};
7
8impl ExecutionGraph {
9 fn node_timing_ns(&self, id: ExecutionNodeId) -> u64 {
15 debug_assert!(
16 (id.0 as usize) < self.nodes.len(),
17 "CB-BUDGET: node id {} out of bounds (graph has {} nodes)",
18 id.0,
19 self.nodes.len()
20 );
21 match &self.nodes[id.0 as usize] {
22 ExecutionNode::Brick { timing_ns, .. } => *timing_ns,
23 ExecutionNode::Kernel { timing_ns, .. } => timing_ns.unwrap_or(0),
24 ExecutionNode::Transfer { timing_ns, .. } => timing_ns.unwrap_or(0),
25 ExecutionNode::Function { .. }
26 | ExecutionNode::Layer { .. }
27 | ExecutionNode::AsyncTask { .. } => 0,
28 }
29 }
30
31 pub fn critical_path(&self) -> (Vec<ExecutionNodeId>, u64) {
38 if self.nodes.is_empty() {
39 return (vec![], 0);
40 }
41
42 let mut adj: Vec<Vec<(u32, u64)>> = vec![vec![]; self.nodes.len()];
44 for edge in &self.edges {
45 match &edge.edge_type {
46 EdgeType::DependsOn | EdgeType::Sequence => {
47 let weight = self.node_timing_ns(edge.dst);
48 adj[edge.src.0 as usize].push((edge.dst.0, weight));
49 }
50 EdgeType::Contains | EdgeType::Calls | EdgeType::Launches => {
51 let weight = self.node_timing_ns(edge.dst);
53 adj[edge.src.0 as usize].push((edge.dst.0, weight));
54 }
55 EdgeType::Transfer { .. } => {
56 let weight = self.node_timing_ns(edge.dst);
58 adj[edge.src.0 as usize].push((edge.dst.0, weight));
59 }
60 }
61 }
62
63 let mut in_degree = vec![0u32; self.nodes.len()];
65 for edges in &adj {
66 for (dst, _) in edges {
67 in_degree[*dst as usize] += 1;
68 }
69 }
70
71 let mut queue: Vec<u32> =
72 (0..self.nodes.len() as u32).filter(|&i| in_degree[i as usize] == 0).collect();
73 let mut topo_order = Vec::with_capacity(self.nodes.len());
74
75 while let Some(u) = queue.pop() {
76 topo_order.push(u);
77 for (v, _) in &adj[u as usize] {
78 in_degree[*v as usize] -= 1;
79 if in_degree[*v as usize] == 0 {
80 queue.push(*v);
81 }
82 }
83 }
84
85 let mut dist = vec![0u64; self.nodes.len()];
87 let mut pred = vec![None::<u32>; self.nodes.len()];
88
89 for &node in &topo_order {
91 if self.edges.iter().all(|e| e.dst.0 != node) {
92 dist[node as usize] = self.node_timing_ns(ExecutionNodeId(node));
93 }
94 }
95
96 for &u in &topo_order {
97 for (v, weight) in &adj[u as usize] {
98 let new_dist = dist[u as usize] + weight;
99 if new_dist > dist[*v as usize] {
100 dist[*v as usize] = new_dist;
101 pred[*v as usize] = Some(u);
102 }
103 }
104 }
105
106 let (end_node, &total_time) =
108 dist.iter().enumerate().max_by_key(|(_, &d)| d).unwrap_or((0, &0));
109
110 let mut path = vec![];
112 let mut current = Some(end_node as u32);
113 while let Some(node) = current {
114 path.push(ExecutionNodeId(node));
115 current = pred[node as usize];
116 }
117 path.reverse();
118
119 (path, total_time)
120 }
121
122 pub fn compute_slack(&self) -> HashMap<ExecutionNodeId, u64> {
126 let (critical_path, total_time) = self.critical_path();
127 let critical_set: std::collections::HashSet<_> = critical_path.iter().copied().collect();
128
129 let mut slack = HashMap::new();
130
131 let mut reverse_adj: Vec<Vec<u32>> = vec![vec![]; self.nodes.len()];
133 for edge in &self.edges {
134 reverse_adj[edge.dst.0 as usize].push(edge.src.0);
135 }
136
137 let mut earliest = vec![0u64; self.nodes.len()];
139 for i in 0..self.nodes.len() {
140 let mut max_pred = 0u64;
141 for &pred in &reverse_adj[i] {
142 max_pred = max_pred
143 .max(earliest[pred as usize] + self.node_timing_ns(ExecutionNodeId(pred)));
144 }
145 earliest[i] = max_pred;
146 }
147
148 let mut latest = vec![total_time; self.nodes.len()];
150 for i in (0..self.nodes.len()).rev() {
151 let timing = self.node_timing_ns(ExecutionNodeId(i as u32));
152 let mut min_succ = total_time;
153 for edge in &self.edges {
154 if edge.src.0 == i as u32 {
155 min_succ = min_succ.min(latest[edge.dst.0 as usize]);
156 }
157 }
158 latest[i] = min_succ.saturating_sub(timing);
159 }
160
161 for i in 0..self.nodes.len() {
163 let node_id = ExecutionNodeId(i as u32);
164 let node_slack = if critical_set.contains(&node_id) {
165 0
166 } else {
167 latest[i].saturating_sub(earliest[i])
168 };
169 slack.insert(node_id, node_slack);
170 }
171
172 slack
173 }
174
175 pub fn roofline_distance(
182 &self,
183 peak_tflops: f32,
184 peak_bandwidth_gb_s: f32,
185 ) -> HashMap<ExecutionNodeId, f32> {
186 let mut distances = HashMap::new();
187
188 for (i, node) in self.nodes.iter().enumerate() {
189 if let ExecutionNode::Kernel { arithmetic_intensity, achieved_tflops, .. } = node {
190 if let (Some(ai), Some(achieved)) = (arithmetic_intensity, achieved_tflops) {
191 let bandwidth_bound = *ai * peak_bandwidth_gb_s / 1000.0; let roofline_bound = peak_tflops.min(bandwidth_bound);
194 let efficiency = achieved / roofline_bound;
195 let distance = 1.0 - efficiency.min(1.0);
196 distances.insert(ExecutionNodeId(i as u32), distance);
197 }
198 }
199 }
200
201 distances
202 }
203
204 pub fn detect_ping_pong(&self) -> Vec<(ExecutionNodeId, ExecutionNodeId)> {
208 let mut patterns = Vec::new();
209
210 let transfers: Vec<(usize, &ExecutionNode)> = self
212 .nodes
213 .iter()
214 .enumerate()
215 .filter(|(_, n)| matches!(n, ExecutionNode::Transfer { .. }))
216 .collect();
217
218 for i in 0..transfers.len() {
220 for j in (i + 1)..transfers.len() {
221 if let (
222 ExecutionNode::Transfer {
223 src: src1,
224 dst: dst1,
225 direction: dir1,
226 bytes: bytes1,
227 ..
228 },
229 ExecutionNode::Transfer {
230 src: src2,
231 dst: dst2,
232 direction: dir2,
233 bytes: bytes2,
234 ..
235 },
236 ) = (&transfers[i].1, &transfers[j].1)
237 {
238 let is_ping_pong = (*dir1 == TransferDirection::H2D
240 && *dir2 == TransferDirection::D2H
241 && dst1 == src2
242 && bytes1 == bytes2)
243 || (*dir1 == TransferDirection::D2H
244 && *dir2 == TransferDirection::H2D
245 && src1 == dst2
246 && bytes1 == bytes2);
247
248 if is_ping_pong {
249 patterns.push((
250 ExecutionNodeId(transfers[i].0 as u32),
251 ExecutionNodeId(transfers[j].0 as u32),
252 ));
253 }
254 }
255 }
256 }
257
258 patterns
259 }
260
261 pub fn critical_path_summary(&self) -> String {
263 let (path, total_ns) = self.critical_path();
264 let slack = self.compute_slack();
265
266 let mut output = String::new();
267 output.push_str(&format!(
268 "Critical Path: {:.2}ms ({} nodes)\n",
269 total_ns as f64 / 1_000_000.0,
270 path.len()
271 ));
272 output.push_str("─".repeat(50).as_str());
273 output.push('\n');
274
275 for (i, node_id) in path.iter().enumerate() {
276 let node = &self.nodes[node_id.0 as usize];
277 let timing = self.node_timing_ns(*node_id);
278 let node_name = Self::format_node_name(node);
279
280 let prefix = if i == 0 {
281 "┌"
282 } else if i == path.len() - 1 {
283 "└"
284 } else {
285 "│"
286 };
287 output.push_str(&format!(
288 "{} {} ({:.1}µs)\n",
289 prefix,
290 node_name,
291 timing as f64 / 1000.0
292 ));
293 }
294
295 let mut slack_vec: Vec<_> = slack.iter().collect();
297 slack_vec.sort_by(|a, b| b.1.cmp(a.1));
298
299 if slack_vec.iter().any(|(_, &s)| s > 0) {
300 output.push_str("\nParallelization Opportunities (high slack):\n");
301 for (node_id, &node_slack) in slack_vec.iter().take(5) {
302 if node_slack > 0 {
303 let node = &self.nodes[node_id.0 as usize];
304 let node_name = Self::format_node_name(node);
305 output.push_str(&format!(
306 " {} slack={:.1}µs\n",
307 node_name,
308 node_slack as f64 / 1000.0
309 ));
310 }
311 }
312 }
313
314 output
315 }
316
317 fn format_node_name(node: &ExecutionNode) -> String {
319 match node {
320 ExecutionNode::Layer { index } => format!("Layer {}", index),
321 ExecutionNode::Brick { id, .. } => id.name().to_string(),
322 ExecutionNode::Kernel { name, .. } => name.clone(),
323 ExecutionNode::Function { name, .. } => name.clone(),
324 ExecutionNode::Transfer { direction, src, dst, .. } => {
325 format!("{:?} {} → {}", direction, src, dst)
326 }
327 ExecutionNode::AsyncTask { name, poll_count, .. } => {
328 format!("{} ({}polls)", name, poll_count)
329 }
330 }
331 }
332}