use crate::bugs::Severity;
use crate::parser::types::Opcode;
use crate::parser::{Instruction, KernelDef, PtxModule, SourceLocation, Statement};
use std::collections::{HashMap, HashSet};
pub type NodeId = usize;
#[derive(Debug, Clone)]
pub struct CfgNode {
pub id: NodeId,
pub label: Option<String>,
pub instructions: Vec<Instruction>,
pub successors: Vec<NodeId>,
pub predecessors: Vec<NodeId>,
}
#[derive(Debug, Clone)]
pub struct ControlFlowGraph {
pub nodes: Vec<CfgNode>,
pub entry: NodeId,
pub exits: Vec<NodeId>,
pub label_to_node: HashMap<String, NodeId>,
}
impl ControlFlowGraph {
pub fn nodes(&self) -> &[CfgNode] {
&self.nodes
}
pub fn get_node(&self, id: NodeId) -> Option<&CfgNode> {
self.nodes.get(id)
}
pub fn find_unreachable(&self) -> Vec<NodeId> {
let mut reachable = HashSet::new();
let mut worklist = vec![self.entry];
while let Some(node_id) = worklist.pop() {
if reachable.insert(node_id) {
if let Some(node) = self.get_node(node_id) {
worklist.extend(&node.successors);
}
}
}
self.nodes
.iter()
.filter(|n| !reachable.contains(&n.id))
.map(|n| n.id)
.collect()
}
}
#[derive(Debug, Clone)]
pub struct BarrierViolation {
pub write_loc: SourceLocation,
pub read_loc: SourceLocation,
pub severity: Severity,
pub message: String,
}
pub struct ControlFlowAnalyzer {
cfg: Option<ControlFlowGraph>,
}
fn flush_block(
nodes: &mut Vec<CfgNode>,
label_to_node: &mut HashMap<String, NodeId>,
current_label: &mut Option<String>,
instructions: &mut Vec<Instruction>,
) {
let node_id = nodes.len();
if let Some(ref label) = *current_label {
label_to_node.insert(label.clone(), node_id);
}
nodes.push(CfgNode {
id: node_id,
label: current_label.take(),
instructions: std::mem::take(instructions),
successors: Vec::new(),
predecessors: Vec::new(),
});
}
fn collect_branch_edges(
instr: &Instruction,
src: NodeId,
label_to_node: &HashMap<String, NodeId>,
edges: &mut Vec<(NodeId, NodeId)>,
) {
for op in &instr.operands {
if let crate::parser::Operand::Label(target) = op {
if let Some(&target_id) = label_to_node.get(target) {
edges.push((src, target_id));
}
}
}
}
fn collect_edges_for_node(
node_idx: NodeId,
last_opcode: Option<&Opcode>,
last_instr: Option<&Instruction>,
node_count: usize,
label_to_node: &HashMap<String, NodeId>,
edges: &mut Vec<(NodeId, NodeId)>,
exits: &mut Vec<NodeId>,
) {
match last_opcode {
Some(Opcode::Bra) => {
if let Some(instr) = last_instr {
collect_branch_edges(instr, node_idx, label_to_node, edges);
}
if node_idx + 1 < node_count {
edges.push((node_idx, node_idx + 1));
}
}
Some(Opcode::Ret | Opcode::Exit) => {
exits.push(node_idx);
}
_ => {
if node_idx + 1 < node_count {
edges.push((node_idx, node_idx + 1));
} else {
exits.push(node_idx);
}
}
}
}
impl ControlFlowAnalyzer {
pub fn new() -> Self {
Self { cfg: None }
}
pub fn build_cfg(&mut self, kernel: &KernelDef) -> ControlFlowGraph {
let mut nodes = Vec::new();
let mut label_to_node: HashMap<String, NodeId> = HashMap::new();
let mut current_instructions = Vec::new();
let mut current_label: Option<String> = None;
for stmt in &kernel.body {
match stmt {
Statement::Label(label) => {
if !current_instructions.is_empty() || current_label.is_some() {
flush_block(
&mut nodes,
&mut label_to_node,
&mut current_label,
&mut current_instructions,
);
}
current_label = Some(label.clone());
}
Statement::Instruction(instr) => {
current_instructions.push(instr.clone());
if instr.opcode.is_branch() {
flush_block(
&mut nodes,
&mut label_to_node,
&mut current_label,
&mut current_instructions,
);
}
}
_ => {}
}
}
if !current_instructions.is_empty() || current_label.is_some() {
flush_block(
&mut nodes,
&mut label_to_node,
&mut current_label,
&mut current_instructions,
);
}
if nodes.is_empty() {
nodes.push(CfgNode {
id: 0,
label: None,
instructions: Vec::new(),
successors: Vec::new(),
predecessors: Vec::new(),
});
}
let mut edges: Vec<(NodeId, NodeId)> = Vec::new();
let mut exits = Vec::new();
let node_count = nodes.len();
for i in 0..node_count {
let last_instr = nodes[i].instructions.last().cloned();
let last_opcode = last_instr.as_ref().map(|instr| &instr.opcode);
collect_edges_for_node(
i,
last_opcode,
last_instr.as_ref(),
node_count,
&label_to_node,
&mut edges,
&mut exits,
);
}
for (from, to) in edges {
nodes[from].successors.push(to);
nodes[to].predecessors.push(from);
}
let cfg = ControlFlowGraph {
nodes,
entry: 0,
exits,
label_to_node,
};
self.cfg = Some(cfg.clone());
cfg
}
pub fn analyze_barriers(&self, _module: &PtxModule) -> Vec<BarrierViolation> {
let mut violations = Vec::new();
if let Some(ref cfg) = self.cfg {
for node in cfg.nodes() {
let mut last_shared_write: Option<SourceLocation> = None;
for instr in &node.instructions {
let is_shared = instr
.modifiers
.iter()
.any(|m| matches!(m, crate::parser::types::Modifier::Shared));
if instr.opcode == Opcode::St && is_shared {
last_shared_write = Some(instr.location.clone());
} else if instr.opcode == Opcode::Ld && is_shared {
if let Some(ref write_loc) = last_shared_write {
let has_barrier = node
.instructions
.iter()
.any(|i| matches!(i.opcode, Opcode::Bar));
if !has_barrier {
violations.push(BarrierViolation {
write_loc: write_loc.clone(),
read_loc: instr.location.clone(),
severity: Severity::High,
message: "Missing barrier between shared memory write and read"
.into(),
});
}
}
} else if matches!(instr.opcode, Opcode::Bar) {
last_shared_write = None;
}
}
}
}
violations
}
}
impl Default for ControlFlowAnalyzer {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::parser::Parser;
#[test]
fn f061_all_paths_reach_exit() {
let ptx = r#"
.version 8.0
.target sm_70
.address_size 64
.entry test()
{
.reg .u32 %r<10>;
mov.u32 %r0, 0;
ret;
}
"#;
let mut parser = Parser::new(ptx).expect("parser creation should succeed");
let module = parser.parse().expect("parsing should succeed");
let mut analyzer = ControlFlowAnalyzer::new();
let cfg = analyzer.build_cfg(&module.kernels[0]);
assert!(!cfg.exits.is_empty(), "F061: Should have exit nodes");
}
#[test]
fn f062_no_unreachable_code() {
let ptx = r#"
.version 8.0
.target sm_70
.address_size 64
.entry test()
{
.reg .u32 %r<10>;
mov.u32 %r0, 0;
ret;
}
"#;
let mut parser = Parser::new(ptx).expect("parser creation should succeed");
let module = parser.parse().expect("parsing should succeed");
let mut analyzer = ControlFlowAnalyzer::new();
let cfg = analyzer.build_cfg(&module.kernels[0]);
let unreachable = cfg.find_unreachable();
assert!(
unreachable.is_empty(),
"F062: Found unreachable code: {:?}",
unreachable
);
}
#[test]
fn f036_barrier_after_shared_write() {
let ptx = r#"
.version 8.0
.target sm_70
.address_size 64
.entry test()
{
.reg .u32 %r<10>;
st.shared.u32 [%r0], %r1;
bar.sync 0;
ld.shared.u32 %r2, [%r3];
ret;
}
"#;
let mut parser = Parser::new(ptx).expect("parser creation should succeed");
let module = parser.parse().expect("parsing should succeed");
let mut analyzer = ControlFlowAnalyzer::new();
let _ = analyzer.build_cfg(&module.kernels[0]);
let violations = analyzer.analyze_barriers(&module);
assert!(
violations.is_empty(),
"F036: Should have no barrier violations with proper sync"
);
}
}