use super::super::{CertRule, RuleViolation};
use crate::manifest::{RuleCategory, Severity};
use crate::utility::cert_c::ast_utils::get_node_text;
use tree_sitter::Node;
pub struct Fio39C;
const OUTPUT_FUNCTIONS: &[&str] = &[
"fwrite", "fprintf", "fputs", "fputc", "putc", "fputwc", "putwc", "fputws",
];
const INPUT_FUNCTIONS: &[&str] = &[
"fread", "fscanf", "fgets", "fgetc", "getc", "fgetwc", "getwc", "fgetws", "ungetc", "ungetwc",
];
const POSITIONING_FUNCTIONS: &[&str] = &["fseek", "fflush", "fsetpos", "rewind"];
impl CertRule for Fio39C {
fn rule_id(&self) -> &'static str {
"FIO39-C"
}
fn description(&self) -> &'static str {
"Do not alternately input and output from a stream without an intervening flush or positioning call"
}
fn severity(&self) -> Severity {
Severity::Low
}
fn category(&self) -> RuleCategory {
RuleCategory::Recommendation
}
fn cert_id(&self) -> &'static str {
"FIO39-C"
}
fn check(&self, node: &Node, source: &str) -> Vec<RuleViolation> {
let mut violations = Vec::new();
self.check_function_body(node, source, &mut violations);
violations
}
}
impl Fio39C {
fn check_function_body(&self, node: &Node, source: &str, violations: &mut Vec<RuleViolation>) {
if node.kind() == "function_definition" {
if let Some(body) = node.child_by_field_name("body") {
self.analyze_compound_statement(&body, source, violations);
}
}
for i in 0..node.child_count() {
if let Some(child) = node.child(i) {
self.check_function_body(&child, source, violations);
}
}
}
fn analyze_compound_statement(
&self,
node: &Node,
source: &str,
violations: &mut Vec<RuleViolation>,
) {
#[derive(Clone, Copy, PartialEq)]
enum IoOp {
None,
Input,
Output,
}
let mut last_op = IoOp::None;
let mut _last_op_line = 0;
let mut last_op_name = String::new();
let calls = self.collect_calls_in_order(node, source);
for (func_name, line, col) in calls {
if POSITIONING_FUNCTIONS.contains(&func_name.as_str()) {
last_op = IoOp::None;
} else if OUTPUT_FUNCTIONS.contains(&func_name.as_str()) {
if last_op == IoOp::Input {
violations.push(RuleViolation {
rule_id: self.rule_id().to_string(),
severity: Severity::Low,
message: format!(
"Output function '{}' called after input function '{}' without intervening fseek/fflush/fsetpos/rewind",
func_name, last_op_name
),
file_path: String::new(),
line,
column: col,
suggestion: Some(
"Add fseek(), fflush(), fsetpos(), or rewind() between input and output operations".to_string(),
),
..Default::default()
});
}
last_op = IoOp::Output;
_last_op_line = line;
last_op_name = func_name;
} else if INPUT_FUNCTIONS.contains(&func_name.as_str()) {
if last_op == IoOp::Output {
violations.push(RuleViolation {
rule_id: self.rule_id().to_string(),
severity: Severity::Low,
message: format!(
"Input function '{}' called after output function '{}' without intervening fseek/fflush/fsetpos/rewind",
func_name, last_op_name
),
file_path: String::new(),
line,
column: col,
suggestion: Some(
"Add fseek(), fflush(), fsetpos(), or rewind() between output and input operations".to_string(),
),
..Default::default()
});
}
last_op = IoOp::Input;
_last_op_line = line;
last_op_name = func_name;
}
}
}
fn collect_calls_in_order(&self, node: &Node, source: &str) -> Vec<(String, usize, usize)> {
let mut calls = Vec::new();
self.collect_calls_recursive(node, source, &mut calls);
calls.sort_by_key(|c| (c.1, c.2));
calls
}
fn collect_calls_recursive(
&self,
node: &Node,
source: &str,
calls: &mut Vec<(String, usize, usize)>,
) {
if node.kind() == "call_expression" {
if let Some(function) = node.child_by_field_name("function") {
let func_name = get_node_text(&function, source);
let pos = node.start_position();
calls.push((func_name.to_string(), pos.row + 1, pos.column + 1));
}
}
for i in 0..node.child_count() {
if let Some(child) = node.child(i) {
self.collect_calls_recursive(&child, source, calls);
}
}
}
}