use crate::manifest::{RuleCategory, Severity};
use crate::rules::{CertRule, RuleViolation};
use crate::utility::cert_c::ast_utils::get_node_text;
use std::collections::HashSet;
use tree_sitter::Node;
pub struct Con05C;
impl CertRule for Con05C {
fn rule_id(&self) -> &'static str {
"CON05-C"
}
fn description(&self) -> &'static str {
"Do not perform operations that can block while holding a lock"
}
fn severity(&self) -> Severity {
Severity::Medium
}
fn category(&self) -> RuleCategory {
RuleCategory::Rule
}
fn cert_id(&self) -> &'static str {
"CON05-C"
}
fn check(&self, node: &Node, source: &str) -> Vec<RuleViolation> {
let mut violations = Vec::new();
self.check_function(node, source, &mut violations);
violations
}
}
impl Con05C {
fn check_function(&self, node: &Node, source: &str, violations: &mut Vec<RuleViolation>) {
if node.kind() == "function_definition" {
self.analyze_function_body(node, source, violations);
}
for i in 0..node.child_count() {
if let Some(child) = node.child(i) {
self.check_function(&child, source, violations);
}
}
}
fn analyze_function_body(
&self,
function_node: &Node,
source: &str,
violations: &mut Vec<RuleViolation>,
) {
if let Some(body) = function_node.child_by_field_name("body") {
let mut locked_regions = Vec::new();
self.find_locked_regions(&body, source, &mut locked_regions);
for (lock_call, unlock_call) in locked_regions {
self.check_for_blocking_ops_in_region(
&body,
source,
&lock_call,
&unlock_call,
violations,
);
}
}
}
fn find_locked_regions<'a>(
&self,
node: &Node<'a>,
source: &str,
regions: &mut Vec<(Node<'a>, Node<'a>)>,
) {
let mut locks = Vec::new();
let mut unlocks = Vec::new();
self.collect_lock_unlock_calls(node, source, &mut locks, &mut unlocks);
let mut used_unlocks = HashSet::new();
for lock in &locks {
for unlock in &unlocks {
if used_unlocks.contains(&unlock.id()) {
continue;
}
if unlock.start_byte() > lock.start_byte() {
regions.push((*lock, *unlock));
used_unlocks.insert(unlock.id());
break;
}
}
}
}
fn collect_lock_unlock_calls<'a>(
&self,
node: &Node<'a>,
source: &str,
locks: &mut Vec<Node<'a>>,
unlocks: &mut Vec<Node<'a>>,
) {
if node.kind() == "call_expression" {
if let Some(function) = node.child_by_field_name("function") {
let func_name = get_node_text(&function, source);
if func_name == "mtx_lock" || func_name == "pthread_mutex_lock" {
locks.push(*node);
} else if func_name == "mtx_unlock" || func_name == "pthread_mutex_unlock" {
unlocks.push(*node);
}
}
}
for i in 0..node.child_count() {
if let Some(child) = node.child(i) {
self.collect_lock_unlock_calls(&child, source, locks, unlocks);
}
}
}
fn check_for_blocking_ops_in_region(
&self,
body: &Node,
source: &str,
lock_call: &Node,
unlock_call: &Node,
violations: &mut Vec<RuleViolation>,
) {
let mut calls_in_region = Vec::new();
self.collect_calls_in_region(
body,
source,
lock_call.start_byte(),
unlock_call.start_byte(),
&mut calls_in_region,
);
for call in calls_in_region {
if let Some(function) = call.child_by_field_name("function") {
let func_name = get_node_text(&function, source);
if self.is_blocking_operation(&func_name) {
let start_point = call.start_position();
let call_text = get_node_text(&call, source);
violations.push(RuleViolation {
rule_id: self.rule_id().to_string(),
severity: Severity::High,
message: format!(
"Blocking operation '{}' called while holding a lock. This can cause other threads to block unnecessarily",
call_text
),
file_path: String::new(),
line: start_point.row + 1,
column: start_point.column + 1,
suggestion: Some(
"Move blocking I/O operations outside the locked region. Perform I/O before acquiring the lock or after releasing it".to_string()
),
..Default::default()
});
}
}
}
}
fn collect_calls_in_region<'a>(
&self,
node: &Node<'a>,
source: &str,
lock_start: usize,
unlock_start: usize,
calls: &mut Vec<Node<'a>>,
) {
if node.kind() == "call_expression" {
let call_start = node.start_byte();
if call_start > lock_start && call_start < unlock_start {
if let Some(function) = node.child_by_field_name("function") {
let func_name = get_node_text(&function, source);
if func_name != "mtx_unlock" && func_name != "pthread_mutex_unlock" {
calls.push(*node);
}
}
}
}
for i in 0..node.child_count() {
if let Some(child) = node.child(i) {
self.collect_calls_in_region(&child, source, lock_start, unlock_start, calls);
}
}
}
fn is_blocking_operation(&self, func_name: &str) -> bool {
matches!(
func_name,
"fopen"
| "fclose"
| "fread"
| "fwrite"
| "fprintf"
| "fscanf"
| "fgets"
| "fputs"
| "fflush"
| "open"
| "close"
| "read"
| "write"
| "printf"
| "scanf"
| "puts"
| "gets"
| "getchar"
| "putchar"
| "send"
| "recv"
| "sendto"
| "recvfrom"
| "connect"
| "accept"
| "listen"
| "sleep"
| "usleep"
| "nanosleep"
| "thrd_sleep"
)
}
}