use super::super::{CertRule, RuleViolation};
use crate::analyze::cfg;
use crate::analyze::const_eval;
use crate::analyze::context::ProjectContext;
use crate::analyze::function_summary::FunctionSummary;
use crate::manifest::{RuleCategory, Severity};
use crate::utility::cert_c::ast_utils::{get_node_text, is_function_parameter};
use std::cell::RefCell;
use std::collections::{HashMap, HashSet};
use tree_sitter::Node;
const TAINT_SOURCES: &[&str] = &[
"recv", "recvfrom", "recvmsg", "read", "fread", "fgets", "fgetws", "gets", "scanf", "fscanf",
"sscanf", "getenv", "getchar", "getwchar", "fgetc", "fgetwc", "getc", "getwc", "gets_s",
"fgets_s", "wscanf", "fwscanf", "swscanf",
];
const TAINT_PROPAGATORS: &[&str] = &[
"strcpy", "strncpy", "strcat", "strncat", "sprintf", "snprintf", "memcpy", "memmove", "wcscpy",
"wcsncpy", "wcscat", "wcsncat", "swprintf",
];
pub struct Str02C {
project_aliases: RefCell<HashMap<String, String>>,
current_aliases: RefCell<HashMap<String, String>>,
function_summaries: RefCell<HashMap<String, FunctionSummary>>,
callers: RefCell<HashMap<String, HashSet<String>>>,
}
impl Str02C {
pub fn new() -> Self {
Self {
project_aliases: RefCell::new(HashMap::new()),
current_aliases: RefCell::new(HashMap::new()),
function_summaries: RefCell::new(HashMap::new()),
callers: RefCell::new(HashMap::new()),
}
}
fn resolve_name(&self, name: &str) -> String {
let aliases = self.current_aliases.borrow();
if let Some(target) = aliases.get(name) {
target.clone()
} else {
name.to_string()
}
}
fn check_functions(&self, node: &Node, source: &str, violations: &mut Vec<RuleViolation>) {
if node.kind() == "function_definition" {
self.check_single_function(node, source, violations);
return;
}
if node.kind() == "call_expression" {
if self.find_containing_function(node).is_none() {
self.check_dangerous_function_call_legacy(node, source, violations);
}
}
for i in 0..node.child_count() {
if let Some(child) = node.child(i) {
self.check_functions(&child, source, violations);
}
}
}
fn find_containing_function<'a>(&self, node: &Node<'a>) -> Option<Node<'a>> {
let mut current = *node;
while let Some(parent) = current.parent() {
if parent.kind() == "function_definition" {
return Some(parent);
}
current = parent;
}
None
}
fn check_single_function(
&self,
func_node: &Node,
source: &str,
violations: &mut Vec<RuleViolation>,
) {
let mut tainted: HashSet<String> = HashSet::new();
self.collect_param_names(func_node, source, &mut tainted);
self.collect_tainted_vars(func_node, source, &mut tainted);
self.check_sinks(func_node, source, &tainted, func_node, violations);
}
fn collect_param_names(&self, func_node: &Node, source: &str, tainted: &mut HashSet<String>) {
if let Some(declarator) = func_node.child_by_field_name("declarator") {
self.find_param_names(&declarator, source, tainted);
}
}
fn find_param_names(&self, node: &Node, source: &str, tainted: &mut HashSet<String>) {
if node.kind() == "parameter_declaration" {
if let Some(decl) = node.child_by_field_name("declarator") {
let name = get_node_text(&decl, source);
let base = extract_base_var(&name);
if !base.is_empty() {
tainted.insert(base);
}
}
return;
}
for i in 0..node.child_count() {
if let Some(child) = node.child(i) {
self.find_param_names(&child, source, tainted);
}
}
}
fn collect_tainted_vars(&self, node: &Node, source: &str, tainted: &mut HashSet<String>) {
if node.kind() == "call_expression" {
self.check_taint_from_call(node, source, tainted);
}
for i in 0..node.child_count() {
if let Some(child) = node.child(i) {
self.collect_tainted_vars(&child, source, tainted);
}
}
}
fn check_taint_from_call(&self, call_node: &Node, source: &str, tainted: &mut HashSet<String>) {
let func_name = match call_node.child_by_field_name("function") {
Some(f) => {
let name = get_node_text(&f, source);
self.resolve_name(&name)
}
None => return,
};
let args_node = match call_node.child_by_field_name("arguments") {
Some(a) => a,
None => return,
};
let args = self.collect_arguments(&args_node, source);
if TAINT_SOURCES.contains(&func_name.as_str()) {
match func_name.as_str() {
"recv" | "recvfrom" | "recvmsg" | "read" | "fread" => {
if let Some(buf_name) = args.get(1) {
tainted.insert(extract_base_var(buf_name));
}
}
"fgets" | "fgetws" | "fgets_s" | "gets" | "gets_s" => {
if let Some(buf_name) = args.first() {
tainted.insert(extract_base_var(buf_name));
}
}
"scanf" | "fscanf" | "sscanf" | "wscanf" | "fwscanf" | "swscanf" => {
for arg in args.iter().skip(1) {
tainted.insert(extract_base_var(arg));
}
}
"getenv" => {
self.taint_assignment_target(call_node, source, tainted);
}
_ => {
self.taint_assignment_target(call_node, source, tainted);
}
}
}
if TAINT_PROPAGATORS.contains(&func_name.as_str()) {
let has_tainted_source = args.iter().skip(1).any(|arg| {
let base = extract_base_var(arg);
tainted.contains(&base)
});
if has_tainted_source {
if let Some(dest) = args.first() {
tainted.insert(extract_base_var(dest));
}
}
}
}
fn taint_assignment_target(
&self,
call_node: &Node,
source: &str,
tainted: &mut HashSet<String>,
) {
if let Some(parent) = call_node.parent() {
match parent.kind() {
"assignment_expression" => {
if let Some(left) = parent.child_by_field_name("left") {
tainted.insert(extract_base_var(&get_node_text(&left, source)));
}
}
"init_declarator" => {
if let Some(decl) = parent.child_by_field_name("declarator") {
tainted.insert(extract_base_var(&get_node_text(&decl, source)));
}
}
_ => {}
}
}
}
fn collect_arguments(&self, args_node: &Node, source: &str) -> Vec<String> {
let mut args = Vec::new();
for i in 0..args_node.child_count() {
if let Some(child) = args_node.child(i) {
if child.kind() != "(" && child.kind() != ")" && child.kind() != "," {
args.push(get_node_text(&child, source).to_string());
}
}
}
args
}
fn check_sinks(
&self,
node: &Node,
source: &str,
tainted: &HashSet<String>,
func_scope: &Node,
violations: &mut Vec<RuleViolation>,
) {
if node.kind() == "call_expression" {
self.check_dangerous_function_call(node, source, tainted, func_scope, violations);
}
for i in 0..node.child_count() {
if let Some(child) = node.child(i) {
self.check_sinks(&child, source, tainted, func_scope, violations);
}
}
}
fn check_dangerous_function_call(
&self,
node: &Node,
source: &str,
tainted: &HashSet<String>,
func_scope: &Node,
violations: &mut Vec<RuleViolation>,
) {
if node.kind() != "call_expression" {
return;
}
if let Some(function_node) = node.child_by_field_name("function") {
let func_name = get_node_text(&function_node, source);
let resolved = self.resolve_name(&func_name);
match resolved.as_str() {
"system" | "popen" => {
self.check_command_injection_risk(
node, source, &func_name, &resolved, tainted, func_scope, violations,
);
}
"execl" | "execle" | "execlp" | "execv" | "execvp" | "execve" | "_execl"
| "_execle" | "_execlp" | "_execv" | "_execvp" | "_execve" => {
self.check_exec_family_call(node, source, &func_name, &resolved, violations);
}
_ => {}
}
}
}
fn check_command_injection_risk(
&self,
node: &Node,
source: &str,
display_name: &str,
resolved_name: &str,
tainted: &HashSet<String>,
func_scope: &Node,
violations: &mut Vec<RuleViolation>,
) {
if let Some(args_node) = node.child_by_field_name("arguments") {
if let Some(first_arg) = self.get_first_argument(&args_node) {
if self.is_string_literal(&first_arg) {
return;
}
let arg_text = get_node_text(&first_arg, source);
let base_var = extract_base_var(&arg_text);
if !tainted.contains(&base_var) {
return;
}
if is_function_parameter(func_scope, &base_var, source)
&& !self.scope_has_taint_source(func_scope, source)
&& self.callers_are_all_clean(func_scope, source)
{
return;
}
let label = if display_name != resolved_name {
format!("{} (macro for {})", display_name, resolved_name)
} else {
resolved_name.to_string()
};
violations.push(RuleViolation {
rule_id: self.rule_id().to_string(),
severity: self.severity(),
message: format!(
"Call to {}() with tainted argument '{}'. Data from external input sources must be sanitized before passing to command processors.",
label, arg_text.trim()
),
file_path: String::new(),
line: node.start_position().row + 1,
column: node.start_position().column + 1,
suggestion: Some(
format!(
"Sanitize the string argument before passing to {}() by whitelisting acceptable characters, or use exec*() functions instead of system() to avoid shell interpretation.",
resolved_name
)
),
..Default::default()
});
}
}
}
fn check_dangerous_function_call_legacy(
&self,
node: &Node,
source: &str,
violations: &mut Vec<RuleViolation>,
) {
if node.kind() != "call_expression" {
return;
}
if let Some(function_node) = node.child_by_field_name("function") {
let func_name = get_node_text(&function_node, source);
let resolved = self.resolve_name(&func_name);
match resolved.as_str() {
"system" | "popen" => {
if let Some(args_node) = node.child_by_field_name("arguments") {
if let Some(first_arg) = self.get_first_argument(&args_node) {
if !self.is_string_literal(&first_arg) {
let arg_text = get_node_text(&first_arg, source);
let label = if func_name != resolved {
format!("{} (macro for {})", func_name, resolved)
} else {
resolved.to_string()
};
violations.push(RuleViolation {
rule_id: self.rule_id().to_string(),
severity: self.severity(),
message: format!(
"Call to {}() with non-literal argument '{}' detected. This may allow command injection if the string contains unsanitized user input or environment variables.",
label, arg_text.trim()
),
file_path: String::new(),
line: node.start_position().row + 1,
column: node.start_position().column + 1,
suggestion: Some(format!(
"Sanitize the string argument before passing to {}() by whitelisting acceptable characters, or use exec*() functions instead of system() to avoid shell interpretation.",
resolved
)),
..Default::default()
});
}
}
}
}
"execl" | "execle" | "execlp" | "execv" | "execvp" | "execve" | "_execl"
| "_execle" | "_execlp" | "_execv" | "_execvp" | "_execve" => {
self.check_exec_family_call(node, source, &func_name, &resolved, violations);
}
_ => {}
}
}
}
fn check_exec_family_call(
&self,
node: &Node,
source: &str,
display_name: &str,
resolved_name: &str,
violations: &mut Vec<RuleViolation>,
) {
if let Some(args_node) = node.child_by_field_name("arguments") {
let args_text = get_node_text(&args_node, source);
if args_text.contains("getenv(") {
if let Some(getenv_pos) = args_text.find("getenv(") {
let before_getenv = &args_text[..getenv_pos];
if before_getenv.contains("\"--\"") {
return; }
}
let scope = self.find_containing_scope(node);
if let Some(scope) = scope {
let scope_text = get_node_text(&scope, source);
if scope_text.contains("strspn(")
|| scope_text.contains("strcspn(")
|| scope_text.contains("ok_chars")
{
return; }
}
let label = if display_name != resolved_name {
format!("{} (macro for {})", display_name, resolved_name)
} else {
resolved_name.to_string()
};
violations.push(RuleViolation {
rule_id: self.rule_id().to_string(),
severity: Severity::Medium,
message: format!(
"Call to {}() with getenv() in arguments without '--' end-of-options marker. Environment variables may contain values that could be interpreted as command options.",
label
),
file_path: String::new(),
line: node.start_position().row + 1,
column: node.start_position().column + 1,
suggestion: Some(
"Add '--' argument before user-controlled data to prevent option injection, or sanitize the data before passing to exec*() functions."
.to_string(),
),
..Default::default()
});
}
}
}
fn find_containing_scope<'a>(&self, node: &Node<'a>) -> Option<Node<'a>> {
let mut current = *node;
while let Some(parent) = current.parent() {
if parent.kind() == "function_definition"
|| parent.kind() == "compound_statement"
|| parent.kind() == "translation_unit"
{
return Some(parent);
}
current = parent;
}
None
}
fn get_first_argument<'a>(&self, args_node: &Node<'a>) -> Option<Node<'a>> {
for i in 0..args_node.child_count() {
if let Some(child) = args_node.child(i) {
if child.kind() != "(" && child.kind() != ")" && child.kind() != "," {
return Some(child);
}
}
}
None
}
fn is_string_literal(&self, node: &Node) -> bool {
node.kind() == "string_literal" || node.kind() == "concatenated_string"
}
fn scope_has_taint_source(&self, scope: &Node, source: &str) -> bool {
let mut found = false;
self.walk_for_taint(scope, source, &mut found);
found
}
fn walk_for_taint(&self, node: &Node, source: &str, found: &mut bool) {
if *found {
return;
}
if node.kind() == "call_expression" {
if let Some(function) = node.child_by_field_name("function") {
let raw = get_node_text(&function, source);
let ident = trailing_identifier(raw);
let resolved = self.resolve_name(ident);
if TAINT_SOURCES.contains(&resolved.as_str()) || TAINT_SOURCES.contains(&ident) {
*found = true;
return;
}
}
}
for i in 0..node.child_count() {
if let Some(child) = node.child(i) {
self.walk_for_taint(&child, source, found);
if *found {
return;
}
}
}
}
fn callers_are_all_clean(&self, scope: &Node, source: &str) -> bool {
let Some(name) = cfg::get_function_name(scope, source) else {
return false;
};
let callers = self.callers.borrow();
let Some(root_callers) = callers.get(name) else {
return false;
};
if root_callers.is_empty() {
return false;
}
let summaries = self.function_summaries.borrow();
let mut visited: HashSet<String> = HashSet::new();
let mut stack: Vec<String> = root_callers.iter().cloned().collect();
while let Some(current) = stack.pop() {
if !visited.insert(current.clone()) {
continue;
}
match summaries.get(¤t) {
Some(s) if !s.has_env03_taint_source && !s.returns_tainted => {}
_ => return false,
}
if let Some(next) = callers.get(¤t) {
for c in next {
if !visited.contains(c) {
stack.push(c.clone());
}
}
}
}
true
}
}
fn trailing_identifier(name: &str) -> &str {
name.rsplit(|c: char| !c.is_alphanumeric() && c != '_')
.next()
.unwrap_or(name)
}
fn extract_base_var(expr: &str) -> String {
let s = expr.trim();
let s = s.strip_prefix('&').unwrap_or(s);
let s = s.strip_prefix('*').unwrap_or(s);
let s = s.trim();
let s = strip_casts(s);
s.chars()
.take_while(|c| c.is_alphanumeric() || *c == '_')
.collect()
}
fn strip_casts(s: &str) -> &str {
let mut s = s;
loop {
let trimmed = s.trim();
if !trimmed.starts_with('(') {
return trimmed;
}
if let Some(close) = trimmed.find(')') {
let inside = &trimmed[1..close];
if inside.contains('*')
|| matches!(
inside.trim(),
"char"
| "int"
| "long"
| "short"
| "unsigned"
| "signed"
| "void"
| "size_t"
| "ssize_t"
| "uint8_t"
| "int8_t"
)
{
s = &trimmed[close + 1..];
continue;
}
return inside;
}
return trimmed;
}
}
impl CertRule for Str02C {
fn rule_id(&self) -> &'static str {
"STR02-C"
}
fn description(&self) -> &'static str {
"Sanitize data passed to complex subsystems"
}
fn severity(&self) -> Severity {
Severity::High
}
fn category(&self) -> RuleCategory {
RuleCategory::Rule
}
fn cert_id(&self) -> &'static str {
"STR02-C"
}
fn set_project_context(&self, context: &ProjectContext) {
*self.project_aliases.borrow_mut() = context.macro_aliases.clone();
*self.function_summaries.borrow_mut() = context.function_summaries.clone();
let mut callers: HashMap<String, HashSet<String>> = HashMap::new();
for (caller, callees) in &context.call_graph {
for callee in callees {
callers
.entry(callee.clone())
.or_default()
.insert(caller.clone());
}
}
*self.callers.borrow_mut() = callers;
}
fn check(&self, node: &Node, source: &str) -> Vec<RuleViolation> {
let mut aliases = self.project_aliases.borrow().clone();
aliases.extend(const_eval::collect_macro_aliases(node, source));
*self.current_aliases.borrow_mut() = aliases;
let mut violations = Vec::new();
self.check_functions(node, source, &mut violations);
violations
}
}