Skip to main content

trueno_ptx_debug/analyzer/
control_flow.rs

1//! Control Flow Analyzer - CFG construction and barrier analysis
2
3use crate::bugs::Severity;
4use crate::parser::types::Opcode;
5use crate::parser::{Instruction, KernelDef, PtxModule, SourceLocation, Statement};
6use std::collections::{HashMap, HashSet};
7
8/// Node ID for CFG nodes
9pub type NodeId = usize;
10
11/// Control Flow Graph node
12#[derive(Debug, Clone)]
13pub struct CfgNode {
14    /// Node ID
15    pub id: NodeId,
16    /// Label (if any)
17    pub label: Option<String>,
18    /// Instructions in this basic block
19    pub instructions: Vec<Instruction>,
20    /// Successor node IDs
21    pub successors: Vec<NodeId>,
22    /// Predecessor node IDs
23    pub predecessors: Vec<NodeId>,
24}
25
26/// Control Flow Graph
27#[derive(Debug, Clone)]
28pub struct ControlFlowGraph {
29    /// Nodes in the graph
30    pub nodes: Vec<CfgNode>,
31    /// Entry node ID
32    pub entry: NodeId,
33    /// Exit node IDs
34    pub exits: Vec<NodeId>,
35    /// Label to node ID mapping
36    pub label_to_node: HashMap<String, NodeId>,
37}
38
39impl ControlFlowGraph {
40    /// Get all nodes
41    pub fn nodes(&self) -> &[CfgNode] {
42        &self.nodes
43    }
44
45    /// Get node by ID
46    pub fn get_node(&self, id: NodeId) -> Option<&CfgNode> {
47        self.nodes.get(id)
48    }
49
50    /// Find unreachable nodes
51    pub fn find_unreachable(&self) -> Vec<NodeId> {
52        let mut reachable = HashSet::new();
53        let mut worklist = vec![self.entry];
54
55        while let Some(node_id) = worklist.pop() {
56            if reachable.insert(node_id) {
57                if let Some(node) = self.get_node(node_id) {
58                    worklist.extend(&node.successors);
59                }
60            }
61        }
62
63        self.nodes
64            .iter()
65            .filter(|n| !reachable.contains(&n.id))
66            .map(|n| n.id)
67            .collect()
68    }
69}
70
71/// Barrier violation
72#[derive(Debug, Clone)]
73pub struct BarrierViolation {
74    /// Write location
75    pub write_loc: SourceLocation,
76    /// Read location
77    pub read_loc: SourceLocation,
78    /// Severity
79    pub severity: Severity,
80    /// Message
81    pub message: String,
82}
83
84/// Control Flow Analyzer
85pub struct ControlFlowAnalyzer {
86    cfg: Option<ControlFlowGraph>,
87}
88
89/// Finalize the current basic block, pushing it onto `nodes` and registering
90/// its label (if any) in `label_to_node`. Returns `None` as the new
91/// `current_label` value so callers can do `current_label = flush_block(...)`.
92fn flush_block(
93    nodes: &mut Vec<CfgNode>,
94    label_to_node: &mut HashMap<String, NodeId>,
95    current_label: &mut Option<String>,
96    instructions: &mut Vec<Instruction>,
97) {
98    let node_id = nodes.len();
99    if let Some(ref label) = *current_label {
100        label_to_node.insert(label.clone(), node_id);
101    }
102    nodes.push(CfgNode {
103        id: node_id,
104        label: current_label.take(),
105        instructions: std::mem::take(instructions),
106        successors: Vec::new(),
107        predecessors: Vec::new(),
108    });
109}
110
111/// Collect branch-target edges for a `bra` instruction.
112fn collect_branch_edges(
113    instr: &Instruction,
114    src: NodeId,
115    label_to_node: &HashMap<String, NodeId>,
116    edges: &mut Vec<(NodeId, NodeId)>,
117) {
118    for op in &instr.operands {
119        if let crate::parser::Operand::Label(target) = op {
120            if let Some(&target_id) = label_to_node.get(target) {
121                edges.push((src, target_id));
122            }
123        }
124    }
125}
126
127/// Determine edges and exit nodes for each basic block in the second pass.
128fn collect_edges_for_node(
129    node_idx: NodeId,
130    last_opcode: Option<&Opcode>,
131    last_instr: Option<&Instruction>,
132    node_count: usize,
133    label_to_node: &HashMap<String, NodeId>,
134    edges: &mut Vec<(NodeId, NodeId)>,
135    exits: &mut Vec<NodeId>,
136) {
137    match last_opcode {
138        Some(Opcode::Bra) => {
139            if let Some(instr) = last_instr {
140                collect_branch_edges(instr, node_idx, label_to_node, edges);
141            }
142            // Also fallthrough if conditional
143            if node_idx + 1 < node_count {
144                edges.push((node_idx, node_idx + 1));
145            }
146        }
147        Some(Opcode::Ret | Opcode::Exit) => {
148            exits.push(node_idx);
149        }
150        _ => {
151            if node_idx + 1 < node_count {
152                edges.push((node_idx, node_idx + 1));
153            } else {
154                exits.push(node_idx);
155            }
156        }
157    }
158}
159
160impl ControlFlowAnalyzer {
161    /// Create a new control flow analyzer
162    pub fn new() -> Self {
163        Self { cfg: None }
164    }
165
166    /// Build CFG for a kernel
167    pub fn build_cfg(&mut self, kernel: &KernelDef) -> ControlFlowGraph {
168        let mut nodes = Vec::new();
169        let mut label_to_node: HashMap<String, NodeId> = HashMap::new();
170        let mut current_instructions = Vec::new();
171        let mut current_label: Option<String> = None;
172
173        // First pass: create basic blocks
174        for stmt in &kernel.body {
175            match stmt {
176                Statement::Label(label) => {
177                    if !current_instructions.is_empty() || current_label.is_some() {
178                        flush_block(
179                            &mut nodes,
180                            &mut label_to_node,
181                            &mut current_label,
182                            &mut current_instructions,
183                        );
184                    }
185                    current_label = Some(label.clone());
186                }
187                Statement::Instruction(instr) => {
188                    current_instructions.push(instr.clone());
189                    if instr.opcode.is_branch() {
190                        flush_block(
191                            &mut nodes,
192                            &mut label_to_node,
193                            &mut current_label,
194                            &mut current_instructions,
195                        );
196                    }
197                }
198                _ => {}
199            }
200        }
201
202        // Add final block if any
203        if !current_instructions.is_empty() || current_label.is_some() {
204            flush_block(
205                &mut nodes,
206                &mut label_to_node,
207                &mut current_label,
208                &mut current_instructions,
209            );
210        }
211
212        // Create empty entry node if needed
213        if nodes.is_empty() {
214            nodes.push(CfgNode {
215                id: 0,
216                label: None,
217                instructions: Vec::new(),
218                successors: Vec::new(),
219                predecessors: Vec::new(),
220            });
221        }
222
223        // Second pass: collect edges
224        let mut edges: Vec<(NodeId, NodeId)> = Vec::new();
225        let mut exits = Vec::new();
226        let node_count = nodes.len();
227
228        for i in 0..node_count {
229            let last_instr = nodes[i].instructions.last().cloned();
230            let last_opcode = last_instr.as_ref().map(|instr| &instr.opcode);
231            collect_edges_for_node(
232                i,
233                last_opcode,
234                last_instr.as_ref(),
235                node_count,
236                &label_to_node,
237                &mut edges,
238                &mut exits,
239            );
240        }
241
242        // Apply edges
243        for (from, to) in edges {
244            nodes[from].successors.push(to);
245            nodes[to].predecessors.push(from);
246        }
247
248        let cfg = ControlFlowGraph {
249            nodes,
250            entry: 0,
251            exits,
252            label_to_node,
253        };
254
255        self.cfg = Some(cfg.clone());
256        cfg
257    }
258
259    /// Analyze barrier synchronization
260    pub fn analyze_barriers(&self, _module: &PtxModule) -> Vec<BarrierViolation> {
261        let mut violations = Vec::new();
262
263        if let Some(ref cfg) = self.cfg {
264            for node in cfg.nodes() {
265                let mut last_shared_write: Option<SourceLocation> = None;
266
267                for instr in &node.instructions {
268                    // Check for shared memory operations
269                    let is_shared = instr
270                        .modifiers
271                        .iter()
272                        .any(|m| matches!(m, crate::parser::types::Modifier::Shared));
273
274                    if instr.opcode == Opcode::St && is_shared {
275                        last_shared_write = Some(instr.location.clone());
276                    } else if instr.opcode == Opcode::Ld && is_shared {
277                        if let Some(ref write_loc) = last_shared_write {
278                            // Check if there's a barrier between write and read
279                            // Simplified: just check if bar.sync appears
280                            let has_barrier = node
281                                .instructions
282                                .iter()
283                                .any(|i| matches!(i.opcode, Opcode::Bar));
284
285                            if !has_barrier {
286                                violations.push(BarrierViolation {
287                                    write_loc: write_loc.clone(),
288                                    read_loc: instr.location.clone(),
289                                    severity: Severity::High,
290                                    message: "Missing barrier between shared memory write and read"
291                                        .into(),
292                                });
293                            }
294                        }
295                    } else if matches!(instr.opcode, Opcode::Bar) {
296                        last_shared_write = None;
297                    }
298                }
299            }
300        }
301
302        violations
303    }
304}
305
306impl Default for ControlFlowAnalyzer {
307    fn default() -> Self {
308        Self::new()
309    }
310}
311
312#[cfg(test)]
313mod tests {
314    use super::*;
315    use crate::parser::Parser;
316
317    // F061: All code paths reach ret or exit
318    #[test]
319    fn f061_all_paths_reach_exit() {
320        let ptx = r#"
321            .version 8.0
322            .target sm_70
323            .address_size 64
324
325            .entry test()
326            {
327                .reg .u32 %r<10>;
328                mov.u32 %r0, 0;
329                ret;
330            }
331        "#;
332        let mut parser = Parser::new(ptx).expect("parser creation should succeed");
333        let module = parser.parse().expect("parsing should succeed");
334
335        let mut analyzer = ControlFlowAnalyzer::new();
336        let cfg = analyzer.build_cfg(&module.kernels[0]);
337
338        // All exit nodes should have ret or exit
339        assert!(!cfg.exits.is_empty(), "F061: Should have exit nodes");
340    }
341
342    // F062: No unreachable code
343    #[test]
344    fn f062_no_unreachable_code() {
345        let ptx = r#"
346            .version 8.0
347            .target sm_70
348            .address_size 64
349
350            .entry test()
351            {
352                .reg .u32 %r<10>;
353                mov.u32 %r0, 0;
354                ret;
355            }
356        "#;
357        let mut parser = Parser::new(ptx).expect("parser creation should succeed");
358        let module = parser.parse().expect("parsing should succeed");
359
360        let mut analyzer = ControlFlowAnalyzer::new();
361        let cfg = analyzer.build_cfg(&module.kernels[0]);
362        let unreachable = cfg.find_unreachable();
363
364        assert!(
365            unreachable.is_empty(),
366            "F062: Found unreachable code: {:?}",
367            unreachable
368        );
369    }
370
371    // F036: bar.sync after shared write, before read
372    #[test]
373    fn f036_barrier_after_shared_write() {
374        let ptx = r#"
375            .version 8.0
376            .target sm_70
377            .address_size 64
378
379            .entry test()
380            {
381                .reg .u32 %r<10>;
382                st.shared.u32 [%r0], %r1;
383                bar.sync 0;
384                ld.shared.u32 %r2, [%r3];
385                ret;
386            }
387        "#;
388        let mut parser = Parser::new(ptx).expect("parser creation should succeed");
389        let module = parser.parse().expect("parsing should succeed");
390
391        let mut analyzer = ControlFlowAnalyzer::new();
392        let _ = analyzer.build_cfg(&module.kernels[0]);
393        let violations = analyzer.analyze_barriers(&module);
394
395        assert!(
396            violations.is_empty(),
397            "F036: Should have no barrier violations with proper sync"
398        );
399    }
400}