use super::super::{CertRule, RuleViolation};
use crate::manifest::{RuleCategory, Severity};
use crate::utility::cert_c::ast_utils::get_node_text;
use tree_sitter::Node;
pub struct Err30C;
impl CertRule for Err30C {
fn rule_id(&self) -> &'static str {
"ERR30-C"
}
fn description(&self) -> &'static str {
"Take care when reading errno"
}
fn severity(&self) -> Severity {
Severity::Medium
}
fn category(&self) -> RuleCategory {
RuleCategory::Rule
}
fn cert_id(&self) -> &'static str {
"ERR30-C"
}
fn check(&self, node: &Node, source: &str) -> Vec<RuleViolation> {
let mut violations = Vec::new();
self.check_node(node, source, &mut violations);
violations
}
}
impl Err30C {
fn check_node(&self, node: &Node, source: &str, violations: &mut Vec<RuleViolation>) {
match node.kind() {
"if_statement" => {
self.check_errno_in_if(node, source, violations);
}
"call_expression" => {
self.check_inband_function_call(node, source, violations);
}
_ => {}
}
for i in 0..node.child_count() {
if let Some(child) = node.child(i) {
self.check_node(&child, source, violations);
}
}
}
fn check_errno_in_if(&self, if_node: &Node, source: &str, violations: &mut Vec<RuleViolation>) {
if let Some(condition) = if_node.child_by_field_name("condition") {
let condition_text = get_node_text(&condition, source);
if condition_text.contains("errno") {
if let Some(recent_call) = self.find_recent_outofband_call(if_node, source) {
let function_name =
if let Some(func_node) = recent_call.child_by_field_name("function") {
get_node_text(&func_node, source)
} else {
"<unknown>"
};
if !self.has_return_value_check_before_errno(&recent_call, if_node, source) {
let start_point = condition.start_position();
violations.push(RuleViolation {
rule_id: self.rule_id().to_string(),
severity: Severity::Medium,
message: format!(
"errno checked without verifying error occurred via return value of '{}' - must check return value first",
function_name
),
file_path: String::new(),
line: start_point.row + 1,
column: start_point.column + 1,
suggestion: Some(format!(
"Check return value before errno: 'if ({}(...) == ERROR_VALUE) {{ if (errno) {{ ... }} }}'",
function_name
)),
..Default::default()
});
}
}
}
}
}
fn check_inband_function_call(
&self,
call_node: &Node,
source: &str,
violations: &mut Vec<RuleViolation>,
) {
if let Some(function_node) = call_node.child_by_field_name("function") {
let function_name = get_node_text(&function_node, source);
if self.is_inband_function(&function_name) {
if !self.has_errno_reset_before(call_node, source) {
let start_point = call_node.start_position();
violations.push(RuleViolation {
rule_id: self.rule_id().to_string(),
severity: Severity::Medium,
message: format!(
"In-band function '{}' called without setting errno = 0 first - errno must be reset before calling in-band functions",
function_name
),
file_path: String::new(),
line: start_point.row + 1,
column: start_point.column + 1,
suggestion: Some(format!(
"Set errno before call: 'errno = 0; val = {}(...); if (errno == ERANGE) {{ ... }}'",
function_name
)),
..Default::default()
});
}
}
}
}
fn has_errno_reset_before(&self, call_node: &Node, source: &str) -> bool {
let mut current = call_node.parent();
while let Some(parent) = current {
if parent.kind() == "compound_statement" || parent.kind() == "function_definition" {
let call_byte_start = call_node.start_byte();
for i in 0..parent.child_count() {
if let Some(child) = parent.child(i) {
if child.end_byte() < call_byte_start {
if self.contains_errno_reset(&child, source) {
return true;
}
}
}
}
break;
}
current = parent.parent();
}
false
}
fn contains_errno_reset(&self, node: &Node, source: &str) -> bool {
if node.kind() == "assignment_expression" {
if let (Some(left), Some(right)) = (
node.child_by_field_name("left"),
node.child_by_field_name("right"),
) {
let left_text = get_node_text(&left, source);
let right_text = get_node_text(&right, source);
if left_text == "errno" && right_text == "0" {
return true;
}
}
}
for i in 0..node.child_count() {
if let Some(child) = node.child(i) {
if self.contains_errno_reset(&child, source) {
return true;
}
}
}
false
}
fn find_recent_outofband_call<'a>(&self, if_node: &Node<'a>, source: &str) -> Option<Node<'a>> {
let mut current = if_node.parent();
while let Some(parent) = current {
if parent.kind() == "compound_statement" || parent.kind() == "function_definition" {
let if_byte_start = if_node.start_byte();
let mut most_recent: Option<Node> = None;
for i in 0..parent.child_count() {
if let Some(child) = parent.child(i) {
if child.end_byte() < if_byte_start {
if let Some(call) = self.find_outofband_call_in_node(&child, source) {
most_recent = Some(call);
}
}
}
}
return most_recent;
}
current = parent.parent();
}
None
}
fn find_outofband_call_in_node<'a>(&self, node: &Node<'a>, source: &str) -> Option<Node<'a>> {
if node.kind() == "call_expression" {
if let Some(function_node) = node.child_by_field_name("function") {
let function_name = get_node_text(&function_node, source);
if self.is_outofband_function(&function_name) {
return Some(*node);
}
}
}
for i in 0..node.child_count() {
if let Some(child) = node.child(i) {
if let Some(call) = self.find_outofband_call_in_node(&child, source) {
return Some(call);
}
}
}
None
}
fn has_return_value_check_before_errno(
&self,
call_node: &Node,
errno_if_node: &Node,
source: &str,
) -> bool {
let call_end = call_node.end_byte();
let errno_start = errno_if_node.start_byte();
if let Some(function_node) = call_node.child_by_field_name("function") {
let function_name = get_node_text(&function_node, source);
let mut current = call_node.parent();
while let Some(parent) = current {
if parent.kind() == "compound_statement" || parent.kind() == "function_definition" {
for i in 0..parent.child_count() {
if let Some(child) = parent.child(i) {
if child.start_byte() >= call_end && child.end_byte() <= errno_start {
if child.kind() == "if_statement" {
if let Some(condition) = child.child_by_field_name("condition")
{
let condition_text = get_node_text(&condition, source);
if self.checks_return_value_for_function(
&condition_text,
&function_name,
) {
return true;
}
}
}
}
}
}
break;
}
current = parent.parent();
}
}
false
}
fn checks_return_value_for_function(&self, condition: &str, function_name: &str) -> bool {
match function_name {
"ftell" => condition.contains("== -1") || condition.contains("== -1L"),
"fopen" | "freopen" => {
condition.contains("== NULL")
|| condition.contains("!= NULL")
|| condition.contains("== 0")
|| condition.contains("!= 0")
}
"signal" => condition.contains("== SIG_ERR"),
"mbrtowc" | "wcrtomb" => {
condition.contains("== (size_t)-1") || condition.contains("< 0")
}
_ => {
condition.contains("==")
|| condition.contains("!=")
|| condition.contains("<")
|| condition.contains(">")
}
}
}
fn is_outofband_function(&self, function_name: &str) -> bool {
matches!(
function_name,
"ftell"
| "fopen"
| "freopen"
| "fclose"
| "fflush"
| "fseek"
| "signal"
| "mbrtowc"
| "wcrtomb"
| "mbtowc"
| "wctomb"
)
}
fn is_inband_function(&self, function_name: &str) -> bool {
matches!(
function_name,
"strtol"
| "strtoul"
| "strtoll"
| "strtoull"
| "strtof"
| "strtod"
| "strtold"
| "fgetwc"
| "fgetws"
| "getwc"
| "getwchar"
)
}
}