use super::super::{CertRule, RuleViolation};
use crate::manifest::{RuleCategory, Severity};
use crate::utility::cert_c::ast_utils::get_node_text;
use crate::utility::cert_c::declarator_utils::is_pointer_declarator;
use tree_sitter::Node;
pub struct Api02C;
impl CertRule for Api02C {
fn rule_id(&self) -> &'static str {
"API02-C"
}
fn description(&self) -> &'static str {
"Functions that read or write to or from an array should take an argument to specify the source or target size"
}
fn severity(&self) -> Severity {
Severity::High
}
fn category(&self) -> RuleCategory {
RuleCategory::Recommendation
}
fn cert_id(&self) -> &'static str {
"API02-C"
}
fn check(&self, node: &Node, source: &str) -> Vec<RuleViolation> {
let mut violations = Vec::new();
self.check_node(node, source, &mut violations);
violations
}
}
impl Api02C {
fn check_node(&self, node: &Node, source: &str, violations: &mut Vec<RuleViolation>) {
if node.kind() == "declaration" {
if let Some(declarator) = node.child_by_field_name("declarator") {
if self.is_function_declarator(&declarator) {
self.check_function_parameters(&declarator, node, source, violations);
}
}
}
for i in 0..node.child_count() {
if let Some(child) = node.child(i) {
self.check_node(&child, source, violations);
}
}
}
fn is_function_declarator(&self, node: &Node) -> bool {
if node.kind() == "function_declarator" {
return true;
}
for i in 0..node.child_count() {
if let Some(child) = node.child(i) {
if self.is_function_declarator(&child) {
return true;
}
}
}
false
}
fn check_function_parameters(
&self,
declarator: &Node,
declaration: &Node,
source: &str,
violations: &mut Vec<RuleViolation>,
) {
self.check_parameters_recursive(declarator, declaration, source, violations);
}
fn check_parameters_recursive(
&self,
node: &Node,
declaration: &Node,
source: &str,
violations: &mut Vec<RuleViolation>,
) {
if node.kind() == "parameter_list" {
let mut params = Vec::new();
for i in 0..node.child_count() {
if let Some(child) = node.child(i) {
if child.kind() == "parameter_declaration" {
params.push(child);
}
}
}
for i in 0..params.len() {
if self.is_pointer_parameter(¶ms[i], source) {
if i + 1 >= params.len() || !self.is_size_t_parameter(¶ms[i + 1], source) {
self.report_violation(declaration, ¶ms[i], source, violations);
}
}
}
return;
}
for i in 0..node.child_count() {
if let Some(child) = node.child(i) {
self.check_parameters_recursive(&child, declaration, source, violations);
}
}
}
fn is_pointer_parameter(&self, param: &Node, source: &str) -> bool {
let type_node = match param.child_by_field_name("type") {
Some(t) => t,
None => return false,
};
let type_text = get_node_text(&type_node, source);
let param_text = get_node_text(param, source);
let normalized = param_text.split_whitespace().collect::<Vec<_>>().join(" ");
if normalized.starts_with("const char *") || normalized.starts_with("const char*") {
return false;
}
if type_text == "char" {
let mut has_const = false;
for i in 0..param.child_count() {
if let Some(child) = param.child(i) {
if child.kind() == "type_qualifier" {
let q = get_node_text(&child, source);
if q == "const" {
has_const = true;
}
}
}
}
if has_const {
if let Some(declarator) = param.child_by_field_name("declarator") {
if is_pointer_declarator(&declarator) {
return false;
}
}
}
}
if type_text == "wchar_t" {
let mut has_const = false;
for i in 0..param.child_count() {
if let Some(child) = param.child(i) {
if child.kind() == "type_qualifier" && get_node_text(&child, source) == "const"
{
has_const = true;
}
}
}
if has_const {
if let Some(declarator) = param.child_by_field_name("declarator") {
if is_pointer_declarator(&declarator) {
return false;
}
}
}
}
if self.is_user_defined_type(&type_text) {
return false;
}
if type_text == "void" {
if let Some(declarator) = param.child_by_field_name("declarator") {
if is_pointer_declarator(&declarator) {
return false;
}
}
}
if type_text.contains('*') {
if !type_text.contains("(*") && !type_text.contains("(* ") {
return true;
}
}
if let Some(declarator) = param.child_by_field_name("declarator") {
if is_pointer_declarator(&declarator) {
return true;
}
}
false
}
fn is_user_defined_type(&self, type_text: &str) -> bool {
let stripped = type_text
.replace("const", "")
.replace("volatile", "")
.replace("restrict", "")
.replace("struct", "")
.replace("union", "")
.replace("enum", "")
.trim()
.to_string();
if type_text.contains("struct ")
|| type_text.contains("union ")
|| type_text.contains("enum ")
{
return true;
}
let primitive_types = [
"char",
"int",
"short",
"long",
"float",
"double",
"void",
"signed",
"unsigned",
"_Bool",
"bool",
"int8_t",
"int16_t",
"int32_t",
"int64_t",
"uint8_t",
"uint16_t",
"uint32_t",
"uint64_t",
"size_t",
"ssize_t",
"ptrdiff_t",
"intptr_t",
"uintptr_t",
"wchar_t",
"FILE",
];
!primitive_types.iter().any(|p| stripped == *p)
}
fn is_size_t_parameter(&self, param: &Node, source: &str) -> bool {
let type_node = match param.child_by_field_name("type") {
Some(t) => t,
None => return false,
};
let type_text = get_node_text(&type_node, source);
matches!(
type_text,
"size_t"
| "uint32_t"
| "uint16_t"
| "uint8_t"
| "int32_t"
| "int"
| "unsigned"
| "unsigned int"
| "rsize_t"
)
}
fn report_violation(
&self,
declaration: &Node,
pointer_param: &Node,
source: &str,
violations: &mut Vec<RuleViolation>,
) {
let param_text = get_node_text(pointer_param, source);
let decl_text = get_node_text(declaration, source);
violations.push(RuleViolation {
rule_id: self.rule_id().to_string(),
severity: Severity::High,
message: format!(
"Function has pointer parameter without size argument: '{}' - Add size_t parameter to specify array capacity",
decl_text.lines().next().unwrap_or(decl_text).trim()
),
file_path: String::new(),
line: declaration.start_position().row + 1,
column: declaration.start_position().column + 1,
suggestion: Some(format!(
"Add a size_t parameter after '{}' to specify the maximum number of elements in the array",
param_text.trim()
)),
..Default::default()
});
}
}