Skip to main content

trueno/brick/exec_graph/traversal/
core.rs

1//! ExecutionGraph core: struct definition, basic graph operations, scope management.
2
3use std::collections::HashMap;
4
5use crate::brick::exec_graph::node::{
6    EdgeType, ExecutionEdge, ExecutionNode, ExecutionNodeId, TransferDirection,
7};
8
9/// Execution path graph for tracking brick → kernel → PTX relationships.
10///
11/// PAR-201: Captures the full execution hierarchy for profiling analysis.
12///
13/// # Example
14///
15/// ```rust,ignore
16/// use trueno::brick::{ExecutionGraph, ExecutionNode, EdgeType};
17///
18/// let mut graph = ExecutionGraph::new();
19///
20/// // Add layer scope
21/// let layer_id = graph.add_node(ExecutionNode::Layer { index: 0 });
22///
23/// // Add brick within layer
24/// let brick_id = graph.add_node(ExecutionNode::Brick {
25///     id: BrickId::QkvProjection,
26///     timing_ns: 1000,
27///     elements: 4096,
28/// });
29/// graph.add_edge(layer_id, brick_id, EdgeType::Contains);
30///
31/// // Add kernel launched by brick
32/// let kernel_id = graph.add_node(ExecutionNode::Kernel {
33///     name: "batched_q4k_gemv".into(),
34///     ptx_hash: 0x7a3b1c2d,
35///     grid: (32, 1, 1),
36///     block: (256, 1, 1),
37///     shared_mem: 4096,
38/// });
39/// graph.add_edge(brick_id, kernel_id, EdgeType::Launches);
40///
41/// // Export to trueno-graph for analysis
42/// #[cfg(feature = "execution-graph")]
43/// let csr = graph.to_csr();
44/// ```
45#[derive(Debug, Default)]
46pub struct ExecutionGraph {
47    /// All nodes in the graph
48    pub(crate) nodes: Vec<ExecutionNode>,
49    /// All edges in the graph
50    pub(crate) edges: Vec<ExecutionEdge>,
51    /// Scope stack for hierarchical recording
52    pub(crate) scope_stack: Vec<ExecutionNodeId>,
53    /// Node name → ID mapping for fast lookup
54    pub(crate) name_to_id: HashMap<String, ExecutionNodeId>,
55}
56
57impl ExecutionGraph {
58    /// Create a new empty execution graph.
59    pub fn new() -> Self {
60        Self::default()
61    }
62
63    /// Add a node to the graph, returning its ID.
64    pub fn add_node(&mut self, node: ExecutionNode) -> ExecutionNodeId {
65        let id = ExecutionNodeId(self.nodes.len() as u32);
66        let name = node.name();
67        self.name_to_id.insert(name, id);
68        self.nodes.push(node);
69        id
70    }
71
72    /// Add an edge between two nodes.
73    pub fn add_edge(&mut self, src: ExecutionNodeId, dst: ExecutionNodeId, edge_type: EdgeType) {
74        debug_assert!(
75            (src.0 as usize) < self.nodes.len(),
76            "CB-BUDGET: src node {} does not exist (graph has {} nodes)",
77            src.0,
78            self.nodes.len()
79        );
80        debug_assert!(
81            (dst.0 as usize) < self.nodes.len(),
82            "CB-BUDGET: dst node {} does not exist (graph has {} nodes)",
83            dst.0,
84            self.nodes.len()
85        );
86        self.edges.push(ExecutionEdge { src, dst, edge_type, weight: 1.0 });
87    }
88
89    /// Add an edge with a weight.
90    pub fn add_weighted_edge(
91        &mut self,
92        src: ExecutionNodeId,
93        dst: ExecutionNodeId,
94        edge_type: EdgeType,
95        weight: f32,
96    ) {
97        self.edges.push(ExecutionEdge { src, dst, edge_type, weight });
98    }
99
100    /// Push a scope for hierarchical recording.
101    /// All subsequent nodes will be children of this scope.
102    pub fn push_scope(&mut self, node: ExecutionNode) -> ExecutionNodeId {
103        let id = self.add_node(node);
104        if let Some(&parent) = self.scope_stack.last() {
105            self.add_edge(parent, id, EdgeType::Contains);
106        }
107        self.scope_stack.push(id);
108        id
109    }
110
111    /// Pop the current scope.
112    pub fn pop_scope(&mut self) -> Option<ExecutionNodeId> {
113        self.scope_stack.pop()
114    }
115
116    /// Get the current scope (if any).
117    pub fn current_scope(&self) -> Option<ExecutionNodeId> {
118        self.scope_stack.last().copied()
119    }
120
121    /// Add a node under the current scope.
122    pub fn add_node_in_scope(&mut self, node: ExecutionNode) -> ExecutionNodeId {
123        let id = self.add_node(node);
124        if let Some(&parent) = self.scope_stack.last() {
125            self.add_edge(parent, id, EdgeType::Contains);
126        }
127        id
128    }
129
130    /// Record a kernel launch under the current scope.
131    pub fn record_kernel_launch(
132        &mut self,
133        name: &str,
134        ptx_hash: u64,
135        grid: (u32, u32, u32),
136        block: (u32, u32, u32),
137        shared_mem: u32,
138    ) -> ExecutionNodeId {
139        debug_assert!(grid.0 > 0 && grid.1 > 0 && grid.2 > 0, "CB-BUDGET: grid dims must be > 0");
140        debug_assert!(
141            block.0 > 0 && block.1 > 0 && block.2 > 0,
142            "CB-BUDGET: block dims must be > 0"
143        );
144        let kernel = ExecutionNode::Kernel {
145            name: name.to_string(),
146            ptx_hash,
147            grid,
148            block,
149            shared_mem,
150            timing_ns: None,
151            arithmetic_intensity: None,
152            achieved_tflops: None,
153        };
154        let kernel_id = self.add_node(kernel);
155
156        // Link from current scope with Launches edge
157        if let Some(&parent) = self.scope_stack.last() {
158            self.add_edge(parent, kernel_id, EdgeType::Launches);
159        }
160
161        kernel_id
162    }
163
164    /// Record a kernel launch with roofline metrics (Phase 9).
165    #[allow(clippy::too_many_arguments)]
166    pub fn record_kernel_launch_with_metrics(
167        &mut self,
168        name: &str,
169        ptx_hash: u64,
170        grid: (u32, u32, u32),
171        block: (u32, u32, u32),
172        shared_mem: u32,
173        timing_ns: u64,
174        arithmetic_intensity: f32,
175        achieved_tflops: f32,
176    ) -> ExecutionNodeId {
177        let kernel = ExecutionNode::Kernel {
178            name: name.to_string(),
179            ptx_hash,
180            grid,
181            block,
182            shared_mem,
183            timing_ns: Some(timing_ns),
184            arithmetic_intensity: Some(arithmetic_intensity),
185            achieved_tflops: Some(achieved_tflops),
186        };
187        let kernel_id = self.add_node(kernel);
188
189        if let Some(&parent) = self.scope_stack.last() {
190            self.add_edge(parent, kernel_id, EdgeType::Launches);
191        }
192
193        kernel_id
194    }
195
196    /// Record a memory transfer (Phase 9: data movement topology).
197    pub fn record_transfer(
198        &mut self,
199        src: &str,
200        dst: &str,
201        bytes: u64,
202        direction: TransferDirection,
203        timing_ns: Option<u64>,
204    ) -> ExecutionNodeId {
205        let transfer = ExecutionNode::Transfer {
206            src: src.to_string(),
207            dst: dst.to_string(),
208            bytes,
209            direction,
210            timing_ns,
211        };
212        let transfer_id = self.add_node(transfer);
213
214        if let Some(&parent) = self.scope_stack.last() {
215            self.add_edge(parent, transfer_id, EdgeType::Contains);
216        }
217
218        transfer_id
219    }
220
221    /// Add a dependency edge for critical path analysis (Phase 9).
222    pub fn add_dependency(&mut self, from: ExecutionNodeId, to: ExecutionNodeId) {
223        self.add_edge(from, to, EdgeType::DependsOn);
224    }
225
226    /// Get a node by ID.
227    pub fn node(&self, id: ExecutionNodeId) -> Option<&ExecutionNode> {
228        self.nodes.get(id.0 as usize)
229    }
230
231    /// Get a node by name.
232    pub fn node_by_name(&self, name: &str) -> Option<(ExecutionNodeId, &ExecutionNode)> {
233        self.name_to_id.get(name).and_then(|&id| self.nodes.get(id.0 as usize).map(|n| (id, n)))
234    }
235
236    /// Get all nodes.
237    pub fn nodes(&self) -> &[ExecutionNode] {
238        &self.nodes
239    }
240
241    /// Get all edges.
242    pub fn edges(&self) -> &[ExecutionEdge] {
243        &self.edges
244    }
245
246    /// Number of nodes.
247    pub fn num_nodes(&self) -> usize {
248        self.nodes.len()
249    }
250
251    /// Number of edges.
252    pub fn num_edges(&self) -> usize {
253        self.edges.len()
254    }
255
256    /// Get outgoing edges for a node.
257    pub fn outgoing_edges(&self, node: ExecutionNodeId) -> impl Iterator<Item = &ExecutionEdge> {
258        self.edges.iter().filter(move |e| e.src == node)
259    }
260
261    /// Get incoming edges for a node.
262    pub fn incoming_edges(&self, node: ExecutionNodeId) -> impl Iterator<Item = &ExecutionEdge> {
263        self.edges.iter().filter(move |e| e.dst == node)
264    }
265
266    /// Find all kernel nodes.
267    pub fn kernel_nodes(&self) -> impl Iterator<Item = (ExecutionNodeId, &ExecutionNode)> {
268        self.nodes
269            .iter()
270            .enumerate()
271            .filter(|(_, n)| n.is_kernel())
272            .map(|(i, n)| (ExecutionNodeId(i as u32), n))
273    }
274
275    /// Find the slowest kernel (by parent brick timing).
276    pub fn slowest_kernel(&self) -> Option<(ExecutionNodeId, &ExecutionNode, u64)> {
277        let mut slowest: Option<(ExecutionNodeId, &ExecutionNode, u64)> = None;
278
279        for (id, node) in self.nodes.iter().enumerate() {
280            if let ExecutionNode::Brick { timing_ns, .. } = node {
281                // Check if this brick has kernel children
282                let node_id = ExecutionNodeId(id as u32);
283                let has_kernel =
284                    self.outgoing_edges(node_id).any(|e| e.edge_type == EdgeType::Launches);
285
286                if has_kernel {
287                    match &slowest {
288                        None => slowest = Some((node_id, node, *timing_ns)),
289                        Some((_, _, t)) if *timing_ns > *t => {
290                            slowest = Some((node_id, node, *timing_ns))
291                        }
292                        Some(_) => {} // Keep existing slowest
293                    }
294                }
295            }
296        }
297
298        slowest
299    }
300
301    /// Clear the graph.
302    pub fn clear(&mut self) {
303        self.nodes.clear();
304        self.edges.clear();
305        self.scope_stack.clear();
306        self.name_to_id.clear();
307    }
308
309    /// Check if scope stack is balanced (empty).
310    pub fn is_scope_balanced(&self) -> bool {
311        self.scope_stack.is_empty()
312    }
313}