use super::super::{CertRule, RuleViolation};
use crate::manifest::{RuleCategory, Severity};
use crate::utility::cert_c::ast_utils::get_node_text;
use std::collections::HashMap;
use tree_sitter::Node;
pub struct Str38C;
const NARROW_FUNCTIONS: &[&str] = &[
"strlen", "strcpy", "strncpy", "strcat", "strncat", "strcmp", "strncmp", "strchr", "strstr",
"strdup", "sprintf", "snprintf", "sscanf",
];
const WIDE_FUNCTIONS: &[&str] = &[
"wcslen", "wcscpy", "wcsncpy", "wcscat", "wcsncat", "wcscmp", "wcsncmp", "wcschr", "wcsstr",
"wcsdup", "swprintf", "swscanf",
];
impl CertRule for Str38C {
fn rule_id(&self) -> &'static str {
"STR38-C"
}
fn description(&self) -> &'static str {
"Do not confuse narrow and wide character strings and functions"
}
fn severity(&self) -> Severity {
Severity::High
}
fn category(&self) -> RuleCategory {
RuleCategory::Rule
}
fn cert_id(&self) -> &'static str {
"STR38-C"
}
fn check(&self, node: &Node, source: &str) -> Vec<RuleViolation> {
let mut violations = Vec::new();
if node.kind() == "function_definition" || node.kind() == "translation_unit" {
let scope = if node.kind() == "function_definition" {
node.child_by_field_name("body")
} else {
Some(*node)
};
if let Some(scope_node) = scope {
let mut var_types: HashMap<String, VarType> = HashMap::new();
self.collect_var_types(&scope_node, source, &mut var_types);
self.check_calls(&scope_node, source, &var_types, &mut violations);
}
}
for i in 0..node.child_count() {
if let Some(child) = node.child(i) {
violations.extend(self.check(&child, source));
}
}
violations
}
}
#[derive(Clone, Copy)]
enum VarType {
Wide, Narrow, }
impl Str38C {
fn collect_var_types(&self, node: &Node, source: &str, types: &mut HashMap<String, VarType>) {
if node.kind() == "declaration" {
let decl_text = get_node_text(&node, source);
if decl_text.contains("wchar_t") {
if let Some(var_name) = self.extract_var_name(node, source) {
types.insert(var_name, VarType::Wide);
}
} else if decl_text.contains("char") && !decl_text.contains("wchar_t") {
if let Some(var_name) = self.extract_var_name(node, source) {
types.insert(var_name, VarType::Narrow);
}
}
}
for i in 0..node.child_count() {
if let Some(child) = node.child(i) {
self.collect_var_types(&child, source, types);
}
}
}
fn extract_var_name(&self, decl: &Node, source: &str) -> Option<String> {
for i in 0..decl.child_count() {
if let Some(child) = decl.child(i) {
if child.kind() == "init_declarator" {
if let Some(declarator) = child.child_by_field_name("declarator") {
return self.get_identifier(&declarator, source);
}
} else if child.kind() == "array_declarator"
|| child.kind() == "pointer_declarator"
|| child.kind() == "identifier"
{
if let Some(name) = self.get_identifier(&child, source) {
return Some(name);
}
}
}
}
None
}
fn get_identifier(&self, node: &Node, source: &str) -> Option<String> {
if node.kind() == "identifier" {
return Some(get_node_text(node, source).to_string());
}
for i in 0..node.child_count() {
if let Some(child) = node.child(i) {
if let Some(id) = self.get_identifier(&child, source) {
return Some(id);
}
}
}
None
}
fn check_calls(
&self,
node: &Node,
source: &str,
types: &HashMap<String, VarType>,
violations: &mut Vec<RuleViolation>,
) {
if node.kind() == "call_expression" {
if let Some(function) = node.child_by_field_name("function") {
let func_name = get_node_text(&function, source);
let is_narrow = NARROW_FUNCTIONS.contains(&func_name);
let is_wide = WIDE_FUNCTIONS.contains(&func_name);
if is_narrow || is_wide {
if let Some(args) = node.child_by_field_name("arguments") {
let arg_text = self.get_first_arg_text(&args, source);
if let Some(arg_text) = arg_text {
for (var_name, var_type) in types {
if arg_text.contains(var_name) {
let mismatch = match (is_narrow, var_type) {
(true, VarType::Wide) => true, (false, VarType::Narrow) => true, _ => false,
};
if mismatch {
let expected = if is_narrow { "wide" } else { "narrow" };
let actual = if matches!(var_type, VarType::Wide) {
"wide"
} else {
"narrow"
};
violations.push(RuleViolation {
rule_id: self.rule_id().to_string(),
severity: Severity::High,
message: format!(
"{} string function '{}' used on {} string variable '{}' - type mismatch",
if is_narrow { "Narrow" } else { "Wide" },
func_name,
actual,
var_name
),
file_path: String::new(),
line: node.start_position().row + 1,
column: node.start_position().column + 1,
suggestion: Some(format!(
"Use {} string function instead (e.g., {})",
expected,
if is_narrow { "wcs* functions" } else { "str* functions" }
)),
..Default::default()
});
break;
}
}
}
}
}
}
}
}
for i in 0..node.child_count() {
if let Some(child) = node.child(i) {
self.check_calls(&child, source, types, violations);
}
}
}
fn get_first_arg_text(&self, args: &Node, source: &str) -> Option<String> {
for i in 0..args.child_count() {
if let Some(child) = args.child(i) {
if child.kind() != "(" && child.kind() != ")" && child.kind() != "," {
return Some(get_node_text(&child, source).to_string());
}
}
}
None
}
}