use super::super::{CertRule, RuleViolation};
use crate::manifest::{RuleCategory, Severity};
use crate::utility::cert_c::ast_utils::get_node_text;
use std::collections::{HashMap, HashSet};
use tree_sitter::Node;
pub struct Pos39C;
impl CertRule for Pos39C {
fn rule_id(&self) -> &'static str {
"POS39-C"
}
fn description(&self) -> &'static str {
"Use the correct byte ordering when transferring data between systems"
}
fn severity(&self) -> Severity {
Severity::Medium
}
fn category(&self) -> RuleCategory {
RuleCategory::Rule
}
fn cert_id(&self) -> &'static str {
"POS39-C"
}
fn check(&self, node: &Node, source: &str) -> Vec<RuleViolation> {
let mut violations = Vec::new();
let mut multi_byte_vars: HashMap<String, String> = HashMap::new();
let mut converted_vars: HashSet<String> = HashSet::new();
let mut received_vars: HashMap<String, (usize, usize)> = HashMap::new();
self.find_multi_byte_vars(node, source, &mut multi_byte_vars);
self.find_byte_order_conversions(node, source, &mut converted_vars);
self.find_recv_calls(
node,
source,
&multi_byte_vars,
&converted_vars,
&mut received_vars,
&mut violations,
);
violations
}
}
impl Pos39C {
fn find_multi_byte_vars(&self, node: &Node, source: &str, vars: &mut HashMap<String, String>) {
if node.kind() == "declaration" {
let decl_text = get_node_text(node, source);
if self.is_multi_byte_type(&decl_text) {
if let Some(var_name) = self.extract_var_name(node, source) {
let type_name = self.extract_type_name(&decl_text);
vars.insert(var_name, type_name);
}
}
}
for i in 0..node.child_count() {
if let Some(child) = node.child(i) {
self.find_multi_byte_vars(&child, source, vars);
}
}
}
fn is_multi_byte_type(&self, text: &str) -> bool {
text.contains("uint32_t")
|| text.contains("uint16_t")
|| text.contains("int32_t")
|| text.contains("int16_t")
|| text.contains("uint64_t")
|| text.contains("int64_t")
|| text.contains("unsigned int")
|| text.contains("unsigned short")
}
fn extract_type_name(&self, decl_text: &str) -> String {
if decl_text.contains("uint32_t") {
"uint32_t".to_string()
} else if decl_text.contains("uint16_t") {
"uint16_t".to_string()
} else if decl_text.contains("int32_t") {
"int32_t".to_string()
} else if decl_text.contains("int16_t") {
"int16_t".to_string()
} else {
"multi-byte".to_string()
}
}
fn extract_var_name(&self, decl: &Node, source: &str) -> Option<String> {
for i in 0..decl.child_count() {
if let Some(child) = decl.child(i) {
if child.kind() == "init_declarator" || child.kind() == "identifier" {
return self.find_identifier(&child, source);
}
}
}
None
}
fn find_identifier(&self, node: &Node, source: &str) -> Option<String> {
if node.kind() == "identifier" {
return Some(get_node_text(node, source).to_string());
}
for i in 0..node.child_count() {
if let Some(child) = node.child(i) {
if let Some(name) = self.find_identifier(&child, source) {
return Some(name);
}
}
}
None
}
fn find_byte_order_conversions(
&self,
node: &Node,
source: &str,
converted: &mut HashSet<String>,
) {
if node.kind() == "assignment_expression" {
if let Some(left) = node.child_by_field_name("left") {
let var_name = get_node_text(&left, source);
if let Some(right) = node.child_by_field_name("right") {
let right_text = get_node_text(&right, source);
if self.has_byte_order_conversion(&right_text) {
converted.insert(var_name.to_string());
}
}
}
}
if node.kind() == "call_expression" {
if let Some(function) = node.child_by_field_name("function") {
let func_name = get_node_text(&function, source);
if self.is_byte_order_function(&func_name) {
if let Some(args) = node.child_by_field_name("arguments") {
let args_text = get_node_text(&args, source);
let clean_args = args_text.trim_matches(|c| c == '(' || c == ')');
converted.insert(clean_args.to_string());
}
}
}
}
for i in 0..node.child_count() {
if let Some(child) = node.child(i) {
self.find_byte_order_conversions(&child, source, converted);
}
}
}
fn has_byte_order_conversion(&self, text: &str) -> bool {
text.contains("ntohl")
|| text.contains("ntohs")
|| text.contains("htonl")
|| text.contains("htons")
}
fn is_byte_order_function(&self, name: &str) -> bool {
matches!(name, "ntohl" | "ntohs" | "htonl" | "htons")
}
fn find_recv_calls(
&self,
node: &Node,
source: &str,
multi_byte_vars: &HashMap<String, String>,
converted_vars: &HashSet<String>,
_received_vars: &mut HashMap<String, (usize, usize)>,
violations: &mut Vec<RuleViolation>,
) {
if node.kind() == "call_expression" {
if let Some(function) = node.child_by_field_name("function") {
let func_name = get_node_text(&function, source);
if func_name == "recv" || func_name == "recvfrom" || func_name == "read" {
if let Some(args) = node.child_by_field_name("arguments") {
let args_text = get_node_text(&args, source);
for (var_name, var_type) in multi_byte_vars {
if args_text.contains(&format!("&{}", var_name))
|| args_text.contains(&format!("(void *)&{}", var_name))
{
if !converted_vars.contains(var_name) {
violations.push(RuleViolation {
rule_id: self.rule_id().to_string(),
message: format!(
"Network data received into '{}' ({}) without byte order conversion. \
Use ntohl()/ntohs() to convert from network byte order.",
var_name, var_type
),
severity: self.severity(),
line: node.start_position().row + 1,
column: node.start_position().column + 1,
file_path: String::new(),
suggestion: Some(format!(
"Add: {} = {}({});",
var_name,
if var_type.contains("16") {
"ntohs"
} else {
"ntohl"
},
var_name
)),
requires_manual_review: None,
});
}
}
}
}
}
}
}
for i in 0..node.child_count() {
if let Some(child) = node.child(i) {
self.find_recv_calls(
&child,
source,
multi_byte_vars,
converted_vars,
_received_vars,
violations,
);
}
}
}
}