use super::super::{CertRule, RuleViolation};
use crate::manifest::{RuleCategory, Severity};
use crate::utility::cert_c::ast_utils::get_node_text;
use tree_sitter::Node;
pub struct Pre02C;
impl Pre02C {
pub fn new() -> Self {
Self
}
fn is_fully_parenthesized(&self, text: &str) -> bool {
let trimmed = text.trim();
if trimmed.is_empty() {
return false;
}
if !trimmed.starts_with('(') || !trimmed.ends_with(')') {
return false;
}
let mut depth = 0;
let chars: Vec<char> = trimmed.chars().collect();
for (i, &c) in chars.iter().enumerate() {
if c == '(' {
depth += 1;
} else if c == ')' {
depth -= 1;
if depth == 0 && i < chars.len() - 1 {
return false;
}
}
}
depth == 0
}
fn is_single_identifier(&self, text: &str) -> bool {
let trimmed = text.trim();
if trimmed.is_empty() {
return false;
}
let operators = vec![
"+", "-", "*", "/", "%", "&", "|", "^", "<<", ">>", "&&", "||",
];
for op in operators {
if trimmed.contains(op) {
return false;
}
}
true
}
fn is_cast_with_parenthesized_operand(&self, text: &str) -> bool {
let trimmed = text.trim();
if !trimmed.starts_with('(') {
return false;
}
let mut depth = 0;
let mut cast_end = 0;
for (i, c) in trimmed.char_indices() {
if c == '(' {
depth += 1;
} else if c == ')' {
depth -= 1;
if depth == 0 {
cast_end = i;
break;
}
}
}
if cast_end == 0 || cast_end >= trimmed.len() - 1 {
return false;
}
let cast_type = &trimmed[1..cast_end];
if cast_type.is_empty()
|| !cast_type
.chars()
.all(|c| c.is_alphanumeric() || c == '_' || c == ' ' || c == '*')
{
return false;
}
let rest = trimmed[cast_end + 1..].trim();
self.is_fully_parenthesized(rest)
}
fn is_do_while_zero_pattern(&self, text: &str) -> bool {
let lower = text.trim().to_lowercase();
lower.starts_with("do")
&& lower.contains("while")
&& (lower.ends_with("while(0)")
|| lower.ends_with("while (0)")
|| lower.ends_with("while(0u)")
|| lower.ends_with("while (0u)"))
}
fn contains_operators(&self, text: &str) -> bool {
let trimmed = text.trim();
let operators = vec![
" + ", " - ", " * ", " / ", " % ", " & ", " | ", " ^ ", " << ", " >> ", " && ", " || ",
" < ", " > ", " <= ", " >= ", " == ", " != ",
];
for op in operators {
if trimmed.contains(op) {
return true;
}
}
if trimmed.starts_with('-') || trimmed.starts_with('!') || trimmed.starts_with('~') {
return true;
}
false
}
fn check_macro_definition(
&self,
node: &Node,
source: &str,
violations: &mut Vec<RuleViolation>,
) {
let is_object_macro = node.kind() == "preproc_def";
let is_function_macro = node.kind() == "preproc_function_def";
if !is_object_macro && !is_function_macro {
return;
}
let value_node = match node.child_by_field_name("value") {
Some(v) => v,
None => return, };
let raw_value_text = get_node_text(&value_node, source);
let value_text: &str = match raw_value_text.find("//") {
Some(pos) => raw_value_text[..pos].trim_end(),
None => raw_value_text,
};
if !self.contains_operators(&value_text) {
return; }
if self.is_single_identifier(&value_text) {
return; }
if self.is_fully_parenthesized(&value_text) {
return; }
if self.is_do_while_zero_pattern(&value_text) {
return;
}
if self.is_cast_with_parenthesized_operand(&value_text) {
return;
}
violations.push(RuleViolation {
rule_id: self.rule_id().to_string(),
severity: self.severity(),
message: "Macro replacement list should be parenthesized to prevent operator precedence issues.".to_string(),
file_path: String::new(),
line: node.start_position().row + 1,
column: node.start_position().column + 1,
suggestion: Some(format!(
"Wrap the entire replacement list in parentheses: ({})",
value_text.trim()
)),
..Default::default()
});
}
}
impl CertRule for Pre02C {
fn rule_id(&self) -> &'static str {
"PRE02-C"
}
fn description(&self) -> &'static str {
"Macro replacement lists should be parenthesized"
}
fn severity(&self) -> Severity {
Severity::Medium
}
fn category(&self) -> RuleCategory {
RuleCategory::Rule
}
fn cert_id(&self) -> &'static str {
"PRE02-C"
}
fn check(&self, node: &Node, source: &str) -> Vec<RuleViolation> {
let mut violations = Vec::new();
self.check_node(node, source, &mut violations);
violations
}
}
impl Pre02C {
fn check_node(&self, node: &Node, source: &str, violations: &mut Vec<RuleViolation>) {
self.check_macro_definition(node, source, violations);
for i in 0..node.child_count() {
if let Some(child) = node.child(i) {
self.check_node(&child, source, violations);
}
}
}
}