use super::super::{CertRule, RuleViolation};
use crate::analyze::context::ProjectContext;
use crate::manifest::{RuleCategory, Severity};
use crate::utility::cert_c::ast_utils::get_node_text;
use std::cell::RefCell;
use std::collections::{HashMap, HashSet};
use tree_sitter::Node;
#[derive(Debug)]
pub struct Msc04C {
call_graph: RefCell<HashMap<String, HashSet<String>>>,
}
impl Msc04C {
pub fn new() -> Self {
Msc04C {
call_graph: RefCell::new(HashMap::new()),
}
}
fn extract_func_name<'a>(&self, node: &Node<'a>, source: &'a str) -> Option<String> {
let declarator = node.child_by_field_name("declarator")?;
self.find_identifier_in_declarator(&declarator, source)
}
fn find_identifier_in_declarator(&self, node: &Node, source: &str) -> Option<String> {
match node.kind() {
"identifier" => {
let name = get_node_text(node, source);
if name.is_empty() {
None
} else {
Some(name.to_string())
}
}
"function_declarator" | "pointer_declarator" | "parenthesized_declarator" => {
let inner = node.child_by_field_name("declarator")?;
self.find_identifier_in_declarator(&inner, source)
}
_ => None,
}
}
fn collect_callees(&self, node: &Node, source: &str, callees: &mut HashSet<String>) {
if node.kind() == "call_expression" {
if let Some(function) = node.child_by_field_name("function") {
if function.kind() == "identifier" {
let name = get_node_text(&function, source);
if !name.is_empty() {
callees.insert(name.to_string());
}
}
}
}
for i in 0..node.child_count() {
if let Some(child) = node.child(i) {
self.collect_callees(&child, source, callees);
}
}
}
fn find_cycle(
&self,
start: &str,
graph: &HashMap<String, HashSet<String>>,
) -> Option<Vec<String>> {
let mut visited = HashSet::new();
let mut path = Vec::new();
self.dfs_cycle(start, start, graph, &mut visited, &mut path)
}
fn dfs_cycle(
&self,
current: &str,
target: &str,
graph: &HashMap<String, HashSet<String>>,
visited: &mut HashSet<String>,
path: &mut Vec<String>,
) -> Option<Vec<String>> {
visited.insert(current.to_string());
path.push(current.to_string());
if let Some(callees) = graph.get(current) {
for callee in callees {
if callee == target && path.len() > 1 {
let mut cycle = path.clone();
cycle.push(target.to_string());
return Some(cycle);
}
if !visited.contains(callee.as_str()) {
if let Some(cycle) = self.dfs_cycle(callee, target, graph, visited, path) {
return Some(cycle);
}
}
}
}
path.pop();
None }
fn has_bounded_base_case(&self, func_node: &Node, source: &str) -> bool {
let params = self.collect_param_names(func_node, source);
if params.is_empty() {
return false; }
let body = match func_node.child_by_field_name("body") {
Some(b) => b,
None => return false,
};
self.find_param_guarded_return(&body, source, ¶ms)
}
fn collect_param_names(&self, func_node: &Node, source: &str) -> HashSet<String> {
let mut params = HashSet::new();
let declarator = match func_node.child_by_field_name("declarator") {
Some(d) => d,
None => return params,
};
self.walk_for_params(&declarator, source, &mut params);
params
}
fn walk_for_params(&self, node: &Node, source: &str, params: &mut HashSet<String>) {
if node.kind() == "parameter_declaration" {
if let Some(decl) = node.child_by_field_name("declarator") {
if let Some(name) = self.find_identifier_in_declarator(&decl, source) {
params.insert(name);
}
}
}
for i in 0..node.child_count() {
if let Some(child) = node.child(i) {
self.walk_for_params(&child, source, params);
}
}
}
fn find_param_guarded_return(
&self,
node: &Node,
source: &str,
params: &HashSet<String>,
) -> bool {
if node.kind() == "if_statement" {
if let Some(cond) = node.child_by_field_name("condition") {
if self.references_any_param(&cond, source, params) {
if let Some(consequence) = node.child_by_field_name("consequence") {
if self.contains_return(&consequence) {
return true;
}
}
}
}
}
for i in 0..node.child_count() {
if let Some(child) = node.child(i) {
if self.find_param_guarded_return(&child, source, params) {
return true;
}
}
}
false
}
fn references_any_param(&self, node: &Node, source: &str, params: &HashSet<String>) -> bool {
if node.kind() == "identifier" {
let name = get_node_text(node, source);
if params.contains(name.trim()) {
return true;
}
}
for i in 0..node.child_count() {
if let Some(child) = node.child(i) {
if self.references_any_param(&child, source, params) {
return true;
}
}
}
false
}
fn contains_return(&self, node: &Node) -> bool {
if node.kind() == "return_statement" {
return true;
}
for i in 0..node.child_count() {
if let Some(child) = node.child(i) {
if self.contains_return(&child) {
return true;
}
}
}
false
}
fn check_function(&self, node: &Node, source: &str, violations: &mut Vec<RuleViolation>) {
let func_name = match self.extract_func_name(node, source) {
Some(n) => n,
None => return,
};
let mut callees = HashSet::new();
if let Some(body) = node.child_by_field_name("body") {
self.collect_callees(&body, source, &mut callees);
}
if callees.contains(&func_name) {
if self.has_bounded_base_case(node, source) {
return;
}
violations.push(RuleViolation {
rule_id: self.rule_id().to_string(),
severity: self.severity(),
message: format!(
"Function '{}' calls itself directly (direct recursion)",
func_name
),
file_path: String::new(),
line: node.start_position().row + 1,
column: node.start_position().column + 1,
suggestion: Some("Refactor to use iteration instead of recursion".to_string()),
requires_manual_review: None,
});
return; }
let graph = self.call_graph.borrow();
if graph.is_empty() {
return; }
let mut local_graph = graph.clone();
local_graph.insert(func_name.clone(), callees);
if let Some(cycle) = self.find_cycle(&func_name, &local_graph) {
let cycle_str = cycle.join(" -> ");
violations.push(RuleViolation {
rule_id: self.rule_id().to_string(),
severity: self.severity(),
message: format!(
"Function '{}' participates in indirect recursion: {}",
func_name, cycle_str
),
file_path: String::new(),
line: node.start_position().row + 1,
column: node.start_position().column + 1,
suggestion: Some("Refactor to eliminate the recursion cycle".to_string()),
requires_manual_review: None,
});
}
}
fn walk_node(&self, node: &Node, source: &str, violations: &mut Vec<RuleViolation>) {
match node.kind() {
"function_definition" => {
self.check_function(node, source, violations);
}
_ => {
for i in 0..node.child_count() {
if let Some(child) = node.child(i) {
self.walk_node(&child, source, violations);
}
}
}
}
}
}
impl CertRule for Msc04C {
fn rule_id(&self) -> &'static str {
"MSC04-C"
}
fn description(&self) -> &'static str {
"Do not use recursive function calls"
}
fn severity(&self) -> Severity {
Severity::Medium
}
fn category(&self) -> RuleCategory {
RuleCategory::Rule
}
fn cert_id(&self) -> &'static str {
"MSC04-C"
}
fn set_project_context(&self, context: &ProjectContext) {
*self.call_graph.borrow_mut() = context.call_graph.clone();
}
fn check(&self, node: &Node, source: &str) -> Vec<RuleViolation> {
let mut violations = Vec::new();
self.walk_node(node, source, &mut violations);
violations
}
}