use crate::manifest::{RuleCategory, Severity};
use crate::prelude::RuleViolation;
use crate::rules::cert_c::CertRule;
use crate::utility::cert_c::ast_utils::get_node_text;
use std::collections::{HashMap, HashSet};
use tree_sitter::Node;
pub struct Str32C;
impl CertRule for Str32C {
fn rule_id(&self) -> &'static str {
"STR32-C"
}
fn cert_id(&self) -> &'static str {
"STR32"
}
fn description(&self) -> &'static str {
"Do not pass a non-null-terminated character sequence to a library function that expects a string"
}
fn severity(&self) -> Severity {
Severity::High
}
fn category(&self) -> RuleCategory {
RuleCategory::Rule
}
fn check(&self, node: &Node, source: &str) -> Vec<RuleViolation> {
let mut violations = Vec::new();
let mut unsafe_arrays: HashSet<String> = HashSet::new();
let mut array_locations: HashMap<String, (usize, usize)> = HashMap::new();
self.find_unsafe_arrays(node, source, &mut unsafe_arrays, &mut array_locations);
self.find_explicit_null_termination(node, source, &mut unsafe_arrays, &array_locations);
self.check_unsafe_usage(
node,
source,
&unsafe_arrays,
&array_locations,
&mut violations,
);
violations
}
}
impl Str32C {
fn find_unsafe_arrays(
&self,
node: &Node,
source: &str,
unsafe_arrays: &mut HashSet<String>,
array_locations: &mut HashMap<String, (usize, usize)>,
) {
if node.kind() == "declaration" {
self.check_declaration_for_unsafe_array(node, source, unsafe_arrays, array_locations);
}
if node.kind() == "call_expression" {
self.check_strncpy_call(node, source, unsafe_arrays, array_locations);
self.check_realloc_call(node, source, unsafe_arrays, array_locations);
}
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
self.find_unsafe_arrays(&child, source, unsafe_arrays, array_locations);
}
}
fn check_declaration_for_unsafe_array(
&self,
node: &Node,
source: &str,
unsafe_arrays: &mut HashSet<String>,
array_locations: &mut HashMap<String, (usize, usize)>,
) {
if let Some(type_node) = node.child_by_field_name("type") {
let type_text = get_node_text(&type_node, source).trim();
if !type_text.contains("char") {
return;
}
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
if child.kind() == "init_declarator" {
if let Some(declarator) = child.child_by_field_name("declarator") {
if declarator.kind() == "array_declarator" {
if let Some(name_node) = declarator.child_by_field_name("declarator") {
let array_name = get_node_text(&name_node, source).trim();
if let Some(size_node) = declarator.child_by_field_name("size") {
let size_text = get_node_text(&size_node, source).trim();
if let Ok(array_size) = size_text.parse::<usize>() {
if let Some(value) = child.child_by_field_name("value") {
if value.kind() == "string_literal" {
let literal_text =
get_node_text(&value, source).trim();
let string_length =
self.get_string_literal_length(literal_text);
if array_size <= string_length {
unsafe_arrays.insert(array_name.to_string());
let start_point = node.start_position();
array_locations.insert(
array_name.to_string(),
(start_point.row, start_point.column),
);
}
}
}
}
}
}
}
}
}
}
}
}
fn check_strncpy_call(
&self,
node: &Node,
source: &str,
unsafe_arrays: &mut HashSet<String>,
array_locations: &mut HashMap<String, (usize, usize)>,
) {
if let Some(function) = node.child_by_field_name("function") {
let func_name = get_node_text(&function, source).trim();
if func_name == "strncpy" {
if let Some(arguments) = node.child_by_field_name("arguments") {
let args = self.extract_arguments(&arguments, source);
if !args.is_empty() {
let dest_text = get_node_text(&args[0], source).trim();
unsafe_arrays.insert(dest_text.to_string());
let start_point = node.start_position();
array_locations
.insert(dest_text.to_string(), (start_point.row, start_point.column));
}
}
}
}
}
fn check_realloc_call(
&self,
node: &Node,
source: &str,
unsafe_arrays: &mut HashSet<String>,
array_locations: &mut HashMap<String, (usize, usize)>,
) {
if let Some(function) = node.child_by_field_name("function") {
let func_name = get_node_text(&function, source).trim();
if func_name == "realloc" {
if let Some(arguments) = node.child_by_field_name("arguments") {
let args = self.extract_arguments(&arguments, source);
if !args.is_empty() {
let ptr_text = get_node_text(&args[0], source).trim();
unsafe_arrays.insert(ptr_text.to_string());
let start_point = node.start_position();
array_locations
.insert(ptr_text.to_string(), (start_point.row, start_point.column));
}
}
}
}
}
fn find_explicit_null_termination(
&self,
node: &Node,
source: &str,
unsafe_arrays: &mut HashSet<String>,
array_locations: &HashMap<String, (usize, usize)>,
) {
if node.kind() == "assignment_expression" {
if let Some(left) = node.child_by_field_name("left") {
if left.kind() == "subscript_expression" {
if let Some(array_node) = left.child_by_field_name("argument") {
let array_name = get_node_text(&array_node, source).trim();
if let Some(right) = node.child_by_field_name("right") {
let right_text = get_node_text(&right, source).trim();
if right_text == "'\\0'" || right_text == "L'\\0'" || right_text == "0"
{
let null_term_line = node.start_position().row;
if let Some(&(unsafe_line, _)) = array_locations.get(array_name) {
if null_term_line > unsafe_line {
unsafe_arrays.remove(array_name);
}
}
}
}
}
}
}
}
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
self.find_explicit_null_termination(&child, source, unsafe_arrays, array_locations);
}
}
fn check_unsafe_usage(
&self,
node: &Node,
source: &str,
unsafe_arrays: &HashSet<String>,
array_locations: &HashMap<String, (usize, usize)>,
violations: &mut Vec<RuleViolation>,
) {
if node.kind() == "call_expression" {
if let Some(function) = node.child_by_field_name("function") {
let func_name = get_node_text(&function, source).trim();
if self.is_string_function(func_name) {
if let Some(arguments) = node.child_by_field_name("arguments") {
let args = self.extract_arguments(&arguments, source);
for arg in &args {
let arg_text = get_node_text(arg, source).trim();
if unsafe_arrays.contains(arg_text) {
let start_point = node.start_position();
violations.push(RuleViolation {
rule_id: self.rule_id().to_string(),
severity: Severity::High,
message: format!(
"Character array '{}' may not be null-terminated but is passed to '{}()' which expects a null-terminated string. \
This can cause buffer overflows or information disclosure.",
arg_text, func_name
),
file_path: String::new(),
line: start_point.row + 1,
column: start_point.column + 1,
suggestion: Some(
format!(
"Ensure '{}' is properly null-terminated before passing to '{}()'. \
For strncpy(), explicitly add a null terminator. For array declarations, \
ensure the bound is large enough to include the null terminator.",
arg_text, func_name
)
),
..Default::default()
});
}
}
}
}
}
}
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
self.check_unsafe_usage(&child, source, unsafe_arrays, array_locations, violations);
}
}
fn is_string_function(&self, func_name: &str) -> bool {
matches!(
func_name,
"strlen"
| "strcpy"
| "strcat"
| "strcmp"
| "strncmp"
| "strstr"
| "strchr"
| "strrchr"
| "strspn"
| "strcspn"
| "strpbrk"
| "strtok"
| "printf"
| "fprintf"
| "sprintf"
| "snprintf"
| "puts"
| "fputs"
| "wcslen"
| "wcscpy"
| "wcscat"
| "wcscmp"
| "wcsncmp"
| "wcsstr"
| "wcschr"
| "wcsrchr"
| "wcsspn"
| "wcscspn"
| "wcspbrk"
| "wcstok"
| "wprintf"
| "fwprintf"
| "swprintf"
)
}
fn extract_arguments<'a>(&self, arguments: &'a Node, _source: &str) -> Vec<Node<'a>> {
let mut args = Vec::new();
let mut cursor = arguments.walk();
for child in arguments.children(&mut cursor) {
if child.kind() != "(" && child.kind() != ")" && child.kind() != "," {
args.push(child);
}
}
args
}
fn get_string_literal_length(&self, literal: &str) -> usize {
let content = literal.trim_matches('"');
let mut length = 0;
let mut chars = content.chars();
while let Some(ch) = chars.next() {
if ch == '\\' {
if let Some(next_ch) = chars.next() {
match next_ch {
'n' | 't' | 'r' | '\\' | '"' | '\'' | '0' => {
length += 1;
}
'x' => {
chars.next(); chars.next(); length += 1;
}
_ => {
length += 1;
}
}
}
} else {
length += 1;
}
}
length
}
}