use super::super::{CertRule, RuleViolation};
use crate::analyze::context::ProjectContext;
use crate::manifest::{RuleCategory, Severity};
use crate::utility::cert_c::ast_utils::get_node_text;
use std::cell::RefCell;
use std::collections::HashSet;
use tree_sitter::Node;
pub struct Dcl07C {
cross_file_functions: RefCell<HashSet<String>>,
}
impl Dcl07C {
pub fn new() -> Self {
Dcl07C {
cross_file_functions: RefCell::new(HashSet::new()),
}
}
fn check_node<'a>(
&self,
node: &Node<'a>,
source: &'a str,
violations: &mut Vec<RuleViolation>,
) {
if node.kind() == "function_definition" {
self.check_kr_style_function(node, source, violations);
}
if node.kind() == "assignment_expression" {
self.check_function_pointer_assignment(node, source, violations);
}
for i in 0..node.child_count() {
if let Some(child) = node.child(i) {
self.check_node(&child, source, violations);
}
}
}
fn check_kr_style_function<'a>(
&self,
func_node: &Node<'a>,
_source: &'a str,
violations: &mut Vec<RuleViolation>,
) {
if let Some(declarator) = func_node.child_by_field_name("declarator") {
let mut found_declarations_after_declarator = false;
let declarator_end = declarator.end_byte();
for i in 0..func_node.child_count() {
if let Some(child) = func_node.child(i) {
if child.start_byte() > declarator_end && child.kind() == "declaration" {
if let Some(body) = func_node.child_by_field_name("body") {
if child.start_byte() < body.start_byte() {
found_declarations_after_declarator = true;
break;
}
}
}
}
}
if found_declarations_after_declarator {
violations.push(RuleViolation {
rule_id: self.rule_id().to_string(),
line: func_node.start_position().row + 1,
column: func_node.start_position().column + 1,
message: "Function uses K&R style (identifier-list) parameter declarations - use prototype form with type information in parameter list".to_string(),
severity: self.severity(),
file_path: String::new(),
suggestion: Some("Declare parameters with types in the parameter list: int max(int a, int b)".to_string()),
requires_manual_review: None,
});
}
}
}
fn check_function_pointer_assignment<'a>(
&self,
assign_node: &Node<'a>,
source: &'a str,
violations: &mut Vec<RuleViolation>,
) {
if let Some(right) = assign_node.child_by_field_name("right") {
if right.kind() == "identifier" {
let func_name = get_node_text(&right, source);
if let Some(left) = assign_node.child_by_field_name("left") {
if let Some(ptr_params) =
self.find_function_pointer_params(&left, source, assign_node)
{
if let Some(func_params) =
self.find_function_definition_params(func_name, source, assign_node)
{
if ptr_params != func_params {
violations.push(RuleViolation {
rule_id: self.rule_id().to_string(),
line: assign_node.start_position().row + 1,
column: assign_node.start_position().column + 1,
message: format!(
"Function pointer assignment has mismatched signature: pointer expects {} parameters but function '{}' has {} parameters",
ptr_params, func_name, func_params
),
severity: self.severity(),
file_path: String::new(),
suggestion: Some(format!(
"Declare function pointer with correct number of parameters to match '{}'",
func_name
)),
requires_manual_review: None,
});
}
}
}
}
}
}
}
fn find_function_pointer_params<'a>(
&self,
var_node: &Node<'a>,
source: &'a str,
context: &Node<'a>,
) -> Option<usize> {
let var_name = if var_node.kind() == "identifier" {
get_node_text(var_node, source)
} else {
return None;
};
let root = self.get_root(context);
self.find_declaration_params(&root, source, var_name)
}
fn get_root<'a>(&self, node: &Node<'a>) -> Node<'a> {
let mut current = *node;
while let Some(parent) = current.parent() {
current = parent;
}
current
}
fn find_declaration_params<'a>(
&self,
node: &Node<'a>,
source: &'a str,
var_name: &str,
) -> Option<usize> {
if node.kind() == "declaration" {
let decl_text = get_node_text(node, source);
if decl_text.contains(var_name) && decl_text.contains("(*") {
return self.count_params_in_declaration(node, source);
}
}
for i in 0..node.child_count() {
if let Some(child) = node.child(i) {
if let Some(count) = self.find_declaration_params(&child, source, var_name) {
return Some(count);
}
}
}
None
}
fn count_params_in_declaration<'a>(
&self,
decl_node: &Node<'a>,
source: &'a str,
) -> Option<usize> {
for i in 0..decl_node.child_count() {
if let Some(child) = decl_node.child(i) {
if let Some(count) = self.find_and_count_params(&child, source) {
return Some(count);
}
}
}
None
}
fn find_and_count_params<'a>(&self, node: &Node<'a>, source: &'a str) -> Option<usize> {
if node.kind() == "parameter_list" || node.kind() == "parameter_declaration" {
return Some(self.count_parameters(node, source));
}
for i in 0..node.child_count() {
if let Some(child) = node.child(i) {
if let Some(count) = self.find_and_count_params(&child, source) {
return Some(count);
}
}
}
None
}
fn count_parameters<'a>(&self, params_node: &Node<'a>, source: &'a str) -> usize {
let mut count = 0;
for i in 0..params_node.child_count() {
if let Some(child) = params_node.child(i) {
if child.kind() == "parameter_declaration" {
count += 1;
} else if child.kind() == "parameter_list" {
count = self.count_parameters(&child, source);
}
}
}
if count == 0 {
let text = get_node_text(params_node, source);
if text.trim() != "()" && text.trim() != "(void)" && !text.trim().is_empty() {
count = text.matches(',').count() + 1;
}
}
count
}
fn find_function_definition_params<'a>(
&self,
func_name: &str,
source: &'a str,
context: &Node<'a>,
) -> Option<usize> {
let root = self.get_root(context);
self.search_function_definition(&root, source, func_name)
}
fn search_function_definition<'a>(
&self,
node: &Node<'a>,
source: &'a str,
func_name: &str,
) -> Option<usize> {
if node.kind() == "function_definition" {
if let Some(declarator) = node.child_by_field_name("declarator") {
if let Some(name) = self.get_function_name(&declarator, source) {
if name == func_name {
return self.count_function_params(&declarator, source);
}
}
}
}
for i in 0..node.child_count() {
if let Some(child) = node.child(i) {
if let Some(count) = self.search_function_definition(&child, source, func_name) {
return Some(count);
}
}
}
None
}
fn get_function_name<'a>(&self, declarator: &Node<'a>, source: &'a str) -> Option<&'a str> {
match declarator.kind() {
"identifier" => Some(get_node_text(declarator, source)),
"function_declarator" => {
if let Some(inner) = declarator.child_by_field_name("declarator") {
self.get_function_name(&inner, source)
} else {
None
}
}
"pointer_declarator" => {
if let Some(inner) = declarator.child_by_field_name("declarator") {
self.get_function_name(&inner, source)
} else {
None
}
}
_ => {
for i in 0..declarator.child_count() {
if let Some(child) = declarator.child(i) {
if child.kind() == "identifier" {
return Some(get_node_text(&child, source));
}
if let Some(name) = self.get_function_name(&child, source) {
return Some(name);
}
}
}
None
}
}
}
fn count_function_params<'a>(&self, declarator: &Node<'a>, source: &'a str) -> Option<usize> {
if declarator.kind() == "function_declarator" {
if let Some(params) = declarator.child_by_field_name("parameters") {
return Some(self.count_parameters(¶ms, source));
}
}
for i in 0..declarator.child_count() {
if let Some(child) = declarator.child(i) {
if let Some(count) = self.count_function_params(&child, source) {
return Some(count);
}
}
}
None
}
}
impl CertRule for Dcl07C {
fn rule_id(&self) -> &'static str {
"DCL07-C"
}
fn description(&self) -> &'static str {
"Include the appropriate type information in function declarators"
}
fn severity(&self) -> Severity {
Severity::Low
}
fn category(&self) -> RuleCategory {
RuleCategory::Recommendation
}
fn cert_id(&self) -> &'static str {
"DCL07-C"
}
fn set_project_context(&self, context: &ProjectContext) {
let mut funcs = context.known_functions.clone();
funcs.extend(context.header_declared_functions.clone());
for alias_name in context.macro_aliases.keys() {
funcs.insert(alias_name.clone());
}
*self.cross_file_functions.borrow_mut() = funcs;
}
fn check(&self, node: &Node, source: &str) -> Vec<RuleViolation> {
let mut violations = Vec::new();
self.check_node(node, source, &mut violations);
violations
}
}