use super::super::{CertRule, RuleViolation};
use crate::manifest::{RuleCategory, Severity};
use std::collections::HashSet;
use tree_sitter::Node;
pub struct Pre32C;
struct UnclosedCallInfo {
function_name: String,
open_parens: usize,
}
impl CertRule for Pre32C {
fn rule_id(&self) -> &'static str {
"PRE32-C"
}
fn description(&self) -> &'static str {
"Do not use preprocessor directives in invocations of function-like macros"
}
fn severity(&self) -> Severity {
Severity::High
}
fn category(&self) -> RuleCategory {
RuleCategory::Rule
}
fn cert_id(&self) -> &'static str {
"PRE32-C"
}
fn check(&self, node: &Node, source: &str) -> Vec<RuleViolation> {
let mut violations = Vec::new();
self.check_node(node, source, &mut violations);
violations
}
}
impl Pre32C {
fn check_node(&self, node: &Node, source: &str, violations: &mut Vec<RuleViolation>) {
match node.kind() {
"call_expression" => {
self.check_function_call(node, source, violations);
}
"preproc_ifdef" | "preproc_if" | "preproc_ifndef" | "preproc_else" | "preproc_elif"
| "preproc_call" | "preproc_def" | "preproc_include" => {
self.check_preproc_for_macro_calls(node, source, violations);
}
_ => {}
}
for i in 0..node.child_count() {
if let Some(child) = node.child(i) {
self.check_node(&child, source, violations);
}
}
}
fn check_preproc_for_macro_calls(
&self,
node: &Node,
source: &str,
violations: &mut Vec<RuleViolation>,
) {
let start_byte = node.start_byte();
let _end_byte = node.end_byte();
let text_before = &source[..start_byte];
if let Some(call_info) = self.find_unclosed_call_before(text_before) {
let text_after_start = &source[start_byte..];
if self.has_matching_close_paren(text_after_start, call_info.open_parens) {
let start_point = node.start_position();
violations.push(RuleViolation {
rule_id: self.rule_id().to_string(),
severity: Severity::High,
message: format!(
"Preprocessor directive '{}' used inside invocation of '{}'. This causes undefined behavior if the function is implemented as a macro",
node.kind().replace("preproc_", "#"),
call_info.function_name
),
file_path: String::new(),
line: start_point.row + 1,
column: start_point.column + 1,
suggestion: Some("Move preprocessor directives outside the function call using conditional compilation".to_string()),
..Default::default()
});
}
}
}
fn find_unclosed_call_before(&self, text: &str) -> Option<UnclosedCallInfo> {
let mut paren_depth = 0i32;
let chars: Vec<char> = text.chars().collect();
let mut i = chars.len();
while i > 0 {
i -= 1;
match chars[i] {
')' => paren_depth += 1,
'(' => {
paren_depth -= 1;
if paren_depth < 0 {
let mut end = i;
while end > 0 && chars[end - 1].is_whitespace() {
end -= 1;
}
let mut start = end;
while start > 0
&& (chars[start - 1].is_alphanumeric() || chars[start - 1] == '_')
{
start -= 1;
}
if start < end {
let function_name: String = chars[start..end].iter().collect();
if self.is_potentially_macro_function(&function_name) {
return Some(UnclosedCallInfo {
function_name,
open_parens: (-paren_depth) as usize,
});
}
}
}
}
_ => {}
}
}
None
}
fn has_matching_close_paren(&self, text: &str, open_count: usize) -> bool {
let mut close_count = 0usize;
let mut paren_depth = 0i32;
for c in text.chars() {
match c {
'(' => paren_depth += 1,
')' => {
if paren_depth > 0 {
paren_depth -= 1;
} else {
close_count += 1;
if close_count >= open_count {
return true;
}
}
}
_ => {}
}
}
false
}
fn check_function_call(&self, node: &Node, source: &str, violations: &mut Vec<RuleViolation>) {
if let Some(function_node) = node.child_by_field_name("function") {
let function_name = &source[function_node.start_byte()..function_node.end_byte()];
if self.is_potentially_macro_function(function_name) {
if let Some(arguments) = node.child_by_field_name("arguments") {
self.check_arguments_for_directives(
&arguments,
source,
function_name,
violations,
);
}
}
}
}
fn check_arguments_for_directives(
&self,
arguments: &Node,
source: &str,
function_name: &str,
violations: &mut Vec<RuleViolation>,
) {
let args_text = &source[arguments.start_byte()..arguments.end_byte()];
if self.contains_preprocessor_directives(args_text) {
let start_point = arguments.start_position();
violations.push(RuleViolation {
rule_id: self.rule_id().to_string(),
severity: Severity::High,
message: format!(
"Function '{}' called with preprocessor directives in arguments. This causes undefined behavior if the function is implemented as a macro",
function_name
),
file_path: String::new(),
line: start_point.row + 1,
column: start_point.column + 1,
suggestion: Some("Move preprocessor directives outside the function call using conditional compilation".to_string()),
..Default::default()
});
}
for i in 0..arguments.child_count() {
if let Some(child) = arguments.child(i) {
if child.kind() != "," {
let arg_text = &source[child.start_byte()..child.end_byte()];
if self.contains_preprocessor_directives(arg_text) {
let start_point = child.start_position();
violations.push(RuleViolation {
rule_id: self.rule_id().to_string(),
severity: Severity::High,
message: format!(
"Argument to '{}' contains preprocessor directive: '{}'",
function_name,
arg_text.trim()
),
file_path: String::new(),
line: start_point.row + 1,
column: start_point.column + 1,
suggestion: Some(
"Use conditional compilation to wrap the entire function call"
.to_string(),
),
..Default::default()
});
}
}
}
}
}
fn is_potentially_macro_function(&self, function_name: &str) -> bool {
let std_lib_functions: HashSet<&str> = [
"memcpy", "memmove", "memset", "memcmp", "memchr", "strcpy", "strncpy", "strcat",
"strncat", "strcmp", "strncmp", "strchr", "strrchr", "strpbrk", "strspn", "strcspn",
"strstr", "strtok", "strlen", "isalnum", "isalpha", "isblank", "iscntrl", "isdigit", "isgraph", "islower", "isprint",
"ispunct", "isspace", "isupper", "isxdigit", "tolower", "toupper",
"getc", "putc", "getchar", "putchar", "fgetc", "fputc", "getwc", "putwc", "fgetwc",
"fputwc", "printf", "fprintf", "sprintf", "snprintf", "scanf", "fscanf", "sscanf",
"fread", "fwrite", "fopen", "fclose", "fseek", "ftell", "rewind", "fgets", "fputs",
"abs", "labs", "llabs", "fabs", "fabsf", "fabsl", "sqrt", "sqrtf", "sqrtl", "pow",
"powf", "powl", "sin", "cos", "tan", "asin", "acos", "atan", "atan2", "exp", "log",
"log10", "ceil", "floor", "fmod", "malloc", "calloc", "realloc", "free", "assert", "wmemcpy", "wmemmove", "wmemset", "wmemcmp", "wmemchr", "wcscpy", "wcsncpy", "wcscat",
"wcsncat", "wcscmp", "wcsncmp", "wcschr", "wcsrchr", "wcspbrk", "wcsspn", "wcscspn",
"wcsstr", "wcstok", "wcslen",
]
.iter()
.cloned()
.collect();
std_lib_functions.contains(function_name) ||
function_name.chars().all(|c| c.is_uppercase() || c == '_' || c.is_ascii_digit())
}
fn contains_preprocessor_directives(&self, text: &str) -> bool {
let directives = [
"#define", "#undef", "#include", "#if", "#ifdef", "#ifndef", "#else", "#elif",
"#endif", "#error", "#warning", "#pragma", "#line",
];
for directive in &directives {
if text.contains(directive) {
return true;
}
}
if text.contains("\\") && text.contains("\n") {
return true;
}
false
}
#[allow(dead_code)]
fn spans_multiple_lines_with_directives(&self, text: &str) -> bool {
let lines: Vec<&str> = text.lines().collect();
if lines.len() <= 1 {
return false;
}
for line in lines {
let trimmed = line.trim();
if trimmed.starts_with('#') {
return true;
}
}
false
}
}