#![allow(dead_code)]
use super::const_eval::MacroConstantMap;
use std::collections::HashMap;
use tree_sitter::Node;
pub type BlockId = usize;
#[derive(Debug, Clone)]
pub struct BasicBlock {
pub id: BlockId,
pub statements: Vec<(usize, usize)>,
pub byte_range: (usize, usize),
pub condition_range: Option<(usize, usize)>,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum CfgEdge {
Fallthrough,
TrueBranch,
FalseBranch,
BackEdge,
Return,
Break,
Continue,
Goto,
}
#[derive(Debug, Clone)]
pub struct FunctionCfg {
pub blocks: Vec<BasicBlock>,
pub edges: Vec<(BlockId, BlockId, CfgEdge)>,
pub entry: BlockId,
pub exits: Vec<BlockId>,
function_start_byte: usize,
}
impl FunctionCfg {
pub fn successors(&self, block_id: BlockId) -> Vec<(BlockId, &CfgEdge)> {
self.edges
.iter()
.filter(|(from, _, _)| *from == block_id)
.map(|(_, to, edge)| (*to, edge))
.collect()
}
pub fn predecessors(&self, block_id: BlockId) -> Vec<(BlockId, &CfgEdge)> {
self.edges
.iter()
.filter(|(_, to, _)| *to == block_id)
.map(|(from, _, edge)| (*from, edge))
.collect()
}
pub fn block_count(&self) -> usize {
self.blocks.len()
}
pub fn get_block(&self, id: BlockId) -> Option<&BasicBlock> {
self.blocks.get(id)
}
}
struct CfgBuilder {
blocks: Vec<BasicBlock>,
edges: Vec<(BlockId, BlockId, CfgEdge)>,
current_block: BlockId,
loop_stack: Vec<(BlockId, BlockId)>,
label_blocks: HashMap<String, BlockId>,
pending_gotos: Vec<(BlockId, String)>,
function_start_byte: usize,
constants: MacroConstantMap,
}
impl CfgBuilder {
fn new(function_start_byte: usize, constants: MacroConstantMap) -> Self {
let entry_block = BasicBlock {
id: 0,
statements: Vec::new(),
byte_range: (0, 0),
condition_range: None,
};
CfgBuilder {
blocks: vec![entry_block],
edges: Vec::new(),
current_block: 0,
loop_stack: Vec::new(),
label_blocks: HashMap::new(),
pending_gotos: Vec::new(),
function_start_byte,
constants,
}
}
fn new_block(&mut self) -> BlockId {
let id = self.blocks.len();
self.blocks.push(BasicBlock {
id,
statements: Vec::new(),
byte_range: (0, 0),
condition_range: None,
});
id
}
fn add_edge(&mut self, from: BlockId, to: BlockId, kind: CfgEdge) {
if !self
.edges
.iter()
.any(|(f, t, k)| *f == from && *t == to && *k == kind)
{
self.edges.push((from, to, kind));
}
}
fn add_statement(&mut self, start: usize, end: usize) {
if let Some(block) = self.blocks.get_mut(self.current_block) {
block.statements.push((start, end));
if block.byte_range.0 == 0 || start < block.byte_range.0 {
block.byte_range.0 = start;
}
if end > block.byte_range.1 {
block.byte_range.1 = end;
}
}
}
fn build_from_compound_statement<'a>(&mut self, node: &Node<'a>, source: &str) {
for i in 0..node.child_count() {
if let Some(child) = node.child(i) {
match child.kind() {
"{" | "}" => continue,
_ => self.process_statement(&child, source),
}
}
}
}
fn process_statement<'a>(&mut self, node: &Node<'a>, source: &str) {
match node.kind() {
"if_statement" => self.process_if(node, source),
"while_statement" => self.process_while(node, source),
"for_statement" => self.process_for(node, source),
"do_statement" => self.process_do_while(node, source),
"switch_statement" => {
self.add_statement(node.start_byte(), node.end_byte());
}
"return_statement" => {
self.add_statement(node.start_byte(), node.end_byte());
let exit_block = self.new_block();
self.add_edge(self.current_block, exit_block, CfgEdge::Return);
self.current_block = self.new_block(); }
"break_statement" => {
self.add_statement(node.start_byte(), node.end_byte());
if let Some(&(_, loop_exit)) = self.loop_stack.last() {
self.add_edge(self.current_block, loop_exit, CfgEdge::Break);
}
self.current_block = self.new_block(); }
"continue_statement" => {
self.add_statement(node.start_byte(), node.end_byte());
if let Some(&(loop_header, _)) = self.loop_stack.last() {
self.add_edge(self.current_block, loop_header, CfgEdge::Continue);
}
self.current_block = self.new_block(); }
"goto_statement" => {
self.add_statement(node.start_byte(), node.end_byte());
for i in 0..node.child_count() {
if let Some(child) = node.child(i) {
if child.kind() == "statement_identifier" || child.kind() == "identifier" {
if let Ok(label) = child.utf8_text(source.as_bytes()) {
self.pending_gotos
.push((self.current_block, label.to_string()));
}
}
}
}
self.current_block = self.new_block(); }
"compound_statement" => {
self.build_from_compound_statement(node, source);
}
"labeled_statement" => {
let label_block = self.new_block();
self.add_edge(self.current_block, label_block, CfgEdge::Fallthrough);
self.current_block = label_block;
for i in 0..node.child_count() {
if let Some(child) = node.child(i) {
if child.kind() == "identifier" || child.kind() == "statement_identifier" {
if let Ok(label) = child.utf8_text(source.as_bytes()) {
self.label_blocks.insert(label.to_string(), label_block);
}
}
}
}
for i in 0..node.child_count() {
if let Some(child) = node.child(i) {
if child.kind() != ":"
&& child.kind() != "identifier"
&& child.kind() != "statement_identifier"
{
self.process_statement(&child, source);
}
}
}
}
_ => {
self.add_statement(node.start_byte(), node.end_byte());
}
}
}
fn process_if<'a>(&mut self, node: &Node<'a>, source: &str) {
let const_val = if let Some(condition) = node.child_by_field_name("condition") {
self.add_statement(condition.start_byte(), condition.end_byte());
if let Some(block) = self.blocks.get_mut(self.current_block) {
block.condition_range = Some((condition.start_byte(), condition.end_byte()));
}
evaluate_constant_condition(&condition, source, &self.constants)
} else {
None
};
let condition_block = self.current_block;
let then_block = self.new_block();
let join_block = self.new_block();
if const_val != Some(false) {
self.add_edge(condition_block, then_block, CfgEdge::TrueBranch);
self.current_block = then_block;
if let Some(consequence) = node.child_by_field_name("consequence") {
self.process_statement(&consequence, source);
}
self.add_edge(self.current_block, join_block, CfgEdge::Fallthrough);
}
if const_val != Some(true) {
if let Some(alternative) = node.child_by_field_name("alternative") {
let else_block = self.new_block();
self.add_edge(condition_block, else_block, CfgEdge::FalseBranch);
self.current_block = else_block;
for i in 0..alternative.child_count() {
if let Some(child) = alternative.child(i) {
if child.kind() != "else" {
self.process_statement(&child, source);
}
}
}
self.add_edge(self.current_block, join_block, CfgEdge::Fallthrough);
} else {
self.add_edge(condition_block, join_block, CfgEdge::FalseBranch);
}
}
if const_val == Some(false) && node.child_by_field_name("alternative").is_none() {
self.add_edge(condition_block, join_block, CfgEdge::Fallthrough);
}
self.current_block = join_block;
}
fn process_while<'a>(&mut self, node: &Node<'a>, source: &str) {
let header_block = self.new_block();
self.add_edge(self.current_block, header_block, CfgEdge::Fallthrough);
self.current_block = header_block;
let const_val = if let Some(condition) = node.child_by_field_name("condition") {
self.add_statement(condition.start_byte(), condition.end_byte());
if let Some(block) = self.blocks.get_mut(header_block) {
block.condition_range = Some((condition.start_byte(), condition.end_byte()));
}
evaluate_constant_condition(&condition, source, &self.constants)
} else {
None
};
let body_block = self.new_block();
let exit_block = self.new_block();
self.add_edge(header_block, body_block, CfgEdge::TrueBranch);
if const_val != Some(true) {
self.add_edge(header_block, exit_block, CfgEdge::FalseBranch);
}
self.loop_stack.push((header_block, exit_block));
self.current_block = body_block;
if let Some(body) = node.child_by_field_name("body") {
self.process_statement(&body, source);
}
self.add_edge(self.current_block, header_block, CfgEdge::BackEdge);
self.loop_stack.pop();
self.current_block = exit_block;
}
fn process_for<'a>(&mut self, node: &Node<'a>, source: &str) {
if let Some(initializer) = node.child_by_field_name("initializer") {
self.add_statement(initializer.start_byte(), initializer.end_byte());
}
let header_block = self.new_block();
self.add_edge(self.current_block, header_block, CfgEdge::Fallthrough);
self.current_block = header_block;
let const_val = if let Some(condition) = node.child_by_field_name("condition") {
self.add_statement(condition.start_byte(), condition.end_byte());
if let Some(block) = self.blocks.get_mut(header_block) {
block.condition_range = Some((condition.start_byte(), condition.end_byte()));
}
evaluate_constant_condition(&condition, source, &self.constants)
} else {
Some(true)
};
let body_block = self.new_block();
let update_block = self.new_block();
let exit_block = self.new_block();
self.add_edge(header_block, body_block, CfgEdge::TrueBranch);
if const_val != Some(true) {
self.add_edge(header_block, exit_block, CfgEdge::FalseBranch);
}
self.loop_stack.push((update_block, exit_block));
self.current_block = body_block;
if let Some(body) = node.child_by_field_name("body") {
self.process_statement(&body, source);
}
self.add_edge(self.current_block, update_block, CfgEdge::Fallthrough);
self.loop_stack.pop();
self.current_block = update_block;
if let Some(update) = node.child_by_field_name("update") {
self.add_statement(update.start_byte(), update.end_byte());
}
self.add_edge(update_block, header_block, CfgEdge::BackEdge);
self.current_block = exit_block;
}
fn process_do_while<'a>(&mut self, node: &Node<'a>, source: &str) {
let body_block = self.new_block();
self.add_edge(self.current_block, body_block, CfgEdge::Fallthrough);
let exit_block = self.new_block();
self.loop_stack.push((body_block, exit_block));
self.current_block = body_block;
if let Some(body) = node.child_by_field_name("body") {
self.process_statement(&body, source);
}
self.loop_stack.pop();
let cond_block = self.new_block();
self.add_edge(self.current_block, cond_block, CfgEdge::Fallthrough);
self.current_block = cond_block;
if let Some(condition) = node.child_by_field_name("condition") {
self.add_statement(condition.start_byte(), condition.end_byte());
if let Some(block) = self.blocks.get_mut(cond_block) {
block.condition_range = Some((condition.start_byte(), condition.end_byte()));
}
}
self.add_edge(cond_block, body_block, CfgEdge::BackEdge);
self.add_edge(cond_block, exit_block, CfgEdge::FalseBranch);
self.current_block = exit_block;
}
fn build(mut self) -> FunctionCfg {
let goto_edges: Vec<(BlockId, BlockId)> = self
.pending_gotos
.iter()
.filter_map(|(src, label)| self.label_blocks.get(label).map(|&tgt| (*src, tgt)))
.collect();
for (src, tgt) in goto_edges {
self.add_edge(src, tgt, CfgEdge::Goto);
}
let mut exits: Vec<BlockId> = self
.edges
.iter()
.filter(|(_, _, kind)| *kind == CfgEdge::Return)
.map(|(from, _, _)| *from)
.collect();
let return_targets: Vec<BlockId> = self
.edges
.iter()
.filter(|(_, _, kind)| *kind == CfgEdge::Return)
.map(|(_, to, _)| *to)
.collect();
exits.extend(return_targets);
if exits.is_empty() && !self.blocks.is_empty() {
exits.push(self.blocks.len() - 1);
}
exits.sort();
exits.dedup();
FunctionCfg {
blocks: self.blocks,
edges: self.edges,
entry: 0,
exits,
function_start_byte: self.function_start_byte,
}
}
}
fn evaluate_constant_condition(
condition: &Node,
source: &str,
constants: &MacroConstantMap,
) -> Option<bool> {
let inner = unwrap_parens_cfg(condition);
match inner.kind() {
"number_literal" => {
let text = inner.utf8_text(source.as_bytes()).ok()?;
let trimmed = text.trim();
if trimmed == "0" {
Some(false)
} else {
trimmed.parse::<i64>().ok().map(|n| n != 0)
}
}
"true" => Some(true),
"false" => Some(false),
"identifier" => {
let name = inner.utf8_text(source.as_bytes()).ok()?;
constants.get(name).map(|&v| v != 0)
}
"binary_expression" => {
let left = inner.child_by_field_name("left")?;
let operator = inner.child_by_field_name("operator")?;
let right = inner.child_by_field_name("right")?;
let left = unwrap_parens_cfg(&left);
let right = unwrap_parens_cfg(&right);
let lv = resolve_constant_operand(&left, source, constants)?;
let rv = resolve_constant_operand(&right, source, constants)?;
let op = operator.utf8_text(source.as_bytes()).ok()?;
match op.trim() {
"==" => Some(lv == rv),
"!=" => Some(lv != rv),
"<" => Some(lv < rv),
">" => Some(lv > rv),
"<=" => Some(lv <= rv),
">=" => Some(lv >= rv),
_ => None,
}
}
_ => None,
}
}
fn resolve_constant_operand(
node: &Node,
source: &str,
constants: &MacroConstantMap,
) -> Option<i64> {
match node.kind() {
"number_literal" => {
let text = node.utf8_text(source.as_bytes()).ok()?.trim().to_string();
text.parse::<i64>().ok()
}
"identifier" => {
let name = node.utf8_text(source.as_bytes()).ok()?;
constants.get(name).copied()
}
_ => None,
}
}
fn unwrap_parens_cfg<'a>(node: &'a Node<'a>) -> Node<'a> {
let mut n = *node;
while n.kind() == "parenthesized_expression" {
if let Some(inner) = n.child(1) {
n = inner;
} else {
break;
}
}
n
}
pub fn build_function_cfg(func_node: &Node, source: &str) -> Option<FunctionCfg> {
build_function_cfg_with_constants(func_node, source, &MacroConstantMap::new())
}
pub fn build_function_cfg_with_constants(
func_node: &Node,
source: &str,
constants: &MacroConstantMap,
) -> Option<FunctionCfg> {
if func_node.kind() != "function_definition" {
return None;
}
let body = func_node.child_by_field_name("body")?;
if body.kind() != "compound_statement" {
return None;
}
let mut builder = CfgBuilder::new(func_node.start_byte(), constants.clone());
builder.build_from_compound_statement(&body, source);
Some(builder.build())
}
pub fn get_function_name<'a>(func_node: &Node<'a>, source: &'a str) -> Option<&'a str> {
let declarator = func_node.child_by_field_name("declarator")?;
extract_name_from_declarator(&declarator, source)
}
fn extract_name_from_declarator<'a>(node: &Node<'a>, source: &'a str) -> Option<&'a str> {
match node.kind() {
"identifier" => node.utf8_text(source.as_bytes()).ok(),
"function_declarator" | "pointer_declarator" => {
let inner = node.child_by_field_name("declarator")?;
extract_name_from_declarator(&inner, source)
}
_ => {
for i in 0..node.child_count() {
if let Some(child) = node.child(i) {
if child.kind() == "identifier" {
return child.utf8_text(source.as_bytes()).ok();
}
}
}
None
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn parse_and_build_cfg(code: &str) -> Option<FunctionCfg> {
let mut parser = tree_sitter::Parser::new();
parser.set_language(&tree_sitter_c::language()).unwrap();
let tree = parser.parse(code, None).unwrap();
let root = tree.root_node();
for i in 0..root.child_count() {
if let Some(child) = root.child(i) {
if child.kind() == "function_definition" {
return build_function_cfg(&child, code);
}
}
}
None
}
#[test]
fn test_simple_function() {
let code = r#"
void foo() {
int x = 1;
int y = 2;
}
"#;
let cfg = parse_and_build_cfg(code).unwrap();
assert!(cfg.block_count() >= 1);
assert_eq!(cfg.entry, 0);
}
#[test]
fn test_if_else() {
let code = r#"
void foo(int x) {
if (x > 0) {
x = 1;
} else {
x = 2;
}
x = 3;
}
"#;
let cfg = parse_and_build_cfg(code).unwrap();
assert!(cfg.block_count() >= 4);
let has_true = cfg.edges.iter().any(|(_, _, e)| *e == CfgEdge::TrueBranch);
let has_false = cfg.edges.iter().any(|(_, _, e)| *e == CfgEdge::FalseBranch);
assert!(has_true);
assert!(has_false);
}
#[test]
fn test_while_loop() {
let code = r#"
void foo(int n) {
int i = 0;
while (i < n) {
i++;
}
}
"#;
let cfg = parse_and_build_cfg(code).unwrap();
let has_back = cfg.edges.iter().any(|(_, _, e)| *e == CfgEdge::BackEdge);
assert!(has_back);
}
#[test]
fn test_for_loop() {
let code = r#"
void foo() {
for (int i = 0; i < 10; i++) {
int x = i;
}
}
"#;
let cfg = parse_and_build_cfg(code).unwrap();
let has_back = cfg.edges.iter().any(|(_, _, e)| *e == CfgEdge::BackEdge);
assert!(has_back);
}
#[test]
fn test_return_creates_exit() {
let code = r#"
int foo(int x) {
if (x < 0) {
return -1;
}
return x;
}
"#;
let cfg = parse_and_build_cfg(code).unwrap();
let return_count = cfg
.edges
.iter()
.filter(|(_, _, e)| *e == CfgEdge::Return)
.count();
assert!(return_count >= 2);
}
#[test]
fn test_goto_edges() {
let code = r#"
void foo(int x) {
if (x < 0) goto skip;
int y = 42;
skip:
use(y);
}
"#;
let cfg = parse_and_build_cfg(code).unwrap();
let has_goto = cfg.edges.iter().any(|(_, _, e)| *e == CfgEdge::Goto);
assert!(has_goto, "Should have a goto edge");
assert!(
cfg.block_count() >= 4,
"goto+label should create at least 4 blocks"
);
}
#[test]
fn test_break_continue() {
let code = r#"
void foo(int n) {
for (int i = 0; i < n; i++) {
if (i == 5) break;
if (i == 3) continue;
}
}
"#;
let cfg = parse_and_build_cfg(code).unwrap();
let has_break = cfg.edges.iter().any(|(_, _, e)| *e == CfgEdge::Break);
let has_continue = cfg.edges.iter().any(|(_, _, e)| *e == CfgEdge::Continue);
assert!(has_break);
assert!(has_continue);
}
}