use super::super::{CertRule, RuleViolation};
use crate::manifest::{RuleCategory, Severity};
use crate::utility::cert_c::ast_utils::get_node_text;
use std::collections::{HashMap, HashSet};
use tree_sitter::Node;
#[derive(Debug)]
pub struct Flp04C;
impl Flp04C {
pub fn new() -> Self {
Flp04C
}
fn has_float_format_specifier(&self, format_str: &str) -> bool {
let cleaned = format_str.replace("\\\"", "").replace("\"", "");
let float_specs = [
"%f", "%lf", "%e", "%E", "%g", "%G", "%a", "%A", "%Lf", "%Le", "%LE", "%Lg", "%LG",
"%La", "%LA",
];
for spec in &float_specs {
if cleaned.contains(spec) {
return true;
}
}
false
}
fn is_float_input_function(&self, node: &Node, source: &str) -> Option<Vec<String>> {
if node.kind() != "call_expression" {
return None;
}
let func_name = if let Some(func_node) = node.child_by_field_name("function") {
get_node_text(&func_node, source)
} else {
return None;
};
let input_funcs = ["scanf", "fscanf", "sscanf", "vfscanf", "vscanf", "vsscanf"];
if !input_funcs.contains(&func_name) {
return None;
}
let args = node.child_by_field_name("arguments")?;
let mut all_args = Vec::new();
let mut cursor = args.walk();
for child in args.children(&mut cursor) {
if child.is_named() && child.kind() != "," {
all_args.push(child);
}
}
if all_args.is_empty() {
return None;
}
let format_index = if func_name == "fscanf" || func_name == "vfscanf" {
1
} else {
0
};
if all_args.len() <= format_index {
return None;
}
let format_arg = all_args[format_index];
let format_string = get_node_text(&format_arg, source);
if !self.has_float_format_specifier(&format_string) {
return None;
}
let mut float_vars = Vec::new();
for arg in all_args.iter().skip(format_index + 1) {
let arg = *arg;
if arg.kind() == "pointer_expression" {
if let Some(var_node) = arg.child_by_field_name("argument") {
let var_name = get_node_text(&var_node, source).to_string();
float_vars.push(var_name);
}
}
}
if !float_vars.is_empty() {
Some(float_vars)
} else {
None
}
}
fn has_validation_check(&self, node: &Node, source: &str, var_name: &str) -> bool {
let text = get_node_text(node, source);
if (text.contains("isinf") || text.contains("isnan")) && text.contains(var_name) {
return true;
}
if text.contains("fpclassify") && text.contains(var_name) {
return true;
}
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
if self.has_validation_check(&child, source, var_name) {
return true;
}
}
false
}
fn is_var_used(&self, node: &Node, source: &str, var_name: &str) -> bool {
if node.kind() == "identifier" {
let name = get_node_text(node, source);
if name == var_name {
return true;
}
}
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
if self.is_var_used(&child, source, var_name) {
return true;
}
}
false
}
fn check_function(&self, func_node: &Node, source: &str, violations: &mut Vec<RuleViolation>) {
let mut float_inputs: HashMap<String, (usize, usize)> = HashMap::new(); let mut validated_vars: HashSet<String> = HashSet::new();
self.collect_float_inputs(func_node, source, &mut float_inputs);
self.collect_validations(func_node, source, &float_inputs, &mut validated_vars);
for (var_name, (line, col)) in &float_inputs {
if !validated_vars.contains(var_name) {
if self.is_var_used(func_node, source, var_name) {
violations.push(RuleViolation {
rule_id: "FLP04-C".to_string(),
severity: Severity::Medium,
line: *line,
column: *col,
message: format!(
"Floating-point input variable '{}' is not validated for exceptional values (NaN, infinity) before use",
var_name
),
file_path: String::new(),
suggestion: Some(
format!("Add validation: if (isinf({0}) || isnan({0})) {{ /* handle error */ }}", var_name)
),
requires_manual_review: Some(false),
});
}
}
}
}
fn collect_float_inputs(
&self,
node: &Node,
source: &str,
float_inputs: &mut HashMap<String, (usize, usize)>,
) {
if let Some(vars) = self.is_float_input_function(node, source) {
for var in vars {
float_inputs.insert(
var,
(
node.start_position().row + 1,
node.start_position().column + 1,
),
);
}
}
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
self.collect_float_inputs(&child, source, float_inputs);
}
}
fn collect_validations(
&self,
node: &Node,
source: &str,
float_inputs: &HashMap<String, (usize, usize)>,
validated_vars: &mut HashSet<String>,
) {
for var_name in float_inputs.keys() {
if self.has_validation_check(node, source, var_name) {
validated_vars.insert(var_name.clone());
}
}
}
fn traverse(&self, node: &Node, source: &str, violations: &mut Vec<RuleViolation>) {
if node.kind() == "function_definition" {
self.check_function(node, source, violations);
}
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
self.traverse(&child, source, violations);
}
}
}
impl CertRule for Flp04C {
fn rule_id(&self) -> &'static str {
"FLP04-C"
}
fn description(&self) -> &'static str {
"Check floating-point inputs for exceptional values"
}
fn category(&self) -> RuleCategory {
RuleCategory::Rule
}
fn severity(&self) -> Severity {
Severity::Medium
}
fn cert_id(&self) -> &'static str {
"FLP04-C"
}
fn check(&self, root: &Node, source: &str) -> Vec<RuleViolation> {
let mut violations = Vec::new();
self.traverse(root, source, &mut violations);
violations
}
}