use super::super::{CertRule, RuleViolation};
use crate::manifest::{RuleCategory, Severity};
use crate::utility::cert_c::ast_utils::get_node_text;
use std::collections::{HashMap, HashSet};
use tree_sitter::Node;
pub struct Con30C;
#[allow(dead_code)]
#[derive(Debug, Clone)]
struct TssKeyInfo {
key_name: String,
has_destructor: bool,
create_line: usize,
create_column: usize,
}
impl CertRule for Con30C {
fn rule_id(&self) -> &'static str {
"CON30-C"
}
fn description(&self) -> &'static str {
"Clean up thread-specific storage"
}
fn severity(&self) -> Severity {
Severity::Medium
}
fn category(&self) -> RuleCategory {
RuleCategory::Recommendation
}
fn cert_id(&self) -> &'static str {
"CON30-C"
}
fn check(&self, node: &Node, source: &str) -> Vec<RuleViolation> {
let mut violations = Vec::new();
let mut tss_keys: HashMap<String, TssKeyInfo> = HashMap::new();
let mut tss_set_calls: HashSet<String> = HashSet::new();
let mut tss_get_freed: HashSet<String> = HashSet::new();
self.analyze_tss_operations(
node,
source,
&mut tss_keys,
&mut tss_set_calls,
&mut tss_get_freed,
);
for (key_name, key_info) in &tss_keys {
if tss_set_calls.contains(key_name) {
if !key_info.has_destructor {
if !tss_get_freed.contains(key_name) {
violations.push(RuleViolation {
rule_id: self.rule_id().to_string(),
message: format!(
"Thread-specific storage key '{}' created without destructor and \
memory stored via tss_set() is never freed. Register a destructor \
in tss_create() or explicitly free(tss_get({})).",
key_name, key_name
),
severity: self.severity(),
line: key_info.create_line,
column: key_info.create_column,
file_path: String::new(),
suggestion: Some(format!(
"Either register a destructor: tss_create(&{}, free) or \
explicitly cleanup: free(tss_get({}))",
key_name, key_name
)),
requires_manual_review: None,
});
}
}
}
}
violations
}
}
impl Con30C {
fn analyze_tss_operations(
&self,
node: &Node,
source: &str,
tss_keys: &mut HashMap<String, TssKeyInfo>,
tss_set_calls: &mut HashSet<String>,
tss_get_freed: &mut HashSet<String>,
) {
if node.kind() == "call_expression" {
if let Some(function) = node.child_by_field_name("function") {
let func_name = get_node_text(&function, source);
if func_name == "tss_create" {
if let Some((key_name, has_destructor)) =
self.extract_tss_create_info(node, source)
{
tss_keys.insert(
key_name.clone(),
TssKeyInfo {
key_name,
has_destructor,
create_line: node.start_position().row + 1,
create_column: node.start_position().column + 1,
},
);
}
}
if func_name == "tss_set" {
if let Some(key_name) = self.extract_tss_key_name(node, source) {
tss_set_calls.insert(key_name);
}
}
if func_name == "free" {
if let Some(key_name) = self.check_tss_get_in_free(node, source) {
tss_get_freed.insert(key_name);
}
}
}
}
for i in 0..node.child_count() {
if let Some(child) = node.child(i) {
self.analyze_tss_operations(&child, source, tss_keys, tss_set_calls, tss_get_freed);
}
}
}
fn extract_tss_create_info(&self, call_node: &Node, source: &str) -> Option<(String, bool)> {
if let Some(args) = call_node.child_by_field_name("arguments") {
let arg_list = self.get_arguments(args, source);
if arg_list.len() >= 2 {
let key_arg = arg_list[0].trim();
let key_name = key_arg
.strip_prefix('&')
.map_or_else(|| key_arg.to_string(), |s| s.trim().to_string());
let destructor_arg = arg_list[1].trim();
let has_destructor = destructor_arg != "NULL"
&& destructor_arg != "0"
&& destructor_arg != "nullptr"
&& !destructor_arg.is_empty();
return Some((key_name, has_destructor));
}
}
None
}
fn extract_tss_key_name(&self, call_node: &Node, source: &str) -> Option<String> {
if let Some(args) = call_node.child_by_field_name("arguments") {
let arg_list = self.get_arguments(args, source);
if !arg_list.is_empty() {
return Some(arg_list[0].trim().to_string());
}
}
None
}
fn check_tss_get_in_free(&self, call_node: &Node, source: &str) -> Option<String> {
if let Some(args) = call_node.child_by_field_name("arguments") {
for i in 0..args.child_count() {
if let Some(child) = args.child(i) {
if let Some(key) = self.find_tss_get_key(&child, source) {
return Some(key);
}
}
}
}
None
}
fn find_tss_get_key(&self, node: &Node, source: &str) -> Option<String> {
if node.kind() == "call_expression" {
if let Some(function) = node.child_by_field_name("function") {
let func_name = get_node_text(&function, source);
if func_name == "tss_get" {
return self.extract_tss_key_name(node, source);
}
}
}
for i in 0..node.child_count() {
if let Some(child) = node.child(i) {
if let Some(key) = self.find_tss_get_key(&child, source) {
return Some(key);
}
}
}
None
}
fn get_arguments(&self, args_node: Node, source: &str) -> Vec<String> {
let mut arguments = Vec::new();
for i in 0..args_node.child_count() {
if let Some(child) = args_node.child(i) {
let kind = child.kind();
if kind != "," && kind != "(" && kind != ")" {
let arg_text = get_node_text(&child, source).to_string();
arguments.push(arg_text);
}
}
}
arguments
}
}