use super::super::{CertRule, RuleViolation};
use crate::manifest::{RuleCategory, Severity};
use crate::utility::cert_c::ast_utils::get_node_text;
use tree_sitter::Node;
pub struct Fio40C;
impl CertRule for Fio40C {
fn rule_id(&self) -> &'static str {
"FIO40-C"
}
fn description(&self) -> &'static str {
"Reset strings on fgets() or fgetws() failure"
}
fn severity(&self) -> Severity {
Severity::Low
}
fn category(&self) -> RuleCategory {
RuleCategory::Recommendation
}
fn cert_id(&self) -> &'static str {
"FIO40-C"
}
fn check(&self, node: &Node, source: &str) -> Vec<RuleViolation> {
let mut violations = Vec::new();
self.check_node(node, source, &mut violations);
violations
}
}
impl Fio40C {
fn check_node(&self, node: &Node, source: &str, violations: &mut Vec<RuleViolation>) {
if node.kind() == "if_statement" {
if let Some(condition) = node.child_by_field_name("condition") {
if let Some((func_name, buffer_var)) =
self.find_fgets_null_check(&condition, source)
{
if let Some(consequence) = node.child_by_field_name("consequence") {
if !self.has_buffer_reset(&consequence, &buffer_var, source) {
let start_point = node.start_position();
violations.push(RuleViolation {
rule_id: self.rule_id().to_string(),
severity: Severity::Low,
message: format!(
"Buffer '{}' not reset after {}() failure - buffer contents are indeterminate and must be reset",
buffer_var, func_name
),
file_path: String::new(),
line: start_point.row + 1,
column: start_point.column + 1,
suggestion: Some(format!(
"Reset buffer in failure branch: '{}[0] = {}\\0';'",
buffer_var,
if func_name == "fgetws" { "L'" } else { "'" }
)),
..Default::default()
});
}
}
}
}
}
for i in 0..node.child_count() {
if let Some(child) = node.child(i) {
self.check_node(&child, source, violations);
}
}
}
fn find_fgets_null_check(&self, condition: &Node, source: &str) -> Option<(String, String)> {
if condition.kind() == "binary_expression" {
if let (Some(left), Some(operator), Some(right)) = (
condition.child_by_field_name("left"),
condition.child_by_field_name("operator"),
condition.child_by_field_name("right"),
) {
let op = get_node_text(&operator, source);
if op == "==" || op == "!=" {
if get_node_text(&right, source).trim() == "NULL" {
if let Some((func, buf)) = self.extract_fgets_call(&left, source) {
return Some((func, buf));
}
}
if get_node_text(&left, source).trim() == "NULL" {
if let Some((func, buf)) = self.extract_fgets_call(&right, source) {
return Some((func, buf));
}
}
}
}
}
if condition.kind() == "unary_expression" {
if let Some(argument) = condition.child_by_field_name("argument") {
return self.extract_fgets_call(&argument, source);
}
}
if condition.kind() == "parenthesized_expression" {
for i in 0..condition.child_count() {
if let Some(child) = condition.child(i) {
if let Some(result) = self.find_fgets_null_check(&child, source) {
return Some(result);
}
}
}
}
None
}
fn extract_fgets_call(&self, node: &Node, source: &str) -> Option<(String, String)> {
if node.kind() == "call_expression" {
if let Some(function) = node.child_by_field_name("function") {
let func_name = get_node_text(&function, source);
if func_name == "fgets" || func_name == "fgetws" {
if let Some(arguments) = node.child_by_field_name("arguments") {
let args = self.extract_arguments(&arguments);
if let Some(first_arg) = args.first() {
let buf_var = get_node_text(first_arg, source);
return Some((func_name.to_string(), buf_var.to_string()));
}
}
}
}
}
None
}
fn has_buffer_reset(&self, node: &Node, buffer_var: &str, source: &str) -> bool {
match node.kind() {
"compound_statement" => {
for i in 0..node.child_count() {
if let Some(child) = node.child(i) {
if self.is_buffer_reset_statement(&child, buffer_var, source) {
return true;
}
}
}
false
}
_ => {
self.is_buffer_reset_statement(node, buffer_var, source)
}
}
}
fn is_buffer_reset_statement(&self, node: &Node, buffer_var: &str, source: &str) -> bool {
if node.kind() == "expression_statement" {
if let Some(assignment) = self.find_assignment(node) {
if let (Some(left), Some(right)) = (
assignment.child_by_field_name("left"),
assignment.child_by_field_name("right"),
) {
let left_text = get_node_text(&left, source);
let right_text = get_node_text(&right, source).trim();
let is_null_char =
right_text == "'\\0'" || right_text == "L'\\0'" || right_text == "0";
if is_null_char {
if left_text.contains(buffer_var) && left_text.contains("[0]") {
return true;
}
if left.kind() == "pointer_expression" {
let deref_text = get_node_text(&left, source);
if deref_text == format!("*{}", buffer_var) {
return true;
}
}
}
}
}
}
false
}
fn find_assignment<'a>(&self, node: &Node<'a>) -> Option<Node<'a>> {
if node.kind() == "assignment_expression" {
return Some(*node);
}
for i in 0..node.child_count() {
if let Some(child) = node.child(i) {
if let Some(assignment) = self.find_assignment(&child) {
return Some(assignment);
}
}
}
None
}
fn extract_arguments<'a>(&self, arguments_node: &Node<'a>) -> Vec<Node<'a>> {
let mut args = Vec::new();
for i in 0..arguments_node.child_count() {
if let Some(child) = arguments_node.child(i) {
if child.kind() != "," && child.kind() != "(" && child.kind() != ")" {
args.push(child);
}
}
}
args
}
}