use super::cfg::{BasicBlock, BlockId, CfgEdge, FunctionCfg};
use super::dataflow::find_node_at_range;
use crate::analyze::function_summary::FunctionSummary;
use std::collections::{HashMap, HashSet, VecDeque};
use tree_sitter::Node;
#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
pub enum NullState {
Unknown,
DefinitelyNull,
PossiblyNull,
NotNull,
}
impl NullState {
pub fn join(self, other: NullState) -> NullState {
use NullState::*;
if self == other {
return self;
}
match (self, other) {
(Unknown, x) | (x, Unknown) => x,
_ => PossiblyNull,
}
}
pub fn is_unsafe(self) -> bool {
matches!(self, NullState::DefinitelyNull | NullState::PossiblyNull)
}
}
pub type StateMap = HashMap<String, NullState>;
fn join_states(a: &StateMap, b: &StateMap) -> StateMap {
let mut result = a.clone();
for (var, &state_b) in b {
let entry = result.entry(var.clone()).or_insert(NullState::Unknown);
*entry = entry.join(state_b);
}
result
}
pub struct NullAnalysisResult {
pub block_entry_states: HashMap<BlockId, StateMap>,
#[allow(dead_code)]
pub block_exit_states: HashMap<BlockId, StateMap>,
pub declared_pointers: HashSet<String>,
}
struct ConditionInfo {
var_name: String,
true_state: NullState,
false_state: NullState,
}
fn parse_all_null_conditions(node: &Node, source: &str) -> Vec<ConditionInfo> {
match node.kind() {
"parenthesized_expression" => {
node.child(1)
.map(|inner| parse_all_null_conditions(&inner, source))
.unwrap_or_default()
}
"binary_expression" => {
let Some(left) = node.child_by_field_name("left") else {
return Vec::new();
};
let Some(operator) = node.child_by_field_name("operator") else {
return Vec::new();
};
let Some(right) = node.child_by_field_name("right") else {
return Vec::new();
};
let op = get_text(&operator, source);
match op.as_str() {
"==" => {
if let Some(var) = extract_null_check_var(&left, &right, source) {
return vec![ConditionInfo {
var_name: var,
true_state: NullState::DefinitelyNull,
false_state: NullState::NotNull,
}];
}
Vec::new()
}
"!=" => {
if let Some(var) = extract_null_check_var(&left, &right, source) {
return vec![ConditionInfo {
var_name: var,
true_state: NullState::NotNull,
false_state: NullState::DefinitelyNull,
}];
}
Vec::new()
}
"||" => {
let mut all = parse_all_null_conditions(&left, source);
all.extend(parse_all_null_conditions(&right, source));
all
}
"&&" => {
let mut all = parse_all_null_conditions(&left, source);
all.extend(parse_all_null_conditions(&right, source));
all
}
_ => Vec::new(),
}
}
"unary_expression" => {
let Some(operator) = node.child(0) else {
return Vec::new();
};
if get_text(&operator, source) == "!" {
let Some(arg) = node.child_by_field_name("argument") else {
return Vec::new();
};
if arg.kind() == "identifier" {
return vec![ConditionInfo {
var_name: get_text(&arg, source),
true_state: NullState::DefinitelyNull,
false_state: NullState::NotNull,
}];
}
}
Vec::new()
}
"identifier" => {
vec![ConditionInfo {
var_name: get_text(node, source),
true_state: NullState::NotNull,
false_state: NullState::DefinitelyNull,
}]
}
_ => Vec::new(),
}
}
fn extract_null_check_var(left: &Node, right: &Node, source: &str) -> Option<String> {
let lt = get_text(left, source);
let rt = get_text(right, source);
if is_null_value(&rt) && left.kind() == "identifier" {
Some(lt)
} else if is_null_value(<) && right.kind() == "identifier" {
Some(rt)
} else {
None
}
}
fn apply_transfer(
block: &BasicBlock,
entry: &StateMap,
body_node: &Node,
source: &str,
declared_pointers: &mut HashSet<String>,
summaries: &HashMap<String, FunctionSummary>,
) -> StateMap {
let mut state = entry.clone();
for &(start, end) in &block.statements {
if let Some(stmt_node) = find_node_at_range(body_node, start, end) {
process_statement_for_null_state(
&stmt_node,
source,
&mut state,
declared_pointers,
summaries,
);
}
}
state
}
fn process_statement_for_null_state(
node: &Node,
source: &str,
state: &mut StateMap,
declared_pointers: &mut HashSet<String>,
summaries: &HashMap<String, FunctionSummary>,
) {
match node.kind() {
"declaration" => {
process_declaration_null(node, source, state, declared_pointers, summaries);
}
"expression_statement" => {
process_assert_for_null_state(node, source, state);
if let Some(expr) = node.child(0) {
process_expression_null(&expr, source, state, declared_pointers, summaries);
}
}
"assignment_expression" => {
process_expression_null(node, source, state, declared_pointers, summaries);
}
"switch_statement" => {
if let Some(body) = node.child_by_field_name("body") {
walk_switch_body_for_null_state(&body, source, state, declared_pointers, summaries);
}
}
_ => {
process_assert_for_null_state(node, source, state);
for i in 0..node.child_count() {
if let Some(child) = node.child(i) {
if child.kind() == "assignment_expression" {
process_expression_null(
&child,
source,
state,
declared_pointers,
summaries,
);
}
}
}
}
}
}
fn walk_switch_body_for_null_state(
node: &Node,
source: &str,
state: &mut StateMap,
declared_pointers: &mut HashSet<String>,
summaries: &HashMap<String, FunctionSummary>,
) {
for i in 0..node.child_count() {
if let Some(child) = node.child(i) {
match child.kind() {
"case_statement" | "compound_statement" => {
walk_switch_body_for_null_state(
&child,
source,
state,
declared_pointers,
summaries,
);
}
"declaration" | "expression_statement" | "assignment_expression" => {
process_statement_for_null_state(
&child,
source,
state,
declared_pointers,
summaries,
);
}
_ => {}
}
}
}
}
fn process_declaration_null(
node: &Node,
source: &str,
state: &mut StateMap,
declared_pointers: &mut HashSet<String>,
summaries: &HashMap<String, FunctionSummary>,
) {
for i in 0..node.child_count() {
if let Some(child) = node.child(i) {
if child.kind() == "init_declarator" {
if let Some(declarator) = child.child_by_field_name("declarator") {
let var_name = get_identifier_from_declarator(&declarator, source);
if var_name.is_empty() {
continue;
}
let is_ptr = is_pointer_declarator(&declarator) && !contains_array(&declarator);
if is_ptr {
declared_pointers.insert(var_name.clone());
}
if let Some(value) = child.child_by_field_name("value") {
if is_ptr {
let rval = classify_rvalue_null(&value, source, summaries);
if rval == NullState::NotNull && value.kind() == "identifier" {
let src_name = get_text(&value, source);
if let Some(&src_state) = state.get(&src_name) {
state.insert(var_name, src_state);
continue;
}
}
if rval == NullState::NotNull && value.kind() == "field_expression" {
if let Some(arg) = value.child_by_field_name("argument") {
let base = get_text(&arg, source);
if let Some(field_node) = value.child_by_field_name("field") {
let field_name = get_text(&field_node, source);
let dotted = format!("{}.{}", base, field_name);
if let Some(&field_state) = state.get(&dotted) {
state.insert(var_name, field_state);
continue;
}
}
if let Some(&base_state) = state.get(&base) {
if base_state.is_unsafe() {
state.insert(var_name, NullState::PossiblyNull);
continue;
}
}
}
}
if rval == NullState::NotNull && value.kind() == "subscript_expression"
{
if let (Some(arg), Some(idx)) = (
value.child_by_field_name("argument"),
value.child_by_field_name("index"),
) {
let base = get_text(&arg, source);
let index = get_text(&idx, source);
let dotted = format!("{}.{}", base, index);
if let Some(&elem_state) = state.get(&dotted) {
state.insert(var_name, elem_state);
continue;
}
}
}
if rval == NullState::NotNull {
if let Some(deref_state) =
extract_deref_pointee_state(&value, source, state)
{
state.insert(var_name, deref_state);
continue;
}
}
if value.kind() == "cast_expression" {
if let Some(inner) = value.child_by_field_name("value") {
let inner = unwrap_parens(&inner);
if inner.kind() == "identifier" {
let inner_name = get_text(&inner, source);
let src_key = format!("*{}", inner_name);
if let Some(&s) = state.get(&src_key) {
let dst_key = format!("*{}", var_name);
state.insert(dst_key, s);
}
}
}
}
state.insert(var_name, rval);
}
} else if is_ptr {
state.insert(var_name, NullState::PossiblyNull);
}
}
} else if child.kind() == "pointer_declarator" || child.kind() == "identifier" {
let var_name = get_identifier_from_declarator(&child, source);
if !var_name.is_empty() && is_pointer_declarator(&child) && !contains_array(&child)
{
declared_pointers.insert(var_name.clone());
state.insert(var_name, NullState::PossiblyNull);
}
}
}
}
}
fn process_expression_null(
node: &Node,
source: &str,
state: &mut StateMap,
declared_pointers: &HashSet<String>,
summaries: &HashMap<String, FunctionSummary>,
) {
if node.kind() == "assignment_expression" {
if let (Some(left), Some(right)) = (
node.child_by_field_name("left"),
node.child_by_field_name("right"),
) {
let left_name = get_text(&left, source);
let left_is_ptr = left.kind() != "identifier" || declared_pointers.contains(&left_name);
if !left_is_ptr {
return;
}
let new_state = classify_rvalue_null(&right, source, summaries);
if new_state == NullState::NotNull && right.kind() == "identifier" {
let right_name = get_text(&right, source);
if let Some(&rhs_state) = state.get(&right_name) {
state.insert(left_name, rhs_state);
return;
}
}
if new_state == NullState::NotNull && right.kind() == "field_expression" {
if let Some(arg) = right.child_by_field_name("argument") {
let base = get_text(&arg, source);
if let Some(field_node) = right.child_by_field_name("field") {
let field_name = get_text(&field_node, source);
let dotted = format!("{}.{}", base, field_name);
if let Some(&field_state) = state.get(&dotted) {
state.insert(left_name, field_state);
return;
}
}
if let Some(&base_state) = state.get(&base) {
if base_state.is_unsafe() {
state.insert(left_name, NullState::PossiblyNull);
return;
}
}
}
}
if new_state == NullState::NotNull && right.kind() == "subscript_expression" {
if let (Some(arg), Some(idx)) = (
right.child_by_field_name("argument"),
right.child_by_field_name("index"),
) {
let base = get_text(&arg, source);
let index = get_text(&idx, source);
let dotted = format!("{}.{}", base, index);
if let Some(&elem_state) = state.get(&dotted) {
state.insert(left_name, elem_state);
return;
}
}
}
if new_state == NullState::NotNull {
if let Some(deref_state) = extract_deref_pointee_state(&right, source, state) {
state.insert(left_name, deref_state);
return;
}
}
if right.kind() == "cast_expression" {
if let Some(inner) = right.child_by_field_name("value") {
let inner = unwrap_parens(&inner);
if inner.kind() == "identifier" {
let inner_name = get_text(&inner, source);
let src_key = format!("*{}", inner_name);
if let Some(&s) = state.get(&src_key) {
let dst_key = format!("*{}", left_name);
state.insert(dst_key, s);
}
}
}
}
if right.kind() == "call_expression" && new_state == NullState::NotNull {
state.insert(left_name, NullState::NotNull);
return;
}
state.insert(left_name, new_state);
}
}
}
fn process_assert_for_null_state(node: &Node, source: &str, state: &mut StateMap) {
let call_node = if node.kind() == "expression_statement" {
node.child(0)
} else if node.kind() == "call_expression" {
Some(*node)
} else {
None
};
if let Some(call) = call_node {
if call.kind() != "call_expression" {
return;
}
if let Some(function) = call.child_by_field_name("function") {
let func_name = get_text(&function, source);
if func_name != "assert" {
return;
}
if let Some(args) = call.child_by_field_name("arguments") {
for i in 0..args.child_count() {
if let Some(arg) = args.child(i) {
if arg.kind() == "(" || arg.kind() == ")" || arg.kind() == "," {
continue;
}
if arg.kind() == "identifier" {
let name = get_text(&arg, source);
state.insert(name, NullState::NotNull);
return;
}
if arg.kind() == "binary_expression" {
if let (Some(left), Some(right)) = (
arg.child_by_field_name("left"),
arg.child_by_field_name("right"),
) {
let lt = get_text(&left, source);
let rt = get_text(&right, source);
if is_null_value(rt.trim()) && left.kind() == "identifier" {
state.insert(lt, NullState::NotNull);
} else if is_null_value(lt.trim()) && right.kind() == "identifier" {
state.insert(rt, NullState::NotNull);
}
}
return;
}
}
}
}
}
}
}
fn unwrap_parens<'a>(node: &'a Node<'a>) -> Node<'a> {
let mut n = *node;
while n.kind() == "parenthesized_expression" {
if let Some(inner) = n.child(1) {
n = inner;
} else {
break;
}
}
n
}
fn extract_deref_pointee_state(node: &Node, source: &str, state: &StateMap) -> Option<NullState> {
let inner = unwrap_parens(node);
if inner.kind() == "pointer_expression" {
if let Some(op) = inner.child_by_field_name("operator") {
if get_text(&op, source) == "*" {
if let Some(arg) = inner.child_by_field_name("argument") {
let arg_name = get_text(&arg, source);
let deref_key = format!("*{}", arg_name);
if let Some(&deref_state) = state.get(&deref_key) {
return Some(deref_state);
}
}
}
}
}
None
}
fn classify_rvalue_null(
node: &Node,
source: &str,
summaries: &HashMap<String, FunctionSummary>,
) -> NullState {
let text = get_text(node, source);
let trimmed = text.trim();
if is_null_value(trimmed) {
return NullState::DefinitelyNull;
}
if node.kind() == "cast_expression" {
if let Some(value) = node.child_by_field_name("value") {
let vt = get_text(&value, source);
if is_null_value(vt.trim()) {
return NullState::DefinitelyNull;
}
}
}
if node.kind() == "call_expression" {
if let Some(function) = node.child_by_field_name("function") {
let func_name = get_text(&function, source);
if is_nullable_function(&func_name, summaries) {
return NullState::PossiblyNull;
}
}
}
if node.kind() == "cast_expression" {
if let Some(value) = node.child_by_field_name("value") {
if value.kind() == "call_expression" {
if let Some(function) = value.child_by_field_name("function") {
let func_name = get_text(&function, source);
if is_nullable_function(&func_name, summaries) {
return NullState::PossiblyNull;
}
}
}
}
}
if node.kind() == "pointer_expression" {
if let Some(op) = node.child_by_field_name("operator") {
if get_text(&op, source) == "&" {
return NullState::NotNull;
}
}
}
if node.kind() == "string_literal" {
return NullState::NotNull;
}
NullState::NotNull
}
pub fn collect_file_scope_null_states(
root: &Node,
source: &str,
summaries: &HashMap<String, FunctionSummary>,
) -> StateMap {
let mut global_vars: HashSet<String> = HashSet::new();
let mut result: StateMap = StateMap::new();
collect_file_scope_pointer_decls(root, source, &mut global_vars, &mut result, summaries);
if global_vars.is_empty() {
return result;
}
collect_global_assignments(root, source, &global_vars, &mut result, summaries);
result
}
fn collect_file_scope_pointer_decls(
node: &Node,
source: &str,
global_vars: &mut HashSet<String>,
result: &mut StateMap,
summaries: &HashMap<String, FunctionSummary>,
) {
for i in 0..node.child_count() {
let child = match node.child(i) {
Some(c) => c,
None => continue,
};
match child.kind() {
"declaration" => {
for j in 0..child.child_count() {
if let Some(declarator) = child.child(j) {
if declarator.kind() == "init_declarator" {
if let Some(decl) = declarator.child_by_field_name("declarator") {
if is_pointer_declarator(&decl) && !contains_array(&decl) {
let name = get_identifier_from_declarator(&decl, source);
if !name.is_empty() {
global_vars.insert(name.clone());
if let Some(value) = declarator.child_by_field_name("value")
{
let state =
classify_rvalue_null(&value, source, summaries);
result.insert(name, state);
}
}
}
}
} else if is_pointer_declarator(&declarator) && !contains_array(&declarator)
{
let name = get_identifier_from_declarator(&declarator, source);
if !name.is_empty()
&& declarator.kind() != "storage_class_specifier"
&& declarator.kind() != "type_qualifier"
&& declarator.kind() != "primitive_type"
&& declarator.kind() != "type_identifier"
{
global_vars.insert(name);
}
}
}
}
}
k if k.starts_with("preproc_") => {
collect_file_scope_pointer_decls(&child, source, global_vars, result, summaries);
}
_ => {}
}
}
}
fn collect_global_assignments(
node: &Node,
source: &str,
global_vars: &HashSet<String>,
result: &mut StateMap,
summaries: &HashMap<String, FunctionSummary>,
) {
for i in 0..node.child_count() {
let child = match node.child(i) {
Some(c) => c,
None => continue,
};
match child.kind() {
"function_definition" => {
if let Some(body) = child.child_by_field_name("body") {
scan_body_for_global_assignments(&body, source, global_vars, result, summaries);
}
}
k if k.starts_with("preproc_") => {
collect_global_assignments(&child, source, global_vars, result, summaries);
}
_ => {}
}
}
}
fn scan_body_for_global_assignments(
node: &Node,
source: &str,
global_vars: &HashSet<String>,
result: &mut StateMap,
summaries: &HashMap<String, FunctionSummary>,
) {
if node.kind() == "assignment_expression" {
if let Some(left) = node.child_by_field_name("left") {
let var_name = get_text(&left, source);
if global_vars.contains(&var_name) {
if let Some(right) = node.child_by_field_name("right") {
let rhs_text = get_text(&right, source);
let new_state =
if right.kind() == "identifier" && global_vars.contains(&rhs_text) {
result.get(&rhs_text).copied().unwrap_or(NullState::Unknown)
} else if right.kind() == "identifier" {
if is_null_value(&rhs_text) {
NullState::DefinitelyNull
} else {
check_preceding_null_assign(node, &rhs_text, source)
}
} else {
classify_rvalue_null(&right, source, summaries)
};
let existing = result.get(&var_name).copied().unwrap_or(NullState::Unknown);
result.insert(var_name, existing.join(new_state));
}
}
}
}
for i in 0..node.child_count() {
if let Some(child) = node.child(i) {
scan_body_for_global_assignments(&child, source, global_vars, result, summaries);
}
}
}
fn check_preceding_null_assign(assignment_node: &Node, var_name: &str, source: &str) -> NullState {
let expr_stmt = if assignment_node.parent().map(|p| p.kind()) == Some("expression_statement") {
assignment_node.parent().unwrap()
} else {
return NullState::Unknown;
};
if let Some(prev) = expr_stmt.prev_sibling() {
if prev.kind() == "expression_statement" {
if let Some(expr) = prev.child(0) {
if expr.kind() == "assignment_expression" {
if let Some(left) = expr.child_by_field_name("left") {
if get_text(&left, source) == var_name {
if let Some(right) = expr.child_by_field_name("right") {
let rhs = get_text(&right, source);
if is_null_value(rhs.trim()) {
return NullState::DefinitelyNull;
}
return classify_rvalue_null(&right, source, &HashMap::new());
}
}
}
}
}
}
if prev.kind() == "declaration" {
for i in 0..prev.child_count() {
if let Some(child) = prev.child(i) {
if child.kind() == "init_declarator" {
if let Some(decl) = child.child_by_field_name("declarator") {
let name = get_identifier_from_declarator(&decl, source);
if name == var_name {
if let Some(value) = child.child_by_field_name("value") {
let vtext = get_text(&value, source);
if is_null_value(vtext.trim()) {
return NullState::DefinitelyNull;
}
return classify_rvalue_null(&value, source, &HashMap::new());
}
}
}
}
}
}
}
}
NullState::Unknown
}
#[allow(dead_code)]
pub fn analyze_null_states(
cfg: &FunctionCfg,
func_node: &Node,
source: &str,
summaries: &HashMap<String, FunctionSummary>,
) -> NullAnalysisResult {
analyze_null_states_with_globals(cfg, func_node, source, summaries, &StateMap::new(), None)
}
pub fn analyze_null_states_with_globals(
cfg: &FunctionCfg,
func_node: &Node,
source: &str,
summaries: &HashMap<String, FunctionSummary>,
global_states: &StateMap,
func_name: Option<&str>,
) -> NullAnalysisResult {
let body = match func_node.child_by_field_name("body") {
Some(b) => b,
None => {
return NullAnalysisResult {
block_entry_states: HashMap::new(),
block_exit_states: HashMap::new(),
declared_pointers: HashSet::new(),
}
}
};
let mut declared_pointers = HashSet::new();
let mut initial_state = StateMap::new();
for (name, &state) in global_states {
initial_state.insert(name.clone(), state);
declared_pointers.insert(name.clone());
}
let func_summary = func_name.and_then(|name| summaries.get(name));
let callsite_states = func_summary.map(|s| &s.callsite_param_null_states);
if let Some(declarator) = func_node.child_by_field_name("declarator") {
collect_param_pointer_state(
&declarator,
source,
&mut initial_state,
&mut declared_pointers,
callsite_states,
);
if let Some(summary) = func_summary {
if !summary.callsite_param_field_null_states.is_empty() {
let param_names =
crate::analyze::function_summary::collect_param_names(func_node, source);
for (param_idx, field_states) in &summary.callsite_param_field_null_states {
if let Some(param_name) = param_names.get(*param_idx) {
if !param_name.is_empty() {
for (field_name, &state) in field_states {
let key = format!("{}.{}", param_name, field_name);
initial_state.insert(key, state);
}
}
}
}
}
}
if let Some(summary) = func_summary {
if !summary.callsite_param_pointee_null_states.is_empty() {
let param_names =
crate::analyze::function_summary::collect_param_names(func_node, source);
for (param_idx, &state) in &summary.callsite_param_pointee_null_states {
if let Some(param_name) = param_names.get(*param_idx) {
if !param_name.is_empty() {
let key = format!("*{}", param_name);
initial_state.insert(key, state);
}
}
}
}
}
}
let mut entry_states: HashMap<BlockId, StateMap> = HashMap::new();
let mut exit_states: HashMap<BlockId, StateMap> = HashMap::new();
for block in &cfg.blocks {
entry_states.insert(block.id, StateMap::new());
exit_states.insert(block.id, StateMap::new());
}
entry_states.insert(cfg.entry, initial_state.clone());
let entry_exit = apply_transfer(
&cfg.blocks[cfg.entry],
&initial_state,
&body,
source,
&mut declared_pointers,
summaries,
);
exit_states.insert(cfg.entry, entry_exit);
let mut worklist: VecDeque<BlockId> = VecDeque::new();
let mut in_worklist: HashSet<BlockId> = HashSet::new();
for (succ, _) in cfg.successors(cfg.entry) {
worklist.push_back(succ);
in_worklist.insert(succ);
}
let mut iterations = 0;
const MAX_ITERATIONS: usize = 500;
while let Some(block_id) = worklist.pop_front() {
in_worklist.remove(&block_id);
iterations += 1;
if iterations > MAX_ITERATIONS * cfg.blocks.len() {
break;
}
let preds = cfg.predecessors(block_id);
let mut new_entry = StateMap::new();
let mut first = true;
for (pred_id, edge_kind) in &preds {
let pred_exit = exit_states.get(pred_id).cloned().unwrap_or_default();
let refined =
apply_edge_refinement(&pred_exit, *pred_id, edge_kind, cfg, &body, source);
if first {
new_entry = refined;
first = false;
} else {
new_entry = join_states(&new_entry, &refined);
}
}
if first {
continue;
}
let block = &cfg.blocks[block_id];
let new_exit = apply_transfer(
block,
&new_entry,
&body,
source,
&mut declared_pointers,
summaries,
);
let old_exit = exit_states.get(&block_id);
if old_exit.is_none_or(|old| *old != new_exit) {
entry_states.insert(block_id, new_entry);
exit_states.insert(block_id, new_exit);
for (succ, _) in cfg.successors(block_id) {
if in_worklist.insert(succ) {
worklist.push_back(succ);
}
}
} else {
entry_states.insert(block_id, new_entry);
}
}
NullAnalysisResult {
block_entry_states: entry_states,
block_exit_states: exit_states,
declared_pointers,
}
}
fn apply_edge_refinement(
pred_exit: &StateMap,
pred_id: BlockId,
edge_kind: &CfgEdge,
cfg: &FunctionCfg,
body: &Node,
source: &str,
) -> StateMap {
let mut state = pred_exit.clone();
let is_true = matches!(edge_kind, CfgEdge::TrueBranch);
let is_false = matches!(edge_kind, CfgEdge::FalseBranch);
if !is_true && !is_false {
return state;
}
let pred_block = match cfg.get_block(pred_id) {
Some(b) => b,
None => return state,
};
let (cond_start, cond_end) = match pred_block.condition_range {
Some(r) => r,
None => return state,
};
let cond_node = match find_node_at_range(body, cond_start, cond_end) {
Some(n) => n,
None => return state,
};
for info in parse_all_null_conditions(&cond_node, source) {
let refined_state = if is_true {
info.true_state
} else {
info.false_state
};
if state.contains_key(&info.var_name) {
state.insert(info.var_name, refined_state);
}
}
state
}
pub fn is_null_deref_at(
result: &NullAnalysisResult,
cfg: &FunctionCfg,
body: &Node,
source: &str,
var_name: &str,
deref_byte: usize,
summaries: &HashMap<String, FunctionSummary>,
) -> bool {
let block = match find_block_containing(cfg, deref_byte) {
Some(b) => b,
None => return false, };
let entry = match result.block_entry_states.get(&block.id) {
Some(s) => s,
None => return false,
};
let mut state = entry.clone();
let mut declared_pointers = result.declared_pointers.clone();
for &(start, end) in &block.statements {
if start >= deref_byte {
break;
}
if let Some(stmt_node) = find_node_at_range(body, start, end) {
process_statement_for_null_state(
&stmt_node,
source,
&mut state,
&mut declared_pointers,
summaries,
);
}
}
match state.get(var_name) {
Some(ns) => ns.is_unsafe(),
None => false, }
}
pub fn get_var_state_at(
result: &NullAnalysisResult,
cfg: &FunctionCfg,
body: &Node,
source: &str,
var_name: &str,
byte_offset: usize,
summaries: &HashMap<String, FunctionSummary>,
) -> NullState {
let block = match find_block_containing(cfg, byte_offset) {
Some(b) => b,
None => return NullState::Unknown,
};
let entry = match result.block_entry_states.get(&block.id) {
Some(s) => s,
None => return NullState::Unknown,
};
let mut state = entry.clone();
let mut declared_pointers = result.declared_pointers.clone();
for &(start, end) in &block.statements {
if start >= byte_offset {
break;
}
if let Some(stmt_node) = find_node_at_range(body, start, end) {
process_statement_for_null_state(
&stmt_node,
source,
&mut state,
&mut declared_pointers,
summaries,
);
}
}
state.get(var_name).copied().unwrap_or(NullState::Unknown)
}
fn find_block_containing(cfg: &FunctionCfg, byte_offset: usize) -> Option<&BasicBlock> {
for block in &cfg.blocks {
for &(start, end) in &block.statements {
if byte_offset >= start && byte_offset < end {
return Some(block);
}
}
}
cfg.blocks.iter().find(|block| {
block.byte_range.0 > 0
&& byte_offset >= block.byte_range.0
&& byte_offset < block.byte_range.1
})
}
fn collect_param_pointer_state(
declarator: &Node,
source: &str,
state: &mut StateMap,
declared_pointers: &mut HashSet<String>,
callsite_states: Option<&HashMap<usize, NullState>>,
) {
if declarator.kind() == "function_declarator" {
if let Some(params) = declarator.child_by_field_name("parameters") {
let mut param_idx: usize = 0;
for i in 0..params.child_count() {
if let Some(param) = params.child(i) {
if param.kind() == "parameter_declaration" {
let param_text = get_text(¶m, source);
if let Some(param_decl) = param.child_by_field_name("declarator") {
let name = get_identifier_from_declarator(¶m_decl, source);
if !name.is_empty()
&& (is_pointer_declarator(¶m_decl)
|| param_text.contains('*')
|| param_text.starts_with("FILE")
|| name.contains("callback"))
{
declared_pointers.insert(name.clone());
let seed_state = if let Some(cs) = callsite_states {
cs.get(¶m_idx)
.copied()
.map(|s| match s {
NullState::Unknown => NullState::NotNull,
other => other,
})
.unwrap_or(NullState::NotNull)
} else {
NullState::NotNull
};
state.insert(name, seed_state);
}
}
param_idx += 1;
}
}
}
}
} else {
for i in 0..declarator.child_count() {
if let Some(child) = declarator.child(i) {
collect_param_pointer_state(
&child,
source,
state,
declared_pointers,
callsite_states,
);
}
}
}
}
fn get_text(node: &Node, source: &str) -> String {
source[node.start_byte()..node.end_byte()].to_string()
}
pub fn is_null_value(text: &str) -> bool {
let t = text.trim();
t == "NULL" || t == "0" || t == "nullptr"
}
pub fn is_nullable_function(func_name: &str, summaries: &HashMap<String, FunctionSummary>) -> bool {
if let Some(summary) = summaries.get(func_name) {
if summary.can_return_null {
return true;
}
}
matches!(
func_name,
"malloc"
| "calloc"
| "realloc"
| "strstr"
| "strchr"
| "strrchr"
| "fopen"
| "fdopen"
| "freopen"
| "tmpfile"
| "popen"
| "getenv"
| "setlocale"
| "strtok"
| "bsearch"
| "fgets"
| "gets"
| "strdup"
| "strndup"
| "strpbrk"
| "memchr"
| "localtime"
| "gmtime"
| "asctime"
| "ctime"
| "create_int"
)
}
#[allow(dead_code)]
pub fn is_cast_to_null(node: &Node, source: &str) -> bool {
if node.kind() == "cast_expression" {
if let Some(value) = node.child_by_field_name("value") {
let vt = get_text(&value, source);
return is_null_value(vt.trim());
}
}
false
}
pub fn is_pointer_declarator(declarator: &Node) -> bool {
match declarator.kind() {
"pointer_declarator" => true,
"array_declarator" => true,
_ => {
for i in 0..declarator.child_count() {
if let Some(child) = declarator.child(i) {
if is_pointer_declarator(&child) {
return true;
}
}
}
false
}
}
}
fn contains_array(node: &Node) -> bool {
if node.kind() == "array_declarator" {
return true;
}
for i in 0..node.child_count() {
if let Some(child) = node.child(i) {
if contains_array(&child) {
return true;
}
}
}
false
}
fn get_identifier_from_declarator(declarator: &Node, source: &str) -> String {
match declarator.kind() {
"identifier" => get_text(declarator, source),
"pointer_declarator" | "array_declarator" => {
if let Some(inner) = declarator.child_by_field_name("declarator") {
get_identifier_from_declarator(&inner, source)
} else {
String::new()
}
}
_ => {
for i in 0..declarator.child_count() {
if let Some(child) = declarator.child(i) {
if child.kind() == "identifier" {
return get_text(&child, source);
}
}
}
String::new()
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::analyze::cfg::build_function_cfg;
fn analyze(code: &str) -> (FunctionCfg, NullAnalysisResult, tree_sitter::Tree, String) {
let mut parser = tree_sitter::Parser::new();
parser.set_language(&tree_sitter_c::language()).unwrap();
let tree = parser.parse(code, None).unwrap();
let root = tree.root_node();
let func = (0..root.child_count())
.filter_map(|i| root.child(i))
.find(|c| c.kind() == "function_definition")
.unwrap();
let cfg = build_function_cfg(&func, code).unwrap();
let summaries = HashMap::new();
let result = analyze_null_states(&cfg, &func, code, &summaries);
(cfg, result, tree, code.to_string())
}
#[test]
fn test_null_assigned_then_deref() {
let code = r#"
void foo() {
int *p = NULL;
*p = 42;
}
"#;
let (cfg, result, tree, source) = analyze(code);
let root = tree.root_node();
let func = root.child(0).unwrap();
let body = func.child_by_field_name("body").unwrap();
let deref_pos = source.find("*p = 42").unwrap();
let summaries = HashMap::new();
assert!(is_null_deref_at(
&result, &cfg, &body, &source, "p", deref_pos, &summaries
));
}
#[test]
fn test_null_check_before_deref() {
let code = r#"
void foo() {
int *p = NULL;
if (p != NULL) {
*p = 42;
}
}
"#;
let (cfg, result, tree, source) = analyze(code);
let root = tree.root_node();
let func = root.child(0).unwrap();
let body = func.child_by_field_name("body").unwrap();
let deref_pos = source.find("*p = 42").unwrap();
let summaries = HashMap::new();
assert!(!is_null_deref_at(
&result, &cfg, &body, &source, "p", deref_pos, &summaries
));
}
#[test]
fn test_early_return_after_null_check() {
let code = r#"
int foo(int *p) {
if (p == NULL) {
return -1;
}
*p = 42;
return 0;
}
"#;
let (cfg, result, tree, source) = analyze(code);
let root = tree.root_node();
let func = root.child(0).unwrap();
let body = func.child_by_field_name("body").unwrap();
let deref_pos = source.find("*p = 42").unwrap();
let summaries = HashMap::new();
assert!(!is_null_deref_at(
&result, &cfg, &body, &source, "p", deref_pos, &summaries
));
}
#[test]
fn test_deref_inside_null_branch() {
let code = r#"
void foo(int *p) {
if (p == NULL) {
*p = 42;
}
}
"#;
let (cfg, result, tree, source) = analyze(code);
let root = tree.root_node();
let func = root.child(0).unwrap();
let body = func.child_by_field_name("body").unwrap();
let deref_pos = source.find("*p = 42").unwrap();
let summaries = HashMap::new();
assert!(is_null_deref_at(
&result, &cfg, &body, &source, "p", deref_pos, &summaries
));
}
#[test]
fn test_malloc_with_check() {
let code = r#"
void foo() {
int *p = malloc(sizeof(int));
if (p == NULL) {
return;
}
*p = 42;
}
"#;
let (cfg, result, tree, source) = analyze(code);
let root = tree.root_node();
let func = root.child(0).unwrap();
let body = func.child_by_field_name("body").unwrap();
let deref_pos = source.find("*p = 42").unwrap();
let summaries = HashMap::new();
assert!(!is_null_deref_at(
&result, &cfg, &body, &source, "p", deref_pos, &summaries
));
}
#[test]
fn test_malloc_without_check() {
let code = r#"
void foo() {
int *p = malloc(sizeof(int));
*p = 42;
}
"#;
let (cfg, result, tree, source) = analyze(code);
let root = tree.root_node();
let func = root.child(0).unwrap();
let body = func.child_by_field_name("body").unwrap();
let deref_pos = source.find("*p = 42").unwrap();
let summaries = HashMap::new();
assert!(is_null_deref_at(
&result, &cfg, &body, &source, "p", deref_pos, &summaries
));
}
#[test]
fn test_while_loop_guard() {
let code = r#"
void foo(int *p) {
while (p != NULL) {
*p = 42;
p = NULL;
}
}
"#;
let (cfg, result, tree, source) = analyze(code);
let root = tree.root_node();
let func = root.child(0).unwrap();
let body = func.child_by_field_name("body").unwrap();
let deref_pos = source.find("*p = 42").unwrap();
let summaries = HashMap::new();
assert!(!is_null_deref_at(
&result, &cfg, &body, &source, "p", deref_pos, &summaries
));
}
#[test]
fn test_global_prepass_null_static() {
let code = r#"
static int *globalData;
void source() {
int *data;
data = NULL;
globalData = data;
}
void sink() {
int *data = globalData;
*data = 42;
}
"#;
let mut parser = tree_sitter::Parser::new();
parser.set_language(&tree_sitter_c::language()).unwrap();
let tree = parser.parse(code, None).unwrap();
let root = tree.root_node();
let summaries = HashMap::new();
let globals = collect_file_scope_null_states(&root, code, &summaries);
assert_eq!(globals.get("globalData"), Some(&NullState::DefinitelyNull));
let sink_func = (0..root.child_count())
.filter_map(|i| root.child(i))
.find(|c| {
c.kind() == "function_definition"
&& code[c.start_byte()..c.end_byte()].contains("sink")
})
.unwrap();
let cfg = build_function_cfg(&sink_func, code).unwrap();
let result =
analyze_null_states_with_globals(&cfg, &sink_func, code, &summaries, &globals, None);
let body = sink_func.child_by_field_name("body").unwrap();
let deref_pos = code.find("*data = 42").unwrap();
assert!(is_null_deref_at(
&result, &cfg, &body, code, "data", deref_pos, &summaries
));
}
#[test]
fn test_global_prepass_nonnull_static() {
let code = r#"
static char *globalData;
void source() {
char *data;
data = "Good";
globalData = data;
}
void sink() {
char *data = globalData;
data[0];
}
"#;
let mut parser = tree_sitter::Parser::new();
parser.set_language(&tree_sitter_c::language()).unwrap();
let tree = parser.parse(code, None).unwrap();
let root = tree.root_node();
let summaries = HashMap::new();
let globals = collect_file_scope_null_states(&root, code, &summaries);
assert_eq!(globals.get("globalData"), Some(&NullState::NotNull));
}
}