trueno/brick/exec_graph/traversal/
core.rs1use std::collections::HashMap;
4
5use crate::brick::exec_graph::node::{
6 EdgeType, ExecutionEdge, ExecutionNode, ExecutionNodeId, TransferDirection,
7};
8
9#[derive(Debug, Default)]
46pub struct ExecutionGraph {
47 pub(crate) nodes: Vec<ExecutionNode>,
49 pub(crate) edges: Vec<ExecutionEdge>,
51 pub(crate) scope_stack: Vec<ExecutionNodeId>,
53 pub(crate) name_to_id: HashMap<String, ExecutionNodeId>,
55}
56
57impl ExecutionGraph {
58 pub fn new() -> Self {
60 Self::default()
61 }
62
63 pub fn add_node(&mut self, node: ExecutionNode) -> ExecutionNodeId {
65 let id = ExecutionNodeId(self.nodes.len() as u32);
66 let name = node.name();
67 self.name_to_id.insert(name, id);
68 self.nodes.push(node);
69 id
70 }
71
72 pub fn add_edge(&mut self, src: ExecutionNodeId, dst: ExecutionNodeId, edge_type: EdgeType) {
74 debug_assert!(
75 (src.0 as usize) < self.nodes.len(),
76 "CB-BUDGET: src node {} does not exist (graph has {} nodes)",
77 src.0,
78 self.nodes.len()
79 );
80 debug_assert!(
81 (dst.0 as usize) < self.nodes.len(),
82 "CB-BUDGET: dst node {} does not exist (graph has {} nodes)",
83 dst.0,
84 self.nodes.len()
85 );
86 self.edges.push(ExecutionEdge { src, dst, edge_type, weight: 1.0 });
87 }
88
89 pub fn add_weighted_edge(
91 &mut self,
92 src: ExecutionNodeId,
93 dst: ExecutionNodeId,
94 edge_type: EdgeType,
95 weight: f32,
96 ) {
97 self.edges.push(ExecutionEdge { src, dst, edge_type, weight });
98 }
99
100 pub fn push_scope(&mut self, node: ExecutionNode) -> ExecutionNodeId {
103 let id = self.add_node(node);
104 if let Some(&parent) = self.scope_stack.last() {
105 self.add_edge(parent, id, EdgeType::Contains);
106 }
107 self.scope_stack.push(id);
108 id
109 }
110
111 pub fn pop_scope(&mut self) -> Option<ExecutionNodeId> {
113 self.scope_stack.pop()
114 }
115
116 pub fn current_scope(&self) -> Option<ExecutionNodeId> {
118 self.scope_stack.last().copied()
119 }
120
121 pub fn add_node_in_scope(&mut self, node: ExecutionNode) -> ExecutionNodeId {
123 let id = self.add_node(node);
124 if let Some(&parent) = self.scope_stack.last() {
125 self.add_edge(parent, id, EdgeType::Contains);
126 }
127 id
128 }
129
130 pub fn record_kernel_launch(
132 &mut self,
133 name: &str,
134 ptx_hash: u64,
135 grid: (u32, u32, u32),
136 block: (u32, u32, u32),
137 shared_mem: u32,
138 ) -> ExecutionNodeId {
139 debug_assert!(grid.0 > 0 && grid.1 > 0 && grid.2 > 0, "CB-BUDGET: grid dims must be > 0");
140 debug_assert!(
141 block.0 > 0 && block.1 > 0 && block.2 > 0,
142 "CB-BUDGET: block dims must be > 0"
143 );
144 let kernel = ExecutionNode::Kernel {
145 name: name.to_string(),
146 ptx_hash,
147 grid,
148 block,
149 shared_mem,
150 timing_ns: None,
151 arithmetic_intensity: None,
152 achieved_tflops: None,
153 };
154 let kernel_id = self.add_node(kernel);
155
156 if let Some(&parent) = self.scope_stack.last() {
158 self.add_edge(parent, kernel_id, EdgeType::Launches);
159 }
160
161 kernel_id
162 }
163
164 #[allow(clippy::too_many_arguments)]
166 pub fn record_kernel_launch_with_metrics(
167 &mut self,
168 name: &str,
169 ptx_hash: u64,
170 grid: (u32, u32, u32),
171 block: (u32, u32, u32),
172 shared_mem: u32,
173 timing_ns: u64,
174 arithmetic_intensity: f32,
175 achieved_tflops: f32,
176 ) -> ExecutionNodeId {
177 let kernel = ExecutionNode::Kernel {
178 name: name.to_string(),
179 ptx_hash,
180 grid,
181 block,
182 shared_mem,
183 timing_ns: Some(timing_ns),
184 arithmetic_intensity: Some(arithmetic_intensity),
185 achieved_tflops: Some(achieved_tflops),
186 };
187 let kernel_id = self.add_node(kernel);
188
189 if let Some(&parent) = self.scope_stack.last() {
190 self.add_edge(parent, kernel_id, EdgeType::Launches);
191 }
192
193 kernel_id
194 }
195
196 pub fn record_transfer(
198 &mut self,
199 src: &str,
200 dst: &str,
201 bytes: u64,
202 direction: TransferDirection,
203 timing_ns: Option<u64>,
204 ) -> ExecutionNodeId {
205 let transfer = ExecutionNode::Transfer {
206 src: src.to_string(),
207 dst: dst.to_string(),
208 bytes,
209 direction,
210 timing_ns,
211 };
212 let transfer_id = self.add_node(transfer);
213
214 if let Some(&parent) = self.scope_stack.last() {
215 self.add_edge(parent, transfer_id, EdgeType::Contains);
216 }
217
218 transfer_id
219 }
220
221 pub fn add_dependency(&mut self, from: ExecutionNodeId, to: ExecutionNodeId) {
223 self.add_edge(from, to, EdgeType::DependsOn);
224 }
225
226 pub fn node(&self, id: ExecutionNodeId) -> Option<&ExecutionNode> {
228 self.nodes.get(id.0 as usize)
229 }
230
231 pub fn node_by_name(&self, name: &str) -> Option<(ExecutionNodeId, &ExecutionNode)> {
233 self.name_to_id.get(name).and_then(|&id| self.nodes.get(id.0 as usize).map(|n| (id, n)))
234 }
235
236 pub fn nodes(&self) -> &[ExecutionNode] {
238 &self.nodes
239 }
240
241 pub fn edges(&self) -> &[ExecutionEdge] {
243 &self.edges
244 }
245
246 pub fn num_nodes(&self) -> usize {
248 self.nodes.len()
249 }
250
251 pub fn num_edges(&self) -> usize {
253 self.edges.len()
254 }
255
256 pub fn outgoing_edges(&self, node: ExecutionNodeId) -> impl Iterator<Item = &ExecutionEdge> {
258 self.edges.iter().filter(move |e| e.src == node)
259 }
260
261 pub fn incoming_edges(&self, node: ExecutionNodeId) -> impl Iterator<Item = &ExecutionEdge> {
263 self.edges.iter().filter(move |e| e.dst == node)
264 }
265
266 pub fn kernel_nodes(&self) -> impl Iterator<Item = (ExecutionNodeId, &ExecutionNode)> {
268 self.nodes
269 .iter()
270 .enumerate()
271 .filter(|(_, n)| n.is_kernel())
272 .map(|(i, n)| (ExecutionNodeId(i as u32), n))
273 }
274
275 pub fn slowest_kernel(&self) -> Option<(ExecutionNodeId, &ExecutionNode, u64)> {
277 let mut slowest: Option<(ExecutionNodeId, &ExecutionNode, u64)> = None;
278
279 for (id, node) in self.nodes.iter().enumerate() {
280 if let ExecutionNode::Brick { timing_ns, .. } = node {
281 let node_id = ExecutionNodeId(id as u32);
283 let has_kernel =
284 self.outgoing_edges(node_id).any(|e| e.edge_type == EdgeType::Launches);
285
286 if has_kernel {
287 match &slowest {
288 None => slowest = Some((node_id, node, *timing_ns)),
289 Some((_, _, t)) if *timing_ns > *t => {
290 slowest = Some((node_id, node, *timing_ns))
291 }
292 Some(_) => {} }
294 }
295 }
296 }
297
298 slowest
299 }
300
301 pub fn clear(&mut self) {
303 self.nodes.clear();
304 self.edges.clear();
305 self.scope_stack.clear();
306 self.name_to_id.clear();
307 }
308
309 pub fn is_scope_balanced(&self) -> bool {
311 self.scope_stack.is_empty()
312 }
313}