use super::super::{CertRule, RuleViolation};
use crate::manifest::{RuleCategory, Severity};
use crate::utility::cert_c::ast_utils::get_node_text;
use std::collections::{HashMap, HashSet};
use tree_sitter::Node;
pub struct Con32C;
impl CertRule for Con32C {
fn rule_id(&self) -> &'static str {
"CON32-C"
}
fn description(&self) -> &'static str {
"Prevent data races when accessing bit-fields from multiple threads"
}
fn severity(&self) -> Severity {
Severity::Medium
}
fn category(&self) -> RuleCategory {
RuleCategory::Rule
}
fn cert_id(&self) -> &'static str {
"CON32-C"
}
fn check(&self, node: &Node, source: &str) -> Vec<RuleViolation> {
let mut violations = Vec::new();
let bitfield_structs = self.collect_bitfield_structs(node, source);
self.check_node(node, source, &bitfield_structs, &mut violations);
violations
}
}
impl Con32C {
fn collect_bitfield_structs(
&self,
node: &Node,
source: &str,
) -> HashMap<String, HashSet<String>> {
let mut bitfield_structs = HashMap::new();
self.find_bitfield_structs(node, source, &mut bitfield_structs);
bitfield_structs
}
fn find_bitfield_structs(
&self,
node: &Node,
source: &str,
bitfield_structs: &mut HashMap<String, HashSet<String>>,
) {
if node.kind() == "struct_specifier" {
if let Some(name_node) = node.child_by_field_name("name") {
let struct_name = get_node_text(&name_node, source).to_string();
if let Some(body) = node.child_by_field_name("body") {
let bitfield_members = self.find_bitfield_members(&body, source);
if !bitfield_members.is_empty() {
bitfield_structs.insert(struct_name, bitfield_members);
}
}
}
}
for i in 0..node.child_count() {
if let Some(child) = node.child(i) {
self.find_bitfield_structs(&child, source, bitfield_structs);
}
}
}
fn find_bitfield_members(&self, body: &Node, source: &str) -> HashSet<String> {
let mut members = HashSet::new();
for i in 0..body.child_count() {
if let Some(child) = body.child(i) {
if child.kind() == "field_declaration" {
if self.is_bitfield_declaration(&child, source) {
if let Some(name) = self.get_field_name(&child, source) {
members.insert(name);
}
}
}
}
}
members
}
fn is_bitfield_declaration(&self, node: &Node, _source: &str) -> bool {
for i in 0..node.child_count() {
if let Some(child) = node.child(i) {
if child.kind() == "bitfield_clause" {
return true;
}
}
}
false
}
fn get_field_name(&self, node: &Node, source: &str) -> Option<String> {
for i in 0..node.child_count() {
if let Some(child) = node.child(i) {
match child.kind() {
"field_identifier" => {
return Some(get_node_text(&child, source).to_string());
}
"field_declarator" => {
if let Some(name) = self.get_identifier_name(&child, source) {
return Some(name);
}
}
"identifier" => {
return Some(get_node_text(&child, source).to_string());
}
_ => {}
}
}
}
None
}
fn get_identifier_name(&self, node: &Node, source: &str) -> Option<String> {
if node.kind() == "identifier" {
return Some(get_node_text(node, source).to_string());
}
for i in 0..node.child_count() {
if let Some(child) = node.child(i) {
if let Some(name) = self.get_identifier_name(&child, source) {
return Some(name);
}
}
}
None
}
fn check_node(
&self,
node: &Node,
source: &str,
bitfield_structs: &HashMap<String, HashSet<String>>,
violations: &mut Vec<RuleViolation>,
) {
if node.kind() == "function_definition" {
self.check_function_for_bitfield_access(node, source, bitfield_structs, violations);
}
for i in 0..node.child_count() {
if let Some(child) = node.child(i) {
self.check_node(&child, source, bitfield_structs, violations);
}
}
}
fn check_function_for_bitfield_access(
&self,
function_node: &Node,
source: &str,
bitfield_structs: &HashMap<String, HashSet<String>>,
violations: &mut Vec<RuleViolation>,
) {
let func_name = self
.get_function_name(function_node, source)
.unwrap_or_else(|| "<unknown>".to_string());
if !self.is_potential_thread_function(function_node, source, &func_name) {
return;
}
let body = match function_node.child_by_field_name("body") {
Some(b) => b,
None => return,
};
let has_mutex = self.uses_mutex_lock(&body, source);
let bitfield_accesses = self.find_bitfield_accesses(&body, source, bitfield_structs);
if !bitfield_accesses.is_empty() && !has_mutex {
for (struct_name, member_name, line) in bitfield_accesses {
violations.push(RuleViolation {
rule_id: self.rule_id().to_string(),
severity: Severity::Medium,
message: format!(
"Function '{}' accesses bit-field '{}.{}' without mutex protection in a multi-threaded context",
func_name, struct_name, member_name
),
file_path: String::new(),
line,
column: 0,
suggestion: Some(
"Protect bit-field accesses with mutex locks or use separate byte-sized members instead of bit-fields".to_string()
),
..Default::default()
});
}
}
}
fn get_function_name(&self, function_node: &Node, source: &str) -> Option<String> {
for i in 0..function_node.child_count() {
if let Some(child) = function_node.child(i) {
if child.kind() == "function_declarator" {
if let Some(name) = self.get_identifier_name(&child, source) {
return Some(name);
}
}
}
}
None
}
fn is_potential_thread_function(
&self,
function_node: &Node,
source: &str,
func_name: &str,
) -> bool {
if func_name.to_lowercase().contains("thread") {
return true;
}
for i in 0..function_node.child_count() {
if let Some(child) = function_node.child(i) {
if child.kind() == "function_declarator" {
if let Some(params) = child.child_by_field_name("parameters") {
for j in 0..params.child_count() {
if let Some(param) = params.child(j) {
if param.kind() == "parameter_declaration" {
let param_text = get_node_text(¶m, source);
if param_text.contains("void") && param_text.contains("*") {
return true;
}
}
}
}
}
}
}
}
false
}
fn uses_mutex_lock(&self, node: &Node, source: &str) -> bool {
if node.kind() == "call_expression" {
if let Some(func) = node.child_by_field_name("function") {
let func_name = get_node_text(&func, source);
if matches!(
func_name,
"mtx_lock" | "mtx_unlock" | "pthread_mutex_lock" | "pthread_mutex_unlock"
) {
return true;
}
}
}
for i in 0..node.child_count() {
if let Some(child) = node.child(i) {
if self.uses_mutex_lock(&child, source) {
return true;
}
}
}
false
}
fn find_bitfield_accesses(
&self,
node: &Node,
source: &str,
bitfield_structs: &HashMap<String, HashSet<String>>,
) -> Vec<(String, String, usize)> {
let mut accesses = Vec::new();
self.collect_bitfield_accesses(node, source, bitfield_structs, &mut accesses);
accesses
}
fn collect_bitfield_accesses(
&self,
node: &Node,
source: &str,
bitfield_structs: &HashMap<String, HashSet<String>>,
accesses: &mut Vec<(String, String, usize)>,
) {
if node.kind() == "field_expression" {
if let Some(field_node) = node.child_by_field_name("field") {
let field_name = get_node_text(&field_node, source).to_string();
if let Some(object) = node.child_by_field_name("argument") {
let object_text = get_node_text(&object, source);
for (struct_name, members) in bitfield_structs {
if members.contains(&field_name) {
let line = node.start_position().row + 1;
accesses.push((struct_name.clone(), field_name.clone(), line));
break;
}
}
if object_text.contains('.') {
for (struct_name, members) in bitfield_structs {
if members.contains(&field_name) {
let line = node.start_position().row + 1;
accesses.push((struct_name.clone(), field_name.clone(), line));
break;
}
}
}
}
}
}
for i in 0..node.child_count() {
if let Some(child) = node.child(i) {
self.collect_bitfield_accesses(&child, source, bitfield_structs, accesses);
}
}
}
}