use super::super::{CertRule, RuleViolation};
use crate::manifest::{RuleCategory, Severity};
use crate::utility::cert_c::ast_utils;
use std::collections::HashSet;
use tree_sitter::Node;
pub struct Exp11C;
impl CertRule for Exp11C {
fn rule_id(&self) -> &'static str {
"EXP11-C"
}
fn description(&self) -> &'static str {
"Do not make assumptions regarding the layout of structures with bit-fields"
}
fn severity(&self) -> Severity {
Severity::Medium
}
fn category(&self) -> RuleCategory {
RuleCategory::Rule
}
fn cert_id(&self) -> &'static str {
"EXP11-C"
}
fn check(&self, node: &Node, source: &str) -> Vec<RuleViolation> {
let mut violations = Vec::new();
if node.kind() != "translation_unit" {
return violations;
}
let bitfield_structs = collect_bitfield_structs(node, source);
let bitfield_vars = collect_bitfield_variables(node, source, &bitfield_structs);
find_bitfield_pointer_casts(node, source, &bitfield_vars, &mut violations);
violations
}
}
fn collect_bitfield_structs<'a>(root: &Node, source: &'a str) -> HashSet<&'a str> {
let mut bitfield_structs = HashSet::new();
let mut cursor = root.walk();
for child in root.children(&mut cursor) {
collect_bitfield_structs_recursive(&child, source, &mut bitfield_structs);
}
bitfield_structs
}
fn collect_bitfield_structs_recursive<'a>(
node: &Node,
source: &'a str,
bitfield_structs: &mut HashSet<&'a str>,
) {
if node.kind() == "struct_specifier" {
if let Some(body) = node.child_by_field_name("body") {
if contains_bitfield(&body) {
if let Some(name_node) = node.child_by_field_name("name") {
let struct_name = ast_utils::get_node_text(&name_node, source);
bitfield_structs.insert(struct_name);
}
}
}
}
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
collect_bitfield_structs_recursive(&child, source, bitfield_structs);
}
}
fn contains_bitfield(body: &Node) -> bool {
let mut cursor = body.walk();
for child in body.children(&mut cursor) {
if child.kind() == "field_declaration" {
let mut field_cursor = child.walk();
for field_child in child.children(&mut field_cursor) {
if field_child.kind() == "bitfield_clause" {
return true;
}
}
}
}
false
}
fn collect_bitfield_variables<'a>(
root: &Node,
source: &'a str,
bitfield_structs: &HashSet<&'a str>,
) -> HashSet<&'a str> {
let mut bitfield_vars = HashSet::new();
let mut cursor = root.walk();
for child in root.children(&mut cursor) {
collect_bitfield_vars_recursive(&child, source, bitfield_structs, &mut bitfield_vars);
}
bitfield_vars
}
fn collect_bitfield_vars_recursive<'a>(
node: &Node,
source: &'a str,
bitfield_structs: &HashSet<&'a str>,
bitfield_vars: &mut HashSet<&'a str>,
) {
if node.kind() == "declaration" {
if let Some(type_node) = node.child_by_field_name("type") {
let type_text = ast_utils::get_node_text(&type_node, source);
if let Some(struct_name) = extract_struct_name(&type_text) {
if bitfield_structs.contains(struct_name) {
extract_declared_variable_names(node, source, bitfield_vars);
}
}
}
}
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
collect_bitfield_vars_recursive(&child, source, bitfield_structs, bitfield_vars);
}
}
fn extract_struct_name(type_text: &str) -> Option<&str> {
let trimmed = type_text.trim();
if let Some(rest) = trimmed.strip_prefix("struct ") {
Some(rest.trim())
} else {
None
}
}
fn extract_declared_variable_names<'a>(
decl_node: &Node,
source: &'a str,
var_names: &mut HashSet<&'a str>,
) {
let mut cursor = decl_node.walk();
for child in decl_node.children(&mut cursor) {
match child.kind() {
"init_declarator" => {
if let Some(declarator) = child.child_by_field_name("declarator") {
if let Some(name) = extract_identifier_from_declarator(&declarator, source) {
var_names.insert(name);
}
}
}
"identifier" | "pointer_declarator" | "array_declarator" => {
if let Some(name) = extract_identifier_from_declarator(&child, source) {
var_names.insert(name);
}
}
_ => {}
}
}
}
fn extract_identifier_from_declarator<'a>(declarator: &Node, source: &'a str) -> Option<&'a str> {
match declarator.kind() {
"identifier" => Some(ast_utils::get_node_text(declarator, source)),
"pointer_declarator" | "array_declarator" => {
if let Some(child_declarator) = declarator.child_by_field_name("declarator") {
extract_identifier_from_declarator(&child_declarator, source)
} else {
None
}
}
_ => None,
}
}
fn find_bitfield_pointer_casts(
node: &Node,
source: &str,
bitfield_vars: &HashSet<&str>,
violations: &mut Vec<RuleViolation>,
) {
if node.kind() == "cast_expression" {
check_cast_expression(node, source, bitfield_vars, violations);
}
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
find_bitfield_pointer_casts(&child, source, bitfield_vars, violations);
}
}
fn check_cast_expression(
node: &Node,
source: &str,
bitfield_vars: &HashSet<&str>,
violations: &mut Vec<RuleViolation>,
) {
if let Some(type_node) = node.child_by_field_name("type") {
let type_text = ast_utils::get_node_text(&type_node, source);
if is_char_pointer_type(&type_text) {
if let Some(value_node) = node.child_by_field_name("value") {
if let Some(var_name) = extract_address_of_variable(&value_node, source) {
if bitfield_vars.contains(var_name) {
let start_point = node.start_position();
violations.push(RuleViolation {
rule_id: "EXP11-C".to_string(),
severity: Severity::Medium,
message: format!(
"Casting address of bit-field structure '{}' to pointer type makes assumptions about bit-field layout. Bit-field layout is implementation-defined.",
var_name
),
file_path: String::new(),
line: start_point.row + 1,
column: start_point.column + 1,
suggestion: Some(
"Access bit-fields directly through structure members instead of pointer arithmetic".to_string()
),
..Default::default()
});
}
}
}
}
}
}
fn is_char_pointer_type(type_text: &str) -> bool {
let normalized = type_text.replace(' ', "");
normalized.contains("char*") || normalized.contains("unsignedchar*")
}
fn extract_address_of_variable<'a>(node: &Node, source: &'a str) -> Option<&'a str> {
if node.kind() == "pointer_expression" {
let mut cursor = node.walk();
let mut has_ampersand = false;
for child in node.children(&mut cursor) {
if child.kind() == "&" {
has_ampersand = true;
break;
}
}
if has_ampersand {
if let Some(argument) = node.child_by_field_name("argument") {
if argument.kind() == "identifier" {
return Some(ast_utils::get_node_text(&argument, source));
}
}
}
}
None
}