use super::super::{CertRule, RuleViolation};
use crate::manifest::{RuleCategory, Severity};
use crate::utility::cert_c::ast_utils::get_node_text;
use std::collections::HashSet;
use tree_sitter::Node;
pub struct Fio18C;
impl CertRule for Fio18C {
fn rule_id(&self) -> &'static str {
"FIO18-C"
}
fn description(&self) -> &'static str {
"Never expect fwrite() to terminate the writing process at a null character"
}
fn severity(&self) -> Severity {
Severity::Medium
}
fn category(&self) -> RuleCategory {
RuleCategory::Recommendation
}
fn cert_id(&self) -> &'static str {
"FIO18-C"
}
fn check(&self, node: &Node, source: &str) -> Vec<RuleViolation> {
let mut violations = Vec::new();
let mut strlen_vars: HashSet<String> = HashSet::new();
self.collect_strlen_assignments(node, source, &mut strlen_vars);
self.check_fwrite_usage(node, source, &strlen_vars, &mut violations);
violations
}
}
impl Fio18C {
fn collect_strlen_assignments(
&self,
node: &Node,
source: &str,
strlen_vars: &mut HashSet<String>,
) {
if node.kind() == "assignment_expression" || node.kind() == "init_declarator" {
let node_text = get_node_text(node, source);
if node_text.contains("strlen") {
if let Some(left) = node.child_by_field_name("left") {
let var_name = get_node_text(&left, source).trim().to_string();
strlen_vars.insert(var_name);
} else if let Some(declarator) = node.child_by_field_name("declarator") {
let var_name = self.extract_identifier(&declarator, source);
if !var_name.is_empty() {
strlen_vars.insert(var_name);
}
}
}
}
for i in 0..node.child_count() {
if let Some(child) = node.child(i) {
self.collect_strlen_assignments(&child, source, strlen_vars);
}
}
}
fn extract_identifier(&self, node: &Node, source: &str) -> String {
if node.kind() == "identifier" {
return get_node_text(node, source).trim().to_string();
}
for i in 0..node.child_count() {
if let Some(child) = node.child(i) {
let result = self.extract_identifier(&child, source);
if !result.is_empty() {
return result;
}
}
}
String::new()
}
fn check_fwrite_usage(
&self,
node: &Node,
source: &str,
strlen_vars: &HashSet<String>,
violations: &mut Vec<RuleViolation>,
) {
if node.kind() == "call_expression" {
if let Some(function) = node.child_by_field_name("function") {
let func_name = get_node_text(&function, source);
if func_name == "fwrite" {
if let Some(args) = node.child_by_field_name("arguments") {
self.analyze_fwrite_args(&args, source, node, strlen_vars, violations);
}
}
}
}
for i in 0..node.child_count() {
if let Some(child) = node.child(i) {
self.check_fwrite_usage(&child, source, strlen_vars, violations);
}
}
}
fn analyze_fwrite_args(
&self,
args: &Node,
source: &str,
call_node: &Node,
strlen_vars: &HashSet<String>,
violations: &mut Vec<RuleViolation>,
) {
let arg_list = self.extract_args(args, source);
if arg_list.len() >= 3 {
let buffer_arg = &arg_list[0];
let nmemb = &arg_list[2];
if nmemb.contains("strlen") {
return;
}
let nmemb_trimmed = nmemb.trim();
if strlen_vars.contains(nmemb_trimmed) {
return;
}
if nmemb.contains("sizeof") {
violations.push(RuleViolation {
rule_id: self.rule_id().to_string(),
message: format!(
"fwrite() using sizeof() for count argument when writing '{}'. \
May write uninitialized data beyond null terminator.",
buffer_arg
),
severity: self.severity(),
line: call_node.start_position().row + 1,
column: call_node.start_position().column + 1,
file_path: String::new(),
suggestion: Some(
"For strings, use strlen(buffer)+1 to write only the string content. \
sizeof() writes the entire buffer regardless of string length."
.to_string(),
),
requires_manual_review: Some(true),
});
return;
}
if !nmemb.chars().all(|c| c.is_ascii_digit()) {
violations.push(RuleViolation {
rule_id: self.rule_id().to_string(),
message: format!(
"fwrite() count argument '{}' not derived from strlen({}). \
May write incorrect number of bytes.",
nmemb, buffer_arg
),
severity: self.severity(),
line: call_node.start_position().row + 1,
column: call_node.start_position().column + 1,
file_path: String::new(),
suggestion: Some(format!(
"For null-terminated strings, use strlen({}) + 1 to include \
the null terminator but avoid writing uninitialized data.",
buffer_arg
)),
requires_manual_review: Some(true),
});
}
}
}
fn extract_args(&self, args: &Node, source: &str) -> Vec<String> {
let mut result = Vec::new();
for i in 0..args.child_count() {
if let Some(child) = args.child(i) {
if child.kind() != "(" && child.kind() != ")" && child.kind() != "," {
result.push(get_node_text(&child, source).trim().to_string());
}
}
}
result
}
}