#![allow(dead_code)]
use super::cfg::{BlockId, FunctionCfg};
use std::collections::{HashMap, HashSet, VecDeque};
use tree_sitter::Node;
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum DefinitionKind {
Declaration,
Assignment,
Parameter,
NullAssignment,
FreeCall,
NullableCall,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct Definition {
pub variable: String,
pub block_id: BlockId,
pub statement_index: usize,
pub kind: DefinitionKind,
pub byte_offset: usize,
}
#[derive(Debug)]
pub struct ReachingDefs {
pub definitions: Vec<Definition>,
pub reaching_in: HashMap<BlockId, HashSet<usize>>,
pub reaching_out: HashMap<BlockId, HashSet<usize>>,
}
impl ReachingDefs {
pub fn defs_reaching_block(&self, block_id: BlockId, var_name: &str) -> Vec<&Definition> {
if let Some(reaching) = self.reaching_in.get(&block_id) {
reaching
.iter()
.filter_map(|&idx| {
let def = &self.definitions[idx];
if def.variable == var_name {
Some(def)
} else {
None
}
})
.collect()
} else {
Vec::new()
}
}
pub fn has_reaching_def_kind(
&self,
block_id: BlockId,
var_name: &str,
kind: &DefinitionKind,
) -> bool {
self.defs_reaching_block(block_id, var_name)
.iter()
.any(|d| &d.kind == kind)
}
pub fn is_potentially_freed(&self, block_id: BlockId, var_name: &str) -> bool {
self.has_reaching_def_kind(block_id, var_name, &DefinitionKind::FreeCall)
}
pub fn is_potentially_null(&self, block_id: BlockId, var_name: &str) -> bool {
self.has_reaching_def_kind(block_id, var_name, &DefinitionKind::NullAssignment)
|| self.has_reaching_def_kind(block_id, var_name, &DefinitionKind::NullableCall)
}
}
fn compute_gen_kill(
cfg: &FunctionCfg,
definitions: &[Definition],
) -> (
HashMap<BlockId, HashSet<usize>>,
HashMap<BlockId, HashSet<usize>>,
) {
let mut gen: HashMap<BlockId, HashSet<usize>> = HashMap::new();
let mut kill: HashMap<BlockId, HashSet<usize>> = HashMap::new();
let mut block_defs: HashMap<BlockId, Vec<usize>> = HashMap::new();
for (idx, def) in definitions.iter().enumerate() {
block_defs.entry(def.block_id).or_default().push(idx);
}
let mut var_defs: HashMap<&str, Vec<usize>> = HashMap::new();
for (idx, def) in definitions.iter().enumerate() {
var_defs.entry(&def.variable).or_default().push(idx);
}
for block in &cfg.blocks {
let block_gen = gen.entry(block.id).or_default();
let block_kill = kill.entry(block.id).or_default();
if let Some(defs_in_block) = block_defs.get(&block.id) {
for &def_idx in defs_in_block {
let def = &definitions[def_idx];
if let Some(other_defs) = var_defs.get(def.variable.as_str()) {
for &other_idx in other_defs {
if other_idx != def_idx {
block_kill.insert(other_idx);
}
}
}
block_gen.insert(def_idx);
}
let mut last_def_per_var: HashMap<&str, usize> = HashMap::new();
for &def_idx in defs_in_block {
last_def_per_var.insert(&definitions[def_idx].variable, def_idx);
}
block_gen.retain(|idx| {
let var = &definitions[*idx].variable;
last_def_per_var.get(var.as_str()) == Some(idx)
});
}
}
(gen, kill)
}
pub fn compute_reaching_definitions(
cfg: &FunctionCfg,
definitions: Vec<Definition>,
) -> ReachingDefs {
let (gen, kill) = compute_gen_kill(cfg, &definitions);
let mut reaching_in: HashMap<BlockId, HashSet<usize>> = HashMap::new();
let mut reaching_out: HashMap<BlockId, HashSet<usize>> = HashMap::new();
for block in &cfg.blocks {
reaching_in.insert(block.id, HashSet::new());
reaching_out.insert(block.id, gen.get(&block.id).cloned().unwrap_or_default());
}
let mut worklist: VecDeque<BlockId> = cfg.blocks.iter().map(|b| b.id).collect();
let mut iterations = 0;
const MAX_ITERATIONS: usize = 1000;
while let Some(block_id) = worklist.pop_front() {
iterations += 1;
if iterations > MAX_ITERATIONS {
break; }
let mut new_in = HashSet::new();
for (pred_id, _) in cfg.predecessors(block_id) {
if let Some(pred_out) = reaching_out.get(&pred_id) {
new_in.extend(pred_out.iter());
}
}
let block_gen = gen.get(&block_id).cloned().unwrap_or_default();
let block_kill = kill.get(&block_id).cloned().unwrap_or_default();
let mut new_out: HashSet<usize> = new_in.difference(&block_kill).copied().collect();
new_out.extend(block_gen.iter());
let old_out = reaching_out.get(&block_id).cloned().unwrap_or_default();
if new_out != old_out {
reaching_out.insert(block_id, new_out);
reaching_in.insert(block_id, new_in);
for (succ_id, _) in cfg.successors(block_id) {
if !worklist.contains(&succ_id) {
worklist.push_back(succ_id);
}
}
} else {
reaching_in.insert(block_id, new_in);
}
}
ReachingDefs {
definitions,
reaching_in,
reaching_out,
}
}
pub fn extract_definitions(cfg: &FunctionCfg, func_node: &Node, source: &str) -> Vec<Definition> {
let mut definitions = Vec::new();
if let Some(declarator) = func_node.child_by_field_name("declarator") {
extract_param_definitions(&declarator, source, &mut definitions);
}
if let Some(body) = func_node.child_by_field_name("body") {
for block in &cfg.blocks {
for (stmt_idx, &(start, end)) in block.statements.iter().enumerate() {
let stmt_node = find_node_at_range(&body, start, end);
if let Some(node) = stmt_node {
extract_definitions_from_node(
&node,
source,
block.id,
stmt_idx,
&mut definitions,
);
}
}
}
}
definitions
}
fn extract_param_definitions(declarator: &Node, source: &str, definitions: &mut Vec<Definition>) {
if declarator.kind() == "function_declarator" {
if let Some(params) = declarator.child_by_field_name("parameters") {
for i in 0..params.child_count() {
if let Some(param) = params.child(i) {
if param.kind() == "parameter_declaration" {
if let Some(param_decl) = param.child_by_field_name("declarator") {
let name = extract_identifier(¶m_decl, source);
if !name.is_empty() {
let param_text = param.utf8_text(source.as_bytes()).unwrap_or("");
let kind = if param_text.contains('*') {
DefinitionKind::NullableCall } else {
DefinitionKind::Parameter
};
definitions.push(Definition {
variable: name,
block_id: 0, statement_index: 0,
kind,
byte_offset: param.start_byte(),
});
}
}
}
}
}
}
} else {
for i in 0..declarator.child_count() {
if let Some(child) = declarator.child(i) {
extract_param_definitions(&child, source, definitions);
}
}
}
}
fn extract_definitions_from_node(
node: &Node,
source: &str,
block_id: BlockId,
stmt_idx: usize,
definitions: &mut Vec<Definition>,
) {
match node.kind() {
"declaration" => {
extract_declaration_defs(node, source, block_id, stmt_idx, definitions);
}
"expression_statement" => {
if let Some(expr) = node.child(0) {
extract_expression_defs(&expr, source, block_id, stmt_idx, definitions);
}
}
"assignment_expression" => {
extract_expression_defs(node, source, block_id, stmt_idx, definitions);
}
_ => {
for i in 0..node.child_count() {
if let Some(child) = node.child(i) {
extract_definitions_from_node(&child, source, block_id, stmt_idx, definitions);
}
}
}
}
}
fn extract_declaration_defs(
node: &Node,
source: &str,
block_id: BlockId,
stmt_idx: usize,
definitions: &mut Vec<Definition>,
) {
for i in 0..node.child_count() {
if let Some(child) = node.child(i) {
if child.kind() == "init_declarator" {
if let Some(declarator) = child.child_by_field_name("declarator") {
let name = extract_identifier(&declarator, source);
if name.is_empty() {
continue;
}
let kind = if let Some(value) = child.child_by_field_name("value") {
classify_rvalue(&value, source)
} else if is_pointer_type_declarator(&declarator) {
DefinitionKind::NullAssignment } else {
DefinitionKind::Declaration
};
definitions.push(Definition {
variable: name,
block_id,
statement_index: stmt_idx,
kind,
byte_offset: child.start_byte(),
});
}
}
}
}
}
fn extract_expression_defs(
node: &Node,
source: &str,
block_id: BlockId,
stmt_idx: usize,
definitions: &mut Vec<Definition>,
) {
if node.kind() == "assignment_expression" {
if let (Some(left), Some(right)) = (
node.child_by_field_name("left"),
node.child_by_field_name("right"),
) {
if left.kind() == "identifier" {
let name = left.utf8_text(source.as_bytes()).unwrap_or("").to_string();
let kind = classify_rvalue(&right, source);
definitions.push(Definition {
variable: name,
block_id,
statement_index: stmt_idx,
kind,
byte_offset: node.start_byte(),
});
}
}
}
if node.kind() == "call_expression" {
if let Some(func) = node.child_by_field_name("function") {
let func_name = func.utf8_text(source.as_bytes()).unwrap_or("");
if func_name == "free" {
if let Some(args) = node.child_by_field_name("arguments") {
for i in 0..args.child_count() {
if let Some(arg) = args.child(i) {
if arg.kind() == "identifier" {
let name =
arg.utf8_text(source.as_bytes()).unwrap_or("").to_string();
definitions.push(Definition {
variable: name,
block_id,
statement_index: stmt_idx,
kind: DefinitionKind::FreeCall,
byte_offset: node.start_byte(),
});
}
}
}
}
}
}
}
for i in 0..node.child_count() {
if let Some(child) = node.child(i) {
if child.kind() != "assignment_expression" {
extract_expression_defs(&child, source, block_id, stmt_idx, definitions);
}
}
}
}
fn classify_rvalue(node: &Node, source: &str) -> DefinitionKind {
let text = node.utf8_text(source.as_bytes()).unwrap_or("").trim();
if text == "NULL" || text == "0" || text == "nullptr" {
return DefinitionKind::NullAssignment;
}
if node.kind() == "cast_expression" {
if let Some(value) = node.child_by_field_name("value") {
let val_text = value.utf8_text(source.as_bytes()).unwrap_or("").trim();
if val_text == "NULL" || val_text == "0" {
return DefinitionKind::NullAssignment;
}
}
}
if node.kind() == "call_expression" {
if let Some(func) = node.child_by_field_name("function") {
let func_name = func.utf8_text(source.as_bytes()).unwrap_or("");
if is_nullable_function(func_name) {
return DefinitionKind::NullableCall;
}
}
}
if node.kind() == "cast_expression" {
if let Some(value) = node.child_by_field_name("value") {
if value.kind() == "call_expression" {
if let Some(func) = value.child_by_field_name("function") {
let func_name = func.utf8_text(source.as_bytes()).unwrap_or("");
if is_nullable_function(func_name) {
return DefinitionKind::NullableCall;
}
}
}
}
}
DefinitionKind::Assignment
}
fn is_nullable_function(name: &str) -> bool {
matches!(
name,
"malloc"
| "calloc"
| "realloc"
| "aligned_alloc"
| "fopen"
| "fdopen"
| "freopen"
| "tmpfile"
| "fgets"
| "gets"
| "getenv"
| "setlocale"
| "strtok"
| "bsearch"
| "strstr"
| "strchr"
| "strrchr"
| "strdup"
| "strndup"
| "strpbrk"
| "memchr"
| "localtime"
| "gmtime"
| "asctime"
| "ctime"
)
}
fn extract_identifier(node: &Node, source: &str) -> String {
match node.kind() {
"identifier" => node.utf8_text(source.as_bytes()).unwrap_or("").to_string(),
"pointer_declarator" | "array_declarator" => {
if let Some(inner) = node.child_by_field_name("declarator") {
extract_identifier(&inner, source)
} else {
String::new()
}
}
_ => {
for i in 0..node.child_count() {
if let Some(child) = node.child(i) {
if child.kind() == "identifier" {
return child.utf8_text(source.as_bytes()).unwrap_or("").to_string();
}
}
}
String::new()
}
}
}
fn is_pointer_type_declarator(node: &Node) -> bool {
node.kind() == "pointer_declarator"
}
pub fn find_node_at_range<'a>(root: &Node<'a>, start: usize, end: usize) -> Option<Node<'a>> {
if root.start_byte() == start && root.end_byte() == end {
return Some(*root);
}
for i in 0..root.child_count() {
if let Some(child) = root.child(i) {
if child.start_byte() <= start && child.end_byte() >= end {
if let Some(found) = find_node_at_range(&child, start, end) {
return Some(found);
}
return Some(child);
}
}
}
None
}
#[cfg(test)]
mod tests {
use super::*;
use crate::analyze::cfg::build_function_cfg;
fn parse_function(code: &str) -> (tree_sitter::Tree, String) {
let mut parser = tree_sitter::Parser::new();
parser.set_language(&tree_sitter_c::language()).unwrap();
let tree = parser.parse(code, None).unwrap();
(tree, code.to_string())
}
fn get_func_node<'a>(tree: &'a tree_sitter::Tree) -> Option<tree_sitter::Node<'a>> {
let root = tree.root_node();
for i in 0..root.child_count() {
if let Some(child) = root.child(i) {
if child.kind() == "function_definition" {
return Some(child);
}
}
}
None
}
#[test]
fn test_reaching_defs_simple() {
let code = r#"
void foo() {
int *p = malloc(10);
if (p == NULL) {
return;
}
*p = 42;
free(p);
}
"#;
let (tree, source) = parse_function(code);
let func_node = get_func_node(&tree).unwrap();
let cfg = build_function_cfg(&func_node, &source).unwrap();
let defs = extract_definitions(&cfg, &func_node, &source);
assert!(!defs.is_empty());
let has_p_def = defs.iter().any(|d| d.variable == "p");
assert!(has_p_def);
}
#[test]
fn test_free_detection() {
let code = r#"
void foo(int *p) {
free(p);
*p = 42;
}
"#;
let (tree, source) = parse_function(code);
let func_node = get_func_node(&tree).unwrap();
let cfg = build_function_cfg(&func_node, &source).unwrap();
let defs = extract_definitions(&cfg, &func_node, &source);
let has_free = defs
.iter()
.any(|d| d.variable == "p" && d.kind == DefinitionKind::FreeCall);
assert!(has_free);
}
#[test]
fn test_null_assignment_detection() {
let code = r#"
void foo() {
int *p = NULL;
p = malloc(10);
}
"#;
let (tree, source) = parse_function(code);
let func_node = get_func_node(&tree).unwrap();
let cfg = build_function_cfg(&func_node, &source).unwrap();
let defs = extract_definitions(&cfg, &func_node, &source);
let has_null = defs
.iter()
.any(|d| d.variable == "p" && d.kind == DefinitionKind::NullAssignment);
let has_nullable = defs
.iter()
.any(|d| d.variable == "p" && d.kind == DefinitionKind::NullableCall);
assert!(has_null);
assert!(has_nullable);
}
#[test]
fn test_reaching_defs_end_to_end() {
let code = r#"
void foo() {
int *p = malloc(10);
if (p == NULL) {
return;
}
*p = 42;
free(p);
}
"#;
let (tree, source) = parse_function(code);
let func_node = get_func_node(&tree).unwrap();
let cfg = build_function_cfg(&func_node, &source).unwrap();
let defs = extract_definitions(&cfg, &func_node, &source);
let reaching = compute_reaching_definitions(&cfg, defs);
assert!(reaching.reaching_out.values().any(|s| !s.is_empty()));
}
#[test]
fn test_defs_reaching_block() {
let code = r#"
void foo(int *p) {
if (p) {
free(p);
}
int x = *p;
}
"#;
let (tree, source) = parse_function(code);
let func_node = get_func_node(&tree).unwrap();
let cfg = build_function_cfg(&func_node, &source).unwrap();
let defs = extract_definitions(&cfg, &func_node, &source);
let reaching = compute_reaching_definitions(&cfg, defs);
let has_p_anywhere = cfg
.blocks
.iter()
.any(|b| !reaching.defs_reaching_block(b.id, "p").is_empty());
let has_p_in_out = reaching.reaching_out.values().any(|s| {
s.iter()
.any(|&idx| reaching.definitions[idx].variable == "p")
});
assert!(
has_p_anywhere || has_p_in_out,
"p should be tracked somewhere"
);
}
#[test]
fn test_is_potentially_freed() {
let code = r#"
void foo() {
int *p = malloc(10);
if (p) {
free(p);
}
int *q = p;
}
"#;
let (tree, source) = parse_function(code);
let func_node = get_func_node(&tree).unwrap();
let cfg = build_function_cfg(&func_node, &source).unwrap();
let defs = extract_definitions(&cfg, &func_node, &source);
let reaching = compute_reaching_definitions(&cfg, defs);
let _any_freed = cfg
.blocks
.iter()
.any(|b| reaching.is_potentially_freed(b.id, "p"));
let has_free_def = reaching
.definitions
.iter()
.any(|d| d.variable == "p" && d.kind == DefinitionKind::FreeCall);
assert!(has_free_def, "Should have free(p) in definitions");
}
#[test]
fn test_is_potentially_null() {
let code = r#"
void foo(int flag) {
int *p = NULL;
if (flag) {
p = malloc(10);
}
int *q = p;
}
"#;
let (tree, source) = parse_function(code);
let func_node = get_func_node(&tree).unwrap();
let cfg = build_function_cfg(&func_node, &source).unwrap();
let defs = extract_definitions(&cfg, &func_node, &source);
let reaching = compute_reaching_definitions(&cfg, defs);
let _any_null = cfg
.blocks
.iter()
.any(|b| reaching.is_potentially_null(b.id, "p"));
let has_null_def = reaching
.definitions
.iter()
.any(|d| d.variable == "p" && d.kind == DefinitionKind::NullAssignment);
assert!(has_null_def, "Should have NULL assignment in definitions");
}
#[test]
fn test_cast_to_null_detection() {
let code = r#"
void foo() {
int *p = (int*)NULL;
}
"#;
let (tree, source) = parse_function(code);
let func_node = get_func_node(&tree).unwrap();
let cfg = build_function_cfg(&func_node, &source).unwrap();
let defs = extract_definitions(&cfg, &func_node, &source);
let has_null = defs
.iter()
.any(|d| d.variable == "p" && d.kind == DefinitionKind::NullAssignment);
assert!(has_null, "Should detect (int*)NULL as null assignment");
}
#[test]
fn test_cast_wrapped_nullable_call() {
let code = r#"
void foo() {
int *p = (int*)malloc(10);
}
"#;
let (tree, source) = parse_function(code);
let func_node = get_func_node(&tree).unwrap();
let cfg = build_function_cfg(&func_node, &source).unwrap();
let defs = extract_definitions(&cfg, &func_node, &source);
let has_nullable = defs
.iter()
.any(|d| d.variable == "p" && d.kind == DefinitionKind::NullableCall);
assert!(
has_nullable,
"Should detect (int*)malloc() as nullable call"
);
}
#[test]
fn test_multiple_defs_same_var() {
let code = r#"
void foo() {
int *p = NULL;
p = malloc(10);
p = NULL;
}
"#;
let (tree, source) = parse_function(code);
let func_node = get_func_node(&tree).unwrap();
let cfg = build_function_cfg(&func_node, &source).unwrap();
let defs = extract_definitions(&cfg, &func_node, &source);
let p_defs: Vec<_> = defs.iter().filter(|d| d.variable == "p").collect();
assert!(p_defs.len() >= 2, "Should find multiple defs of p");
}
#[test]
fn test_parameter_pointer_detection() {
let code = r#"
void foo(int *ptr, int val) {
*ptr = val;
}
"#;
let (tree, source) = parse_function(code);
let func_node = get_func_node(&tree).unwrap();
let cfg = build_function_cfg(&func_node, &source).unwrap();
let defs = extract_definitions(&cfg, &func_node, &source);
let has_ptr_param = defs.iter().any(|d| d.variable == "ptr");
assert!(has_ptr_param, "Should extract parameter definition for ptr");
}
#[test]
fn test_nullptr_detection() {
let code = r#"
void foo() {
int *p = 0;
}
"#;
let (tree, source) = parse_function(code);
let func_node = get_func_node(&tree).unwrap();
let cfg = build_function_cfg(&func_node, &source).unwrap();
let defs = extract_definitions(&cfg, &func_node, &source);
let has_null = defs
.iter()
.any(|d| d.variable == "p" && d.kind == DefinitionKind::NullAssignment);
assert!(has_null, "Should detect 0 as null assignment");
}
#[test]
fn test_non_pointer_parameter() {
let code = r#"
void foo(int val) {
int x = val;
}
"#;
let (tree, source) = parse_function(code);
let func_node = get_func_node(&tree).unwrap();
let cfg = build_function_cfg(&func_node, &source).unwrap();
let defs = extract_definitions(&cfg, &func_node, &source);
let val_param = defs.iter().find(|d| d.variable == "val");
assert!(val_param.is_some(), "Should detect non-pointer parameter");
assert_eq!(val_param.unwrap().kind, DefinitionKind::Parameter);
}
#[test]
fn test_plain_assignment_detection() {
let code = r#"
void foo() {
int x = 5;
}
"#;
let (tree, source) = parse_function(code);
let func_node = get_func_node(&tree).unwrap();
let cfg = build_function_cfg(&func_node, &source).unwrap();
let defs = extract_definitions(&cfg, &func_node, &source);
let has_assign = defs
.iter()
.any(|d| d.variable == "x" && d.kind == DefinitionKind::Assignment);
assert!(has_assign, "Should detect plain assignment");
}
#[test]
fn test_uninitialized_pointer_declaration() {
let code = r#"
void foo() {
int *p;
}
"#;
let (tree, source) = parse_function(code);
let func_node = get_func_node(&tree).unwrap();
let cfg = build_function_cfg(&func_node, &source).unwrap();
let defs = extract_definitions(&cfg, &func_node, &source);
let p_defs: Vec<_> = defs.iter().filter(|d| d.variable == "p").collect();
let _ = p_defs;
}
#[test]
fn test_nullable_functions() {
assert!(is_nullable_function("malloc"));
assert!(is_nullable_function("calloc"));
assert!(is_nullable_function("realloc"));
assert!(is_nullable_function("fopen"));
assert!(is_nullable_function("strdup"));
assert!(is_nullable_function("getenv"));
assert!(!is_nullable_function("printf"));
assert!(!is_nullable_function("free"));
assert!(!is_nullable_function("memcpy"));
}
#[test]
fn test_find_node_at_range_basic() {
let code = "int x = 42;\n";
let mut parser = tree_sitter::Parser::new();
parser.set_language(&tree_sitter_c::language()).unwrap();
let tree = parser.parse(code, None).unwrap();
let root = tree.root_node();
let decl = root.child(0).unwrap();
let found = find_node_at_range(&root, decl.start_byte(), decl.end_byte());
assert!(found.is_some());
}
#[test]
fn test_gen_kill_multiple_blocks() {
let code = r#"
void foo(int flag) {
int *p = malloc(10);
if (flag) {
p = NULL;
} else {
free(p);
}
int *q = p;
}
"#;
let (tree, source) = parse_function(code);
let func_node = get_func_node(&tree).unwrap();
let cfg = build_function_cfg(&func_node, &source).unwrap();
let defs = extract_definitions(&cfg, &func_node, &source);
let reaching = compute_reaching_definitions(&cfg, defs);
let has_null = reaching
.definitions
.iter()
.any(|d| d.variable == "p" && d.kind == DefinitionKind::NullAssignment);
let has_free = reaching
.definitions
.iter()
.any(|d| d.variable == "p" && d.kind == DefinitionKind::FreeCall);
assert!(has_null, "Should have NULL assignment for p");
assert!(has_free, "Should have free(p)");
}
#[test]
fn test_reassignment_kills_previous() {
let code = r#"
void foo() {
int *p = malloc(10);
p = NULL;
int *q = p;
}
"#;
let (tree, source) = parse_function(code);
let func_node = get_func_node(&tree).unwrap();
let cfg = build_function_cfg(&func_node, &source).unwrap();
let defs = extract_definitions(&cfg, &func_node, &source);
let reaching = compute_reaching_definitions(&cfg, defs);
let last_block = cfg.blocks.last().unwrap();
let _p_reaching = reaching.defs_reaching_block(last_block.id, "p");
assert!(!reaching.definitions.is_empty(), "Should have definitions");
}
}