use super::super::{CertRule, RuleViolation};
use crate::manifest::{RuleCategory, Severity};
use crate::utility::cert_c::ast_utils::get_node_text;
use std::collections::HashSet;
use tree_sitter::Node;
pub struct Fio37C;
impl CertRule for Fio37C {
fn rule_id(&self) -> &'static str {
"FIO37-C"
}
fn description(&self) -> &'static str {
"Do not assume that fgets() or fgetws() returns a nonempty string when successful"
}
fn severity(&self) -> Severity {
Severity::High
}
fn category(&self) -> RuleCategory {
RuleCategory::Recommendation
}
fn cert_id(&self) -> &'static str {
"FIO37-C"
}
fn check(&self, node: &Node, source: &str) -> Vec<RuleViolation> {
let mut violations = Vec::new();
self.check_node(node, source, &mut violations);
violations
}
}
impl Fio37C {
fn check_node(&self, node: &Node, source: &str, violations: &mut Vec<RuleViolation>) {
if node.kind() == "compound_statement" {
self.check_function_body(node, source, violations);
}
for i in 0..node.child_count() {
if let Some(child) = node.child(i) {
self.check_node(&child, source, violations);
}
}
}
fn check_function_body(&self, body: &Node, source: &str, violations: &mut Vec<RuleViolation>) {
let mut fgets_vars = HashSet::new();
self.collect_fgets_vars(body, source, &mut fgets_vars);
self.check_strlen_usage(body, source, &fgets_vars, violations);
}
fn collect_fgets_vars(&self, node: &Node, source: &str, fgets_vars: &mut HashSet<String>) {
if node.kind() == "call_expression" {
if let Some(function) = node.child_by_field_name("function") {
let func_name = get_node_text(&function, source);
if func_name == "fgets" || func_name == "fgetws" {
if let Some(args) = node.child_by_field_name("arguments") {
for i in 0..args.child_count() {
if let Some(arg) = args.child(i) {
if arg.kind() != "(" && arg.kind() != ")" && arg.kind() != "," {
let var_name = get_node_text(&arg, source);
fgets_vars.insert(var_name.to_string());
break;
}
}
}
}
}
}
}
for i in 0..node.child_count() {
if let Some(child) = node.child(i) {
self.collect_fgets_vars(&child, source, fgets_vars);
}
}
}
fn check_strlen_usage(
&self,
node: &Node,
source: &str,
fgets_vars: &HashSet<String>,
violations: &mut Vec<RuleViolation>,
) {
if node.kind() == "binary_expression" {
let operator = self.get_operator(node, source);
if operator == "-" {
if let Some(left) = node.child_by_field_name("left") {
if let Some(var_name) = self.is_strlen_of_var(&left, source) {
if fgets_vars.contains(&var_name) {
self.report_violation(node, &var_name, source, violations);
}
}
}
}
}
for i in 0..node.child_count() {
if let Some(child) = node.child(i) {
self.check_strlen_usage(&child, source, fgets_vars, violations);
}
}
}
fn get_operator(&self, binary_expr: &Node, source: &str) -> String {
for i in 0..binary_expr.child_count() {
if let Some(child) = binary_expr.child(i) {
let text = get_node_text(&child, source);
if text == "-" || text == "+" || text == "*" || text == "/" {
return text.to_string();
}
}
}
String::new()
}
fn is_strlen_of_var(&self, expr: &Node, source: &str) -> Option<String> {
if expr.kind() == "call_expression" {
if let Some(function) = expr.child_by_field_name("function") {
let func_name = get_node_text(&function, source);
if func_name == "strlen" {
if let Some(args) = expr.child_by_field_name("arguments") {
for i in 0..args.child_count() {
if let Some(arg) = args.child(i) {
if arg.kind() != "(" && arg.kind() != ")" && arg.kind() != "," {
let var_name = get_node_text(&arg, source);
return Some(var_name.to_string());
}
}
}
}
}
}
}
None
}
fn report_violation(
&self,
node: &Node,
var_name: &str,
source: &str,
violations: &mut Vec<RuleViolation>,
) {
let expr_text = get_node_text(&node, source);
violations.push(RuleViolation {
rule_id: self.rule_id().to_string(),
severity: Severity::High,
message: format!(
"Dangerous use of strlen() on fgets() result '{}': '{}' - fgets() may return empty string (strlen=0), causing underflow",
var_name, expr_text.trim()
),
file_path: String::new(),
line: node.start_position().row + 1,
column: node.start_position().column + 1,
suggestion: Some(format!(
"Check if '{}' is non-empty before using strlen() in arithmetic. Use strchr() to find characters instead of assuming strlen > 0",
var_name
)),
..Default::default()
});
}
}