use super::super::{CertRule, RuleViolation};
use crate::manifest::{RuleCategory, Severity};
use crate::utility::cert_c::ast_utils::get_node_text;
use tree_sitter::Node;
pub struct Con07C;
impl CertRule for Con07C {
fn rule_id(&self) -> &'static str {
"CON07-C"
}
fn description(&self) -> &'static str {
"Ensure that compound operations on shared variables are atomic"
}
fn severity(&self) -> Severity {
Severity::Medium
}
fn category(&self) -> RuleCategory {
RuleCategory::Recommendation
}
fn cert_id(&self) -> &'static str {
"CON07-C"
}
fn check(&self, node: &Node, source: &str) -> Vec<RuleViolation> {
let mut violations = Vec::new();
let static_vars = self.collect_static_variables(node, source);
self.check_node(node, source, &static_vars, &mut violations);
violations
}
}
impl Con07C {
fn collect_static_variables(&self, node: &Node, source: &str) -> Vec<String> {
let mut static_vars = Vec::new();
self.find_static_variables(node, source, &mut static_vars);
static_vars
}
fn find_static_variables(&self, node: &Node, source: &str, static_vars: &mut Vec<String>) {
if node.kind() == "declaration" {
let mut is_static = false;
let mut is_file_scope = node
.parent()
.is_some_and(|p| p.kind() == "translation_unit");
let mut is_mutex_or_thread_type = false;
let mut var_names = Vec::new();
for i in 0..node.child_count() {
if let Some(child) = node.child(i) {
match child.kind() {
"storage_class_specifier" => {
if get_node_text(&child, source) == "static" {
is_static = true;
}
if get_node_text(&child, source) == "extern" {
is_file_scope = false;
}
}
"type_identifier" => {
let type_name = get_node_text(&child, source);
if type_name.contains("mutex")
|| type_name.contains("pthread")
|| type_name.contains("thrd")
|| type_name.contains("cnd")
{
is_mutex_or_thread_type = true;
}
}
"init_declarator" => {
if let Some(declarator) = child.child_by_field_name("declarator") {
if let Some(name) = self.get_identifier_name(&declarator, source) {
var_names.push(name);
}
}
}
"identifier" => {
var_names.push(get_node_text(&child, source).to_string());
}
_ => {}
}
}
}
if is_static || (is_file_scope && !is_mutex_or_thread_type) {
static_vars.extend(var_names);
}
}
for i in 0..node.child_count() {
if let Some(child) = node.child(i) {
self.find_static_variables(&child, source, static_vars);
}
}
}
fn get_identifier_name(&self, node: &Node, source: &str) -> Option<String> {
if node.kind() == "identifier" {
return Some(get_node_text(node, source).to_string());
}
for i in 0..node.child_count() {
if let Some(child) = node.child(i) {
if let Some(name) = self.get_identifier_name(&child, source) {
return Some(name);
}
}
}
None
}
fn check_node(
&self,
node: &Node,
source: &str,
static_vars: &[String],
violations: &mut Vec<RuleViolation>,
) {
if node.kind() == "function_definition" {
self.check_function_for_non_atomic_operations(node, source, static_vars, violations);
}
for i in 0..node.child_count() {
if let Some(child) = node.child(i) {
self.check_node(&child, source, static_vars, violations);
}
}
}
fn check_function_for_non_atomic_operations(
&self,
function_node: &Node,
source: &str,
static_vars: &[String],
violations: &mut Vec<RuleViolation>,
) {
let func_name = self
.get_function_name(function_node, source)
.unwrap_or_else(|| "<unknown>".to_string());
if func_name.to_lowercase().contains("init") {
return;
}
if self.uses_mutex_lock(function_node, source) {
return;
}
if self.uses_atomic_operations(function_node, source) {
return;
}
let body = match function_node.child_by_field_name("body") {
Some(b) => b,
None => return,
};
let static_var_accesses = self.find_static_var_accesses(&body, source, static_vars);
if static_var_accesses.len() > 1 {
let has_compound_write = static_var_accesses
.iter()
.any(|v| self.has_compound_operation_on_var(&body, source, v));
if has_compound_write {
violations.push(RuleViolation {
rule_id: self.rule_id().to_string(),
severity: Severity::Medium,
message: format!(
"Function '{}' performs compound operation on shared static variables ({}) without synchronization",
func_name,
static_var_accesses.join(", ")
),
file_path: String::new(),
line: function_node.start_position().row + 1,
column: function_node.start_position().column + 1,
suggestion: Some(
"Use mutex locks (mtx_lock/mtx_unlock) or atomic operations to ensure atomicity".to_string()
),
..Default::default()
});
}
} else if static_var_accesses.len() == 1 {
if self.has_compound_operation_on_var(&body, source, &static_var_accesses[0]) {
violations.push(RuleViolation {
rule_id: self.rule_id().to_string(),
severity: Severity::Medium,
message: format!(
"Function '{}' performs compound operation on shared static variable '{}' without synchronization",
func_name,
static_var_accesses[0]
),
file_path: String::new(),
line: function_node.start_position().row + 1,
column: function_node.start_position().column + 1,
suggestion: Some(
"Use mutex locks (mtx_lock/mtx_unlock) or atomic operations to ensure atomicity".to_string()
),
..Default::default()
});
}
}
}
fn get_function_name(&self, function_node: &Node, source: &str) -> Option<String> {
for i in 0..function_node.child_count() {
if let Some(child) = function_node.child(i) {
if child.kind() == "function_declarator" {
if let Some(name) = self.get_identifier_name(&child, source) {
return Some(name);
}
}
}
}
None
}
fn uses_mutex_lock(&self, node: &Node, source: &str) -> bool {
if node.kind() == "call_expression" {
if let Some(func) = node.child_by_field_name("function") {
let func_name = get_node_text(&func, source);
if matches!(
func_name,
"mtx_lock"
| "mtx_unlock"
| "pthread_mutex_lock"
| "pthread_mutex_unlock"
| "stdThreadLockAcquire"
| "stdThreadLockRelease"
) {
return true;
}
}
}
for i in 0..node.child_count() {
if let Some(child) = node.child(i) {
if self.uses_mutex_lock(&child, source) {
return true;
}
}
}
false
}
fn uses_atomic_operations(&self, node: &Node, source: &str) -> bool {
if node.kind() == "call_expression" {
if let Some(func) = node.child_by_field_name("function") {
let func_name = get_node_text(&func, source);
if func_name.starts_with("atomic_") {
return true;
}
}
}
if node.kind() == "type_qualifier" {
let text = get_node_text(node, source);
if text == "_Atomic" {
return true;
}
}
for i in 0..node.child_count() {
if let Some(child) = node.child(i) {
if self.uses_atomic_operations(&child, source) {
return true;
}
}
}
false
}
fn find_static_var_accesses(
&self,
node: &Node,
source: &str,
static_vars: &[String],
) -> Vec<String> {
let mut accesses = Vec::new();
self.collect_static_var_accesses(node, source, static_vars, &mut accesses);
accesses.sort();
accesses.dedup();
accesses
}
fn collect_static_var_accesses(
&self,
node: &Node,
source: &str,
static_vars: &[String],
accesses: &mut Vec<String>,
) {
if node.kind() == "identifier" {
let name = get_node_text(node, source);
if static_vars.contains(&name.to_string()) {
accesses.push(name.to_string());
}
}
for i in 0..node.child_count() {
if let Some(child) = node.child(i) {
self.collect_static_var_accesses(&child, source, static_vars, accesses);
}
}
}
fn has_compound_operation_on_var(&self, node: &Node, source: &str, var_name: &str) -> bool {
if node.kind() == "assignment_expression" {
let left = node.child_by_field_name("left");
let right = node.child_by_field_name("right");
if let Some(left_node) = left {
let left_text = get_node_text(&left_node, source);
if left_text == var_name {
if let Some(operator) = node.child_by_field_name("operator") {
let op_text = get_node_text(&operator, source);
if matches!(
op_text,
"+=" | "-=" | "*=" | "/=" | "%=" | "<<=" | ">>=" | "&=" | "^=" | "|="
) {
return true;
}
}
if let Some(right_node) = right {
let right_text = get_node_text(&right_node, source);
if right_text.contains(var_name) {
return true;
}
}
}
}
}
if matches!(node.kind(), "update_expression") {
if let Some(argument) = node.child_by_field_name("argument") {
if get_node_text(&argument, source) == var_name {
return true;
}
}
}
for i in 0..node.child_count() {
if let Some(child) = node.child(i) {
if self.has_compound_operation_on_var(&child, source, var_name) {
return true;
}
}
}
false
}
}