use super::super::{CertRule, RuleViolation};
use crate::manifest::{RuleCategory, Severity};
use crate::utility::cert_c::ast_utils::get_node_text;
use tree_sitter::Node;
pub struct Mem10C;
impl CertRule for Mem10C {
fn rule_id(&self) -> &'static str {
"MEM10-C"
}
fn description(&self) -> &'static str {
"Define and use a pointer validation function"
}
fn severity(&self) -> Severity {
Severity::Medium
}
fn category(&self) -> RuleCategory {
RuleCategory::Rule
}
fn cert_id(&self) -> &'static str {
"MEM10-C"
}
fn check(&self, node: &Node, source: &str) -> Vec<RuleViolation> {
let mut violations = Vec::new();
self.check_pointer_validation(node, source, &mut violations);
self.check_sizeof_pointer_misuse(node, source, &mut violations);
violations
}
}
impl Mem10C {
fn collect_enclosing_params<'a>(
&self,
node: &Node<'a>,
source: &str,
) -> std::collections::HashSet<String> {
let mut current = node.parent();
while let Some(p) = current {
if p.kind() == "function_definition" {
return self.extract_pointer_param_names(&p, source);
}
current = p.parent();
}
std::collections::HashSet::new()
}
fn extract_pointer_param_names(
&self,
func_node: &Node,
source: &str,
) -> std::collections::HashSet<String> {
let mut names = std::collections::HashSet::new();
for i in 0..func_node.child_count() {
if let Some(child) = func_node.child(i) {
self.collect_params_from_declarator(&child, source, &mut names);
}
}
names
}
fn collect_params_from_declarator(
&self,
node: &Node,
source: &str,
names: &mut std::collections::HashSet<String>,
) {
if node.kind() == "function_declarator" {
if let Some(params) = node.child_by_field_name("parameters") {
for i in 0..params.child_count() {
if let Some(param) = params.child(i) {
if param.kind() == "parameter_declaration" {
if let Some(decl) = param.child_by_field_name("declarator") {
let text = get_node_text(&decl, source);
if decl.kind() == "pointer_declarator"
|| decl.kind() == "array_declarator"
{
if let Some(id) = find_identifier_in_node(&decl, source) {
names.insert(id);
}
} else if decl.kind() == "identifier" {
names.insert(text.to_string());
}
}
}
}
}
}
} else {
for i in 0..node.child_count() {
if let Some(child) = node.child(i) {
self.collect_params_from_declarator(&child, source, names);
}
}
}
}
fn check_pointer_validation(
&self,
node: &Node,
source: &str,
violations: &mut Vec<RuleViolation>,
) {
if node.kind() == "if_statement" {
if let Some(condition) = node.child_by_field_name("condition") {
if self.is_direct_null_check(&condition, source) {
let checked_var = extract_checked_var_name(&condition, source);
let params = self.collect_enclosing_params(node, source);
if checked_var.as_deref().is_some_and(|v| params.contains(v)) {
let suppress = checked_var.as_ref().is_some_and(|var_name| {
is_positive_guard(&condition, source)
&& !is_param_used_after_if(node, var_name, source)
});
if !suppress {
violations.push(RuleViolation {
rule_id: self.rule_id().to_string(),
message: "Direct NULL check for pointer validation. \
Define and use a dedicated pointer validation function \
instead of ad-hoc NULL checks. This centralizes validation \
logic and allows platform-specific enhancements."
.to_string(),
severity: self.severity(),
line: condition.start_position().row + 1,
column: condition.start_position().column + 1,
file_path: String::new(),
suggestion: Some(
"Create a validation function like 'int valid(void *ptr)' \
and use 'if (!valid(ptr))' instead of 'if (ptr == NULL)'"
.to_string(),
),
requires_manual_review: Some(true),
});
}
}
}
}
}
for i in 0..node.child_count() {
if let Some(child) = node.child(i) {
self.check_pointer_validation(&child, source, violations);
}
}
}
fn is_direct_null_check(&self, condition: &Node, source: &str) -> bool {
let condition_text = get_node_text(condition, source);
if condition_text.contains("== NULL") || condition_text.contains("!= NULL") {
if !self.appears_to_be_validation_function_call(condition, source) {
return true;
}
}
if condition.kind() == "unary_expression" {
if let Some(operator) = condition.child_by_field_name("operator") {
let op_text = get_node_text(&operator, source);
if op_text == "!" {
if let Some(argument) = condition.child_by_field_name("argument") {
if argument.kind() == "identifier" {
return true;
}
}
}
}
}
false
}
fn appears_to_be_validation_function_call(&self, condition: &Node, source: &str) -> bool {
self.contains_call_expression(condition, source)
}
fn contains_call_expression(&self, node: &Node, _source: &str) -> bool {
if node.kind() == "call_expression" {
return true;
}
for i in 0..node.child_count() {
if let Some(child) = node.child(i) {
if self.contains_call_expression(&child, _source) {
return true;
}
}
}
false
}
fn check_sizeof_pointer_misuse(
&self,
node: &Node,
source: &str,
violations: &mut Vec<RuleViolation>,
) {
if node.kind() == "call_expression" {
if let Some(func) = node.child_by_field_name("function") {
let func_name = get_node_text(&func, source);
if matches!(
func_name,
"malloc" | "calloc" | "realloc" | "memset" | "memcpy" | "memmove"
) {
if let Some(args) = node.child_by_field_name("arguments") {
self.check_sizeof_args_in_call(func_name, &args, node, source, violations);
}
}
}
}
for i in 0..node.child_count() {
if let Some(child) = node.child(i) {
self.check_sizeof_pointer_misuse(&child, source, violations);
}
}
}
fn check_sizeof_args_in_call(
&self,
func_name: &str,
args: &Node,
call_node: &Node,
source: &str,
violations: &mut Vec<RuleViolation>,
) {
let mut cursor = args.walk();
for child in args.children(&mut cursor) {
if child.kind() == "(" || child.kind() == ")" || child.kind() == "," {
continue;
}
self.find_sizeof_pointer_in_expr(&child, func_name, call_node, source, violations);
}
}
fn find_sizeof_pointer_in_expr(
&self,
node: &Node,
func_name: &str,
call_node: &Node,
source: &str,
violations: &mut Vec<RuleViolation>,
) {
if node.kind() == "sizeof_expression" {
for i in 0..node.child_count() {
if let Some(child) = node.child(i) {
if child.kind() == "parenthesized_expression" {
if let Some(inner) = child.child(1) {
if inner.kind() == "identifier" {
let var_name = get_node_text(&inner, source);
if self.is_pointer_variable(&inner, var_name, source) {
violations.push(RuleViolation {
rule_id: self.rule_id().to_string(),
message: format!(
"sizeof({}) returns the size of a pointer ({}), not the pointed-to data. \
Use sizeof(*{}) or sizeof(type) in {}() call.",
var_name,
if cfg!(target_pointer_width = "64") { "8 bytes" } else { "4 bytes" },
var_name,
func_name,
),
severity: Severity::Medium,
line: node.start_position().row + 1,
column: node.start_position().column + 1,
file_path: String::new(),
suggestion: Some(format!(
"Replace sizeof({}) with sizeof(*{}) to get the size of the pointed-to type",
var_name, var_name,
)),
requires_manual_review: Some(false),
});
}
}
}
}
}
}
return; }
for i in 0..node.child_count() {
if let Some(child) = node.child(i) {
self.find_sizeof_pointer_in_expr(&child, func_name, call_node, source, violations);
}
}
}
fn is_pointer_variable(&self, node: &Node, var_name: &str, source: &str) -> bool {
let mut current = node.parent();
while let Some(p) = current {
if p.kind() == "function_definition" {
if self.is_pointer_param(&p, var_name, source) {
return true;
}
if let Some(body) = p.child_by_field_name("body") {
if self.is_pointer_local_var(&body, var_name, source) {
return true;
}
}
return false;
}
current = p.parent();
}
false
}
fn is_pointer_param(&self, func_node: &Node, var_name: &str, source: &str) -> bool {
let params = self.extract_pointer_param_names(func_node, source);
params.contains(var_name)
}
fn is_pointer_local_var(&self, body: &Node, var_name: &str, source: &str) -> bool {
let mut cursor = body.walk();
for child in body.children(&mut cursor) {
if child.kind() == "declaration" {
let mut decl_cursor = child.walk();
for decl_child in child.children(&mut decl_cursor) {
if decl_child.kind() == "init_declarator" {
if let Some(declarator) = decl_child.child_by_field_name("declarator") {
if declarator.kind() == "pointer_declarator" {
if let Some(id) = find_identifier_in_node(&declarator, source) {
if id == var_name {
return true;
}
}
}
}
}
}
}
}
false
}
}
fn extract_checked_var_name(condition: &Node, source: &str) -> Option<String> {
if condition.kind() == "binary_expression" {
let left = condition.child_by_field_name("left")?;
let right = condition.child_by_field_name("right")?;
let left_text = get_node_text(&left, source);
let right_text = get_node_text(&right, source);
if right_text == "NULL" && left.kind() == "identifier" {
return Some(left_text.to_string());
}
if left_text == "NULL" && right.kind() == "identifier" {
return Some(right_text.to_string());
}
return None;
}
if condition.kind() == "parenthesized_expression" {
if let Some(inner) = condition.child(1) {
return extract_checked_var_name(&inner, source);
}
}
if condition.kind() == "unary_expression" {
if let Some(op) = condition.child_by_field_name("operator") {
if get_node_text(&op, source) == "!" {
if let Some(arg) = condition.child_by_field_name("argument") {
if arg.kind() == "identifier" {
return Some(get_node_text(&arg, source).to_string());
}
}
}
}
}
None
}
fn find_identifier_in_node(node: &Node, source: &str) -> Option<String> {
if node.kind() == "identifier" {
return Some(get_node_text(node, source).to_string());
}
for i in 0..node.child_count() {
if let Some(child) = node.child(i) {
if let Some(id) = find_identifier_in_node(&child, source) {
return Some(id);
}
}
}
None
}
fn is_positive_guard(condition: &Node, source: &str) -> bool {
let text = get_node_text(condition, source);
if text.contains("!= NULL") || text.contains("!=NULL") {
return true;
}
if text.contains("NULL !=") || text.contains("NULL!=") {
return true;
}
if condition.kind() == "identifier" {
return true;
}
if condition.kind() == "parenthesized_expression" {
if let Some(inner) = condition.child(1) {
if inner.kind() == "identifier" {
return true;
}
}
}
false
}
fn is_param_used_after_if(if_node: &Node, var_name: &str, source: &str) -> bool {
let parent = match if_node.parent() {
Some(p) => p,
None => return false,
};
let mut found_if = false;
for i in 0..parent.child_count() {
if let Some(sibling) = parent.child(i) {
if sibling.id() == if_node.id() {
found_if = true;
continue;
}
if found_if && contains_identifier(&sibling, var_name, source) {
return true;
}
}
}
false
}
fn contains_identifier(node: &Node, name: &str, source: &str) -> bool {
if node.kind() == "identifier" && get_node_text(node, source) == name {
return true;
}
for i in 0..node.child_count() {
if let Some(child) = node.child(i) {
if contains_identifier(&child, name, source) {
return true;
}
}
}
false
}