use super::super::{CertRule, RuleViolation};
use crate::analyze::cfg::{self as cfg_mod, FunctionCfg};
use crate::analyze::dataflow::find_node_at_range;
use crate::manifest::{RuleCategory, Severity};
use crate::utility::cert_c::ast_utils::get_node_text;
use std::cell::RefCell;
use std::collections::{HashMap, HashSet, VecDeque};
use tree_sitter::Node;
pub struct Mem01C {
function_cfgs: RefCell<HashMap<usize, FunctionCfg>>,
}
impl Mem01C {
pub fn new() -> Self {
Self {
function_cfgs: RefCell::new(HashMap::new()),
}
}
}
#[derive(Debug, PartialEq)]
enum PtrAction {
Reassigned,
FreedAgain,
Used,
Irrelevant,
}
impl CertRule for Mem01C {
fn rule_id(&self) -> &'static str {
"MEM01-C"
}
fn description(&self) -> &'static str {
"Store a new value in pointers immediately after free()"
}
fn severity(&self) -> Severity {
Severity::High
}
fn category(&self) -> RuleCategory {
RuleCategory::Recommendation
}
fn cert_id(&self) -> &'static str {
"MEM01-C"
}
fn set_function_cfgs(&self, cfgs: &HashMap<usize, FunctionCfg>) {
*self.function_cfgs.borrow_mut() = cfgs.clone();
}
fn check(&self, node: &Node, source: &str) -> Vec<RuleViolation> {
let mut violations = Vec::new();
self.check_node(node, source, &mut violations);
violations
}
}
impl Mem01C {
fn check_node(&self, node: &Node, source: &str, violations: &mut Vec<RuleViolation>) {
if node.kind() == "function_definition" {
self.check_function(node, source, violations);
return; }
for i in 0..node.child_count() {
if let Some(child) = node.child(i) {
self.check_node(&child, source, violations);
}
}
}
fn check_function(&self, func_node: &Node, source: &str, violations: &mut Vec<RuleViolation>) {
let body = match func_node.child_by_field_name("body") {
Some(b) => b,
None => return,
};
let cfgs = self.function_cfgs.borrow();
let inline_cfg;
let cfg = if let Some(c) = cfgs.get(&func_node.start_byte()) {
c
} else if let Some(c) = cfg_mod::build_function_cfg(func_node, source) {
inline_cfg = c;
&inline_cfg
} else {
return; };
let free_calls = self.collect_free_calls(&body, source);
for (ptr_name, free_byte, line, column) in free_calls {
if self.ptr_has_post_free_use(cfg, &body, source, &ptr_name, free_byte) {
violations.push(RuleViolation {
rule_id: self.rule_id().to_string(),
severity: Severity::High,
message: format!(
"Pointer '{}' is used or freed again after free() without reassignment",
ptr_name
),
file_path: String::new(),
line,
column,
suggestion: Some(format!(
"Set '{} = NULL;' after free({}) or remove the subsequent use",
ptr_name, ptr_name
)),
..Default::default()
});
}
}
}
fn collect_free_calls(&self, node: &Node, source: &str) -> Vec<(String, usize, usize, usize)> {
let mut results = Vec::new();
self.walk_for_free_calls(node, source, &mut results);
results
}
fn walk_for_free_calls(
&self,
node: &Node,
source: &str,
results: &mut Vec<(String, usize, usize, usize)>,
) {
if node.kind() == "call_expression" {
if let Some(func) = node.child_by_field_name("function") {
let func_name = get_node_text(&func, source);
if func_name == "free" {
if let Some(ptr_name) = self.extract_free_arg(node, source) {
let pos = node.start_position();
results.push((ptr_name, node.start_byte(), pos.row + 1, pos.column + 1));
}
}
}
}
for i in 0..node.child_count() {
if let Some(child) = node.child(i) {
self.walk_for_free_calls(&child, source, results);
}
}
}
fn extract_free_arg(&self, call_node: &Node, source: &str) -> Option<String> {
let args = call_node.child_by_field_name("arguments")?;
for i in 0..args.child_count() {
if let Some(arg) = args.child(i) {
if arg.kind() != "(" && arg.kind() != ")" && arg.kind() != "," {
return Some(get_node_text(&arg, source).to_string());
}
}
}
None
}
fn ptr_has_post_free_use(
&self,
cfg: &FunctionCfg,
body: &Node,
source: &str,
ptr_name: &str,
free_byte: usize,
) -> bool {
let containing_block = match find_block_containing(cfg, free_byte) {
Some(b) => b,
None => return true, };
match self.scan_block_from(containing_block, body, source, ptr_name, free_byte) {
Some(PtrAction::FreedAgain) | Some(PtrAction::Used) => return true,
Some(PtrAction::Reassigned) => return false,
_ => {} }
let mut visited: HashSet<usize> = HashSet::new();
visited.insert(containing_block.id);
let mut queue: VecDeque<usize> = VecDeque::new();
for (succ_id, _edge) in cfg.successors(containing_block.id) {
queue.push_back(succ_id);
}
while let Some(block_id) = queue.pop_front() {
if !visited.insert(block_id) {
continue; }
let block = match cfg.get_block(block_id) {
Some(b) => b,
None => continue,
};
match self.scan_block_all(block, body, source, ptr_name) {
Some(PtrAction::FreedAgain) | Some(PtrAction::Used) => return true,
Some(PtrAction::Reassigned) => continue, _ => {
for (succ_id, _edge) in cfg.successors(block_id) {
queue.push_back(succ_id);
}
}
}
}
false }
fn scan_block_from(
&self,
block: &crate::analyze::cfg::BasicBlock,
body: &Node,
source: &str,
ptr_name: &str,
after_byte: usize,
) -> Option<PtrAction> {
for &(start, end) in &block.statements {
if start <= after_byte {
continue;
}
if let Some(stmt_node) = find_node_at_range(body, start, end) {
let action = classify_stmt_for_ptr(&stmt_node, source, ptr_name);
if action != PtrAction::Irrelevant {
return Some(action);
}
}
}
None
}
fn scan_block_all(
&self,
block: &crate::analyze::cfg::BasicBlock,
body: &Node,
source: &str,
ptr_name: &str,
) -> Option<PtrAction> {
for &(start, end) in &block.statements {
if let Some(stmt_node) = find_node_at_range(body, start, end) {
let action = classify_stmt_for_ptr(&stmt_node, source, ptr_name);
if action != PtrAction::Irrelevant {
return Some(action);
}
}
}
None
}
}
fn classify_stmt_for_ptr(node: &Node, source: &str, ptr_name: &str) -> PtrAction {
match node.kind() {
"expression_statement" => {
if let Some(expr) = node.child(0) {
classify_expr_for_ptr(&expr, source, ptr_name)
} else {
PtrAction::Irrelevant
}
}
"return_statement" => {
if node.child_count() > 1 {
if let Some(expr) = node.child(1) {
if subtree_contains_identifier(&expr, source, ptr_name) {
return PtrAction::Used;
}
}
}
PtrAction::Irrelevant
}
"declaration" => {
if let Some(declarator) = find_declarator_name(node, source) {
if declarator == ptr_name {
return PtrAction::Reassigned;
}
}
if subtree_contains_identifier(node, source, ptr_name) {
return PtrAction::Used;
}
PtrAction::Irrelevant
}
_ => {
if subtree_contains_identifier(node, source, ptr_name) {
PtrAction::Used
} else {
PtrAction::Irrelevant
}
}
}
}
fn classify_expr_for_ptr(expr: &Node, source: &str, ptr_name: &str) -> PtrAction {
match expr.kind() {
"assignment_expression" => {
if let Some(left) = expr.child_by_field_name("left") {
let left_text = get_node_text(&left, source);
if left_text == ptr_name {
return PtrAction::Reassigned;
}
}
if subtree_contains_identifier(expr, source, ptr_name) {
return PtrAction::Used;
}
PtrAction::Irrelevant
}
"call_expression" => {
if let Some(func) = expr.child_by_field_name("function") {
let func_name = get_node_text(&func, source);
if func_name == "free" {
if let Some(args) = expr.child_by_field_name("arguments") {
if arg_list_contains_identifier(&args, source, ptr_name) {
return PtrAction::FreedAgain;
}
}
}
}
if let Some(args) = expr.child_by_field_name("arguments") {
if arg_list_contains_identifier(&args, source, ptr_name) {
return PtrAction::Used;
}
}
PtrAction::Irrelevant
}
"update_expression" => {
if subtree_contains_identifier(expr, source, ptr_name) {
return PtrAction::Used;
}
PtrAction::Irrelevant
}
_ => {
if subtree_contains_identifier(expr, source, ptr_name) {
PtrAction::Used
} else {
PtrAction::Irrelevant
}
}
}
}
fn subtree_contains_identifier(node: &Node, source: &str, name: &str) -> bool {
if node.kind() == "identifier" {
return get_node_text(node, source) == name;
}
for i in 0..node.child_count() {
if let Some(child) = node.child(i) {
if subtree_contains_identifier(&child, source, name) {
return true;
}
}
}
false
}
fn arg_list_contains_identifier(args: &Node, source: &str, name: &str) -> bool {
for i in 0..args.child_count() {
if let Some(arg) = args.child(i) {
if arg.kind() != "(" && arg.kind() != ")" && arg.kind() != "," {
if subtree_contains_identifier(&arg, source, name) {
return true;
}
}
}
}
false
}
fn find_declarator_name(decl: &Node, source: &str) -> Option<String> {
for i in 0..decl.child_count() {
if let Some(child) = decl.child(i) {
match child.kind() {
"init_declarator" => {
if let Some(d) = child.child_by_field_name("declarator") {
return extract_identifier_from_declarator(&d, source);
}
}
"pointer_declarator" | "array_declarator" | "identifier" => {
return extract_identifier_from_declarator(&child, source);
}
_ => {}
}
}
}
None
}
fn extract_identifier_from_declarator(node: &Node, source: &str) -> Option<String> {
match node.kind() {
"identifier" => Some(get_node_text(node, source).to_string()),
"pointer_declarator" | "array_declarator" => {
if let Some(d) = node.child_by_field_name("declarator") {
extract_identifier_from_declarator(&d, source)
} else {
for i in 0..node.child_count() {
if let Some(child) = node.child(i) {
if let Some(name) = extract_identifier_from_declarator(&child, source) {
return Some(name);
}
}
}
None
}
}
_ => None,
}
}
fn find_block_containing(
cfg: &FunctionCfg,
byte_offset: usize,
) -> Option<&crate::analyze::cfg::BasicBlock> {
for block in &cfg.blocks {
for &(start, end) in &block.statements {
if byte_offset >= start && byte_offset < end {
return Some(block);
}
}
}
cfg.blocks.iter().find(|block| {
block.byte_range.0 > 0
&& byte_offset >= block.byte_range.0
&& byte_offset < block.byte_range.1
})
}