use super::super::{CertRule, RuleViolation};
use crate::manifest::{RuleCategory, Severity};
use crate::utility::cert_c::ast_utils::get_node_text;
use std::collections::HashSet;
use tree_sitter::Node;
#[derive(Debug)]
pub struct Flp02C;
impl Flp02C {
pub fn new() -> Self {
Flp02C
}
fn is_float_type(&self, type_text: &str) -> bool {
type_text.contains("float") || type_text.contains("double")
}
fn is_equality_operator(&self, op: &str) -> bool {
op == "==" || op == "!="
}
fn is_zero_float_literal(text: &str) -> bool {
let t = text
.trim()
.trim_start_matches('-')
.trim_end_matches(['f', 'F', 'l', 'L']);
matches!(t, "0.0" | "0." | ".0" | "0")
}
#[allow(dead_code)]
fn appears_to_be_float_expression(&self, node: &Node, source: &str) -> bool {
if self.has_float_characteristics(node, source) {
return true;
}
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
if self.appears_to_be_float_expression(&child, source) {
return true;
}
}
false
}
fn has_float_characteristics(&self, node: &Node, source: &str) -> bool {
let kind = node.kind();
if kind == "number_literal" {
let text = get_node_text(node, source);
if text.contains('.') && !text.contains("->") && !text.contains("...") {
return true;
}
if text.ends_with('f') || text.ends_with('F') {
return true;
}
if text.contains('e') || text.contains('E') {
return true;
}
}
if kind == "cast_expression" {
if let Some(type_node) = node.child_by_field_name("type") {
let type_text = get_node_text(&type_node, source);
if type_text.contains("float") || type_text.contains("double") {
return true;
}
}
}
if kind == "call_expression" {
if let Some(func_node) = node.child_by_field_name("function") {
if func_node.kind() == "identifier" {
let func_name = get_node_text(&func_node, source);
let float_funcs = [
"sqrtf", "sqrt", "powf", "pow", "sinf", "sin", "cosf", "cos", "tanf",
"tan", "logf", "log", "expf", "exp", "fabsf", "fabs",
];
if float_funcs.contains(&func_name) {
return true;
}
}
}
}
false
}
fn collect_float_variables(&self, node: &Node, source: &str, float_vars: &mut HashSet<String>) {
if node.kind() == "declaration" {
let decl_text = get_node_text(node, source);
if self.is_float_type(&decl_text) {
self.extract_identifiers(node, source, float_vars);
}
} else if node.kind() == "parameter_declaration" {
let decl_text = get_node_text(node, source);
if self.is_float_type(&decl_text) {
self.extract_identifiers(node, source, float_vars);
}
}
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
self.collect_float_variables(&child, source, float_vars);
}
}
fn extract_identifiers(&self, node: &Node, source: &str, identifiers: &mut HashSet<String>) {
if node.kind() == "identifier" {
identifiers.insert(get_node_text(node, source).to_string());
} else if node.kind() == "array_declarator" || node.kind() == "init_declarator" {
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
self.extract_identifiers(&child, source, identifiers);
}
} else {
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
self.extract_identifiers(&child, source, identifiers);
}
}
}
fn involves_float_variable(
&self,
node: &Node,
source: &str,
float_vars: &HashSet<String>,
) -> bool {
if node.kind() == "identifier" {
let name = get_node_text(node, source);
if float_vars.contains(name) {
return true;
}
}
if self.has_float_characteristics(node, source) {
return true;
}
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
if self.involves_float_variable(&child, source, float_vars) {
return true;
}
}
false
}
fn check_float_equality(
&self,
node: &Node,
source: &str,
float_vars: &HashSet<String>,
violations: &mut Vec<RuleViolation>,
) {
if node.kind() != "binary_expression" {
return;
}
if let Some(operator_node) = node.child_by_field_name("operator") {
let operator = get_node_text(&operator_node, source);
if !self.is_equality_operator(operator) {
return;
}
let left_is_float = if let Some(left) = node.child_by_field_name("left") {
self.involves_float_variable(&left, source, float_vars)
} else {
false
};
let right_is_float = if let Some(right) = node.child_by_field_name("right") {
self.involves_float_variable(&right, source, float_vars)
} else {
false
};
let left_text = node
.child_by_field_name("left")
.map(|n| get_node_text(&n, source).to_string())
.unwrap_or_default();
let right_text = node
.child_by_field_name("right")
.map(|n| get_node_text(&n, source).to_string())
.unwrap_or_default();
if Self::is_zero_float_literal(&left_text) || Self::is_zero_float_literal(&right_text) {
return;
}
if left_is_float && right_is_float {
violations.push(RuleViolation {
rule_id: "FLP02-C".to_string(),
severity: Severity::Low,
line: node.start_position().row + 1,
column: node.start_position().column + 1,
message: format!(
"Floating-point {} comparison may produce unexpected results due to representation error",
if operator == "==" { "equality" } else { "inequality" }
),
file_path: String::new(),
suggestion: Some(
"Use epsilon-based comparison (e.g., fabs(x - y) < EPSILON) or consider using integer arithmetic for precise computation".to_string(),
),
requires_manual_review: Some(false),
});
}
}
}
fn traverse(
&self,
node: &Node,
source: &str,
float_vars: &HashSet<String>,
violations: &mut Vec<RuleViolation>,
) {
if node.kind() == "binary_expression" {
self.check_float_equality(node, source, float_vars, violations);
}
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
self.traverse(&child, source, float_vars, violations);
}
}
}
impl CertRule for Flp02C {
fn rule_id(&self) -> &'static str {
"FLP02-C"
}
fn description(&self) -> &'static str {
"Avoid using floating-point numbers when precise computation is needed"
}
fn category(&self) -> RuleCategory {
RuleCategory::Recommendation
}
fn severity(&self) -> Severity {
Severity::Low
}
fn cert_id(&self) -> &'static str {
"FLP02-C"
}
fn check(&self, root: &Node, source: &str) -> Vec<RuleViolation> {
let mut float_vars = HashSet::new();
self.collect_float_variables(root, source, &mut float_vars);
let mut violations = Vec::new();
self.traverse(root, source, &float_vars, &mut violations);
violations
}
}