use super::super::{CertRule, RuleViolation};
use crate::manifest::{RuleCategory, Severity};
use crate::utility::cert_c::ast_utils::get_node_text;
use std::collections::HashMap;
use tree_sitter::Node;
#[derive(Debug)]
pub struct Fio50C;
#[derive(Clone, Debug, PartialEq)]
enum OperationType {
Input,
Output,
Positioning,
}
#[derive(Clone, Debug)]
struct FileOperation {
op_type: OperationType,
file_var: String,
line: usize,
column: usize,
}
impl Fio50C {
pub fn new() -> Self {
Fio50C
}
fn is_input_function(&self, name: &str) -> bool {
matches!(
name,
"fread" | "fgets" | "fscanf" | "getc" | "fgetc" | "fgetwc" | "fgetws" | "vfscanf"
)
}
fn is_output_function(&self, name: &str) -> bool {
matches!(
name,
"fwrite" | "fputs" | "fprintf" | "putc" | "fputc" | "fputwc" | "fputws" | "vfprintf"
)
}
fn is_positioning_function(&self, name: &str) -> bool {
matches!(name, "fflush" | "fseek" | "fsetpos" | "rewind")
}
fn get_file_argument(&self, arguments: &Node, source: &str) -> Option<String> {
let mut cursor = arguments.walk();
for child in arguments.children(&mut cursor) {
if child.kind() != "(" && child.kind() != ")" && child.kind() != "," {
return Some(get_node_text(&child, source).to_string());
}
}
None
}
fn analyze_scope(&self, scope_node: &Node, source: &str, violations: &mut Vec<RuleViolation>) {
let mut file_operations: HashMap<String, Vec<FileOperation>> = HashMap::new();
self.collect_file_operations(scope_node, source, &mut file_operations);
for ops in file_operations.values() {
self.detect_alternation_violations(ops, violations);
}
}
fn is_cpp_input_operator(&self, node: &Node, source: &str) -> Option<String> {
if node.kind() == "binary_expression" {
if let Some(operator_node) = node.child_by_field_name("operator") {
let operator = get_node_text(&operator_node, source);
if operator == ">>" {
if let Some(left) = node.child_by_field_name("left") {
return Some(get_node_text(&left, source).to_string());
}
}
}
}
None
}
fn is_cpp_output_operator(&self, node: &Node, source: &str) -> Option<String> {
if node.kind() == "binary_expression" {
if let Some(operator_node) = node.child_by_field_name("operator") {
let operator = get_node_text(&operator_node, source);
if operator == "<<" {
if let Some(left) = node.child_by_field_name("left") {
return Some(get_node_text(&left, source).to_string());
}
}
}
}
None
}
fn is_cpp_positioning_call(&self, node: &Node, source: &str) -> Option<String> {
if node.kind() == "call_expression" {
if let Some(function) = node.child_by_field_name("function") {
let func_text = get_node_text(&function, source);
if func_text.contains(".seekg")
|| func_text.contains(".seekp")
|| func_text.contains("->seekg")
|| func_text.contains("->seekp")
{
if let Some(dot_pos) = func_text.find('.') {
return Some(func_text[..dot_pos].to_string());
} else if let Some(arrow_pos) = func_text.find("->") {
return Some(func_text[..arrow_pos].to_string());
}
}
}
}
None
}
fn collect_file_operations(
&self,
node: &Node,
source: &str,
file_operations: &mut HashMap<String, Vec<FileOperation>>,
) {
if node.kind() == "call_expression" {
if let Some(function) = node.child_by_field_name("function") {
let func_name = get_node_text(&function, source);
let op_type = if self.is_input_function(func_name) {
Some(OperationType::Input)
} else if self.is_output_function(func_name) {
Some(OperationType::Output)
} else if self.is_positioning_function(func_name) {
Some(OperationType::Positioning)
} else {
None
};
if let Some(op_type) = op_type {
if let Some(arguments) = node.child_by_field_name("arguments") {
if let Some(file_var) = self.get_file_argument(&arguments, source) {
let operation = FileOperation {
op_type,
file_var: file_var.clone(),
line: node.start_position().row + 1,
column: node.start_position().column + 1,
};
file_operations.entry(file_var).or_default().push(operation);
}
}
}
}
if let Some(stream_var) = self.is_cpp_positioning_call(node, source) {
let operation = FileOperation {
op_type: OperationType::Positioning,
file_var: stream_var.clone(),
line: node.start_position().row + 1,
column: node.start_position().column + 1,
};
file_operations
.entry(stream_var)
.or_default()
.push(operation);
}
}
if node.kind() == "binary_expression" {
if let Some(stream_var) = self.is_cpp_input_operator(node, source) {
let operation = FileOperation {
op_type: OperationType::Input,
file_var: stream_var.clone(),
line: node.start_position().row + 1,
column: node.start_position().column + 1,
};
file_operations
.entry(stream_var)
.or_default()
.push(operation);
}
if let Some(stream_var) = self.is_cpp_output_operator(node, source) {
let operation = FileOperation {
op_type: OperationType::Output,
file_var: stream_var.clone(),
line: node.start_position().row + 1,
column: node.start_position().column + 1,
};
file_operations
.entry(stream_var)
.or_default()
.push(operation);
}
}
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
self.collect_file_operations(&child, source, file_operations);
}
}
fn detect_alternation_violations(
&self,
operations: &[FileOperation],
violations: &mut Vec<RuleViolation>,
) {
for i in 0..operations.len().saturating_sub(1) {
let current = &operations[i];
let next = &operations[i + 1];
if next.op_type == OperationType::Positioning {
continue;
}
if current.op_type == OperationType::Output && next.op_type == OperationType::Input {
let has_positioning = operations
.iter()
.skip(i + 1)
.take_while(|op| op.line < next.line)
.any(|op| op.op_type == OperationType::Positioning);
if !has_positioning {
violations.push(RuleViolation {
rule_id: "FIO50-C".to_string(),
severity: Severity::Low,
line: next.line,
column: next.column,
message: format!(
"Input operation on file stream '{}' follows output without intervening positioning call (fflush, fseek, fsetpos, or rewind)",
next.file_var
),
file_path: String::new(),
suggestion: Some(
"Insert fflush() or a positioning function (fseek, fsetpos, rewind) between output and input operations".to_string(),
),
requires_manual_review: Some(false),
});
}
}
if current.op_type == OperationType::Input && next.op_type == OperationType::Output {
let has_positioning = operations
.iter()
.skip(i + 1)
.take_while(|op| op.line < next.line)
.any(|op| op.op_type == OperationType::Positioning);
if !has_positioning {
violations.push(RuleViolation {
rule_id: "FIO50-C".to_string(),
severity: Severity::Low,
line: next.line,
column: next.column,
message: format!(
"Output operation on file stream '{}' follows input without intervening positioning call (fseek, fsetpos, or rewind)",
next.file_var
),
file_path: String::new(),
suggestion: Some(
"Insert a positioning function (fseek, fsetpos, rewind) between input and output operations, unless input encountered EOF".to_string(),
),
requires_manual_review: Some(false),
});
}
}
}
}
fn traverse(&self, node: &Node, source: &str, violations: &mut Vec<RuleViolation>) {
if node.kind() == "function_definition" {
self.analyze_scope(node, source, violations);
}
if node.kind() == "translation_unit" {
self.analyze_scope(node, source, violations);
}
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
self.traverse(&child, source, violations);
}
}
}
impl CertRule for Fio50C {
fn rule_id(&self) -> &'static str {
"FIO50-C"
}
fn description(&self) -> &'static str {
"Do not alternately input and output from a file stream without an intervening positioning call"
}
fn category(&self) -> RuleCategory {
RuleCategory::Recommendation
}
fn severity(&self) -> Severity {
Severity::Low
}
fn cert_id(&self) -> &'static str {
"FIO50-C"
}
fn check(&self, root: &Node, source: &str) -> Vec<RuleViolation> {
let mut violations = Vec::new();
self.traverse(root, source, &mut violations);
violations
}
}