use super::super::{CertRule, RuleViolation};
use crate::analyze::cfg::FunctionCfg;
use crate::analyze::const_eval::{self, MacroConstantMap, VarRangeMap};
use crate::analyze::context::ProjectContext;
use crate::analyze::value_range::{self, RangeAnalysisResult};
use crate::manifest::{RuleCategory, Severity};
use crate::utility::cert_c::ast_utils;
use std::cell::RefCell;
use std::collections::HashMap;
use tree_sitter::Node;
pub struct Int34C {
project_macros: RefCell<MacroConstantMap>,
current_macros: RefCell<MacroConstantMap>,
function_cfgs: RefCell<HashMap<usize, FunctionCfg>>,
vra_results: RefCell<HashMap<usize, RangeAnalysisResult>>,
}
impl Int34C {
pub fn new() -> Self {
Self {
project_macros: RefCell::new(MacroConstantMap::new()),
current_macros: RefCell::new(MacroConstantMap::new()),
function_cfgs: RefCell::new(HashMap::new()),
vra_results: RefCell::new(HashMap::new()),
}
}
}
impl CertRule for Int34C {
fn rule_id(&self) -> &'static str {
"INT34-C"
}
fn description(&self) -> &'static str {
"Do not shift an expression by a negative number of bits or by greater than or equal to the number of bits that exist in the operand"
}
fn severity(&self) -> Severity {
Severity::Medium
}
fn category(&self) -> RuleCategory {
RuleCategory::Rule
}
fn cert_id(&self) -> &'static str {
"INT34-C"
}
fn set_project_context(&self, context: &ProjectContext) {
*self.project_macros.borrow_mut() = context.macro_constants.clone();
}
fn set_function_cfgs(&self, cfgs: &HashMap<usize, FunctionCfg>) {
*self.function_cfgs.borrow_mut() = cfgs.clone();
}
fn set_vra_results(&self, results: &HashMap<usize, RangeAnalysisResult>) {
let mut stored = HashMap::new();
for (&key, result) in results {
stored.insert(
key,
RangeAnalysisResult {
block_entry_ranges: result.block_entry_ranges.clone(),
block_exit_ranges: result.block_exit_ranges.clone(),
return_ranges: result.return_ranges.clone(),
},
);
}
*self.vra_results.borrow_mut() = stored;
}
fn needs_vra(&self) -> bool {
true
}
fn check(&self, node: &Node, source: &str) -> Vec<RuleViolation> {
let mut violations = Vec::new();
let mut macros = self.project_macros.borrow().clone();
macros.extend(const_eval::collect_macro_constants(node, source));
*self.current_macros.borrow_mut() = macros;
self.check_recursive(node, source, &mut violations);
violations
}
}
impl Int34C {
fn check_recursive(&self, node: &Node, source: &str, violations: &mut Vec<RuleViolation>) {
if node.kind() == "binary_expression" {
if let Some(operator) = ast_utils::get_binary_operator(node, source) {
if operator == "<<" || operator == ">>" {
self.check_shift_operation(node, source, operator, violations);
}
}
}
for i in 0..node.child_count() {
if let Some(child) = node.child(i) {
self.check_recursive(&child, source, violations);
}
}
}
fn check_shift_operation(
&self,
node: &Node,
source: &str,
operator: &str,
violations: &mut Vec<RuleViolation>,
) {
let left = node.child_by_field_name("left");
let right = node.child_by_field_name("right");
if let (Some(left_node), Some(right_node)) = (left, right) {
let right_text = ast_utils::get_node_text(&right_node, source);
let left_text = ast_utils::get_node_text(&left_node, source);
if self.is_non_negative_integer_literal(&right_node, source) {
return;
}
if self.shift_amount_bounded_by_modulo(&right_node, source) {
return;
}
if let Some(range) = self.eval_shift_range_via_vra(node, &right_node, source) {
if range.min >= 0 && range.max < 32 {
return;
}
}
{
let macros = self.current_macros.borrow();
let mut var_ranges = const_eval::extract_loop_var_ranges(node, source, ¯os);
Self::extract_if_condition_ranges(node, source, ¯os, &mut var_ranges);
if let Some(range) =
const_eval::try_evaluate_range(&right_node, source, ¯os, &var_ranges)
{
if range.min >= 0 && range.max < 32 {
return;
}
}
}
if self.is_likely_unsigned(left_text, &left_node, source) {
if operator == "<<" && !self.is_shift_amount_validated(node, &right_node, source) {
self.report_violation(
node,
left_text.to_string(),
right_text.to_string(),
source,
violations,
);
}
} else {
if !self.is_shift_amount_validated(node, &right_node, source) {
self.report_violation(
node,
left_text.to_string(),
right_text.to_string(),
source,
violations,
);
}
}
}
}
fn report_violation(
&self,
node: &Node,
_left_text: String,
right_text: String,
source: &str,
violations: &mut Vec<RuleViolation>,
) {
let operation = ast_utils::get_node_text(node, source);
violations.push(RuleViolation {
rule_id: self.rule_id().to_string(),
severity: self.severity(),
message: format!(
"Shift operation '{}' by '{}' without validating shift amount is non-negative and within type width",
operation, right_text
),
file_path: String::new(),
line: node.start_position().row + 1,
column: node.start_position().column + 1,
suggestion: Some(format!(
"Check that '{}' is >= 0 and < the bit width of the operand before shifting",
right_text
)),
..Default::default()
});
}
fn shift_amount_bounded_by_modulo(&self, node: &Node, source: &str) -> bool {
if node.kind() == "binary_expression" {
if let Some(op) = ast_utils::get_binary_operator(node, source) {
if op == "%" {
if let Some(right) = node.child_by_field_name("right") {
if self.is_non_negative_integer_literal(&right, source) {
let text = ast_utils::get_node_text(&right, source)
.trim()
.to_ascii_lowercase();
let stripped = text.trim_end_matches(['u', 'l']);
if let Ok(modulus) = stripped.parse::<u64>() {
return modulus > 0 && modulus <= 64;
}
}
}
}
}
}
if node.kind() == "parenthesized_expression" {
if let Some(inner) = node.child(1) {
return self.shift_amount_bounded_by_modulo(&inner, source);
}
}
false
}
fn is_non_negative_integer_literal(&self, node: &Node, source: &str) -> bool {
if node.kind() != "number_literal" {
return false;
}
let text = ast_utils::get_node_text(node, source)
.trim()
.to_ascii_lowercase();
let stripped = text.trim_end_matches(['u', 'l']);
if let Some(hex) = stripped.strip_prefix("0x") {
u64::from_str_radix(hex, 16).is_ok()
} else if let Some(bin) = stripped.strip_prefix("0b") {
u64::from_str_radix(bin, 2).is_ok()
} else if stripped.starts_with('0') && stripped.len() > 1 {
u64::from_str_radix(&stripped[1..], 8).is_ok()
} else {
stripped.parse::<u64>().is_ok()
}
}
fn is_likely_unsigned(&self, var_name: &str, node: &Node, source: &str) -> bool {
if var_name.starts_with("ui_")
|| var_name.starts_with("u_")
|| var_name.starts_with("unsigned_")
{
return true;
}
if let Some(func) = ast_utils::find_containing_function(node) {
if let Some(param_list) = self.find_parameter_list(&func) {
for i in 0..param_list.child_count() {
if let Some(param) = param_list.child(i) {
if param.kind() == "parameter_declaration"
&& self.decl_has_unsigned_var(¶m, var_name, source)
{
return true;
}
}
}
}
if let Some(body) = func.child_by_field_name("body") {
if self.body_has_unsigned_var(&body, var_name, source) {
return true;
}
}
}
false
}
fn find_parameter_list<'a>(&self, func: &'a Node) -> Option<Node<'a>> {
for i in 0..func.child_count() {
if let Some(child) = func.child(i) {
if child.kind() == "function_declarator" {
for j in 0..child.child_count() {
if let Some(grandchild) = child.child(j) {
if grandchild.kind() == "parameter_list" {
return Some(grandchild);
}
}
}
}
}
}
None
}
fn decl_has_unsigned_var(&self, decl: &Node, var_name: &str, source: &str) -> bool {
let mut has_unsigned = false;
let mut declares_var = false;
for i in 0..decl.child_count() {
if let Some(child) = decl.child(i) {
match child.kind() {
"sized_type_specifier" => {
let text = ast_utils::get_node_text(&child, source);
if text.contains("unsigned") {
has_unsigned = true;
}
}
"identifier" if ast_utils::get_node_text(&child, source) == var_name => {
declares_var = true;
}
"pointer_declarator" | "array_declarator" | "init_declarator"
if self.declarator_contains_name(&child, var_name, source) =>
{
declares_var = true;
}
_ => {}
}
}
}
has_unsigned && declares_var
}
fn declarator_contains_name(&self, node: &Node, var_name: &str, source: &str) -> bool {
if node.kind() == "identifier" {
return ast_utils::get_node_text(node, source) == var_name;
}
for i in 0..node.child_count() {
if let Some(child) = node.child(i) {
if self.declarator_contains_name(&child, var_name, source) {
return true;
}
}
}
false
}
fn body_has_unsigned_var(&self, body: &Node, var_name: &str, source: &str) -> bool {
for i in 0..body.child_count() {
if let Some(child) = body.child(i) {
if child.kind() == "declaration"
&& self.decl_has_unsigned_var(&child, var_name, source)
{
return true;
}
}
}
false
}
fn is_shift_amount_validated(
&self,
shift_node: &Node,
shift_amount: &Node,
source: &str,
) -> bool {
let shift_var = ast_utils::get_node_text(shift_amount, source);
if let Some(func) = ast_utils::find_containing_function(shift_node) {
if let Some(body) = func.child_by_field_name("body") {
if self.has_validation_check(&body, shift_var, source, shift_node) {
return true;
}
}
}
let mut current = shift_node.parent();
while let Some(node) = current {
match node.kind() {
"if_statement" => {
if let Some(condition) = node.child_by_field_name("condition") {
if self.checks_shift_bounds(&condition, shift_var, source) {
if self.is_in_safe_branch(&node, shift_node) {
return true;
}
}
}
}
"while_statement" | "for_statement" | "do_statement" => {
if let Some(condition) = node.child_by_field_name("condition") {
if self.loop_bounds_shift_amount(&condition, shift_amount, source) {
return true;
}
}
}
_ => {}
}
current = node.parent();
}
false
}
fn loop_bounds_shift_amount(
&self,
condition: &Node,
shift_amount: &Node,
source: &str,
) -> bool {
let mut shift_vars = Vec::new();
Self::collect_identifiers_from(shift_amount, source, &mut shift_vars);
if shift_vars.is_empty() {
return false;
}
let cond = if condition.kind() == "parenthesized_expression" {
match condition.child(1) {
Some(c) => c,
None => return false,
}
} else {
*condition
};
if self.condition_bounds_var_small(&cond, &shift_vars, source) {
return true;
}
if self.condition_is_shift_to_zero_check(&cond, &shift_vars, source) {
return true;
}
false
}
fn collect_identifiers_from(node: &Node, source: &str, names: &mut Vec<String>) {
if node.kind() == "identifier" {
let name = ast_utils::get_node_text(node, source).to_string();
if !names.contains(&name) {
names.push(name);
}
}
for i in 0..node.child_count() {
if let Some(child) = node.child(i) {
Self::collect_identifiers_from(&child, source, names);
}
}
}
fn condition_bounds_var_small(&self, cond: &Node, var_names: &[String], source: &str) -> bool {
if cond.kind() != "binary_expression" {
return false;
}
let op = ast_utils::get_binary_operator(cond, source).unwrap_or_default();
if op == "&&" {
if let Some(left) = cond.child_by_field_name("left") {
if self.condition_bounds_var_small(&left, var_names, source) {
return true;
}
}
if let Some(right) = cond.child_by_field_name("right") {
if self.condition_bounds_var_small(&right, var_names, source) {
return true;
}
}
return false;
}
let (left, right) = match (
cond.child_by_field_name("left"),
cond.child_by_field_name("right"),
) {
(Some(l), Some(r)) => (l, r),
_ => return false,
};
let left_text = ast_utils::get_node_text(&left, source);
let right_text = ast_utils::get_node_text(&right, source);
if (op == "<" || op == "<=") && var_names.iter().any(|v| v == left_text) {
if let Ok(bound) = right_text.trim().parse::<i64>() {
return bound <= 32;
}
let macros = self.current_macros.borrow();
if let Some(val) = const_eval::try_evaluate_expr(&right, source, ¯os) {
return val <= 32;
}
}
if (op == ">" || op == ">=") && var_names.iter().any(|v| v == right_text) {
if let Ok(bound) = left_text.trim().parse::<i64>() {
return bound <= 32;
}
let macros = self.current_macros.borrow();
if let Some(val) = const_eval::try_evaluate_expr(&left, source, ¯os) {
return val <= 32;
}
}
false
}
fn condition_is_shift_to_zero_check(
&self,
cond: &Node,
var_names: &[String],
source: &str,
) -> bool {
if cond.kind() != "binary_expression" {
return false;
}
let op = ast_utils::get_binary_operator(cond, source).unwrap_or_default();
if op != "!=" && op != "==" {
return false;
}
let (left, right) = match (
cond.child_by_field_name("left"),
cond.child_by_field_name("right"),
) {
(Some(l), Some(r)) => (l, r),
_ => return false,
};
let (shift_side, zero_side) = if self.is_zero_literal(&right, source) {
(left, right)
} else if self.is_zero_literal(&left, source) {
(right, left)
} else {
return false;
};
let _ = zero_side;
let shift_expr = if shift_side.kind() == "parenthesized_expression" {
shift_side.named_child(0).unwrap_or(shift_side)
} else {
shift_side
};
if shift_expr.kind() == "binary_expression" {
if let Some(shift_op) = ast_utils::get_binary_operator(&shift_expr, source) {
if shift_op == ">>" {
if let Some(rhs) = shift_expr.child_by_field_name("right") {
let rhs_text = ast_utils::get_node_text(&rhs, source);
if var_names.iter().any(|v| v == rhs_text) {
return true;
}
}
}
}
}
false
}
fn is_zero_literal(&self, node: &Node, source: &str) -> bool {
let text = ast_utils::get_node_text(node, source).trim().to_string();
text == "0" || text == "0u" || text == "0U" || text == "0L" || text == "0UL"
}
fn has_validation_check(
&self,
scope: &Node,
var_name: &str,
source: &str,
shift_node: &Node,
) -> bool {
let shift_line = shift_node.start_position().row;
for i in 0..scope.named_child_count() {
if let Some(child) = scope.named_child(i) {
let child_line = child.start_position().row;
if child_line >= shift_line {
break;
}
if child.kind() == "if_statement" {
if let Some(condition) = child.child_by_field_name("condition") {
if self.checks_shift_bounds(&condition, var_name, source) {
if let Some(consequence) = child.child_by_field_name("consequence") {
if Self::has_return_or_error_handling(&consequence, source) {
return true;
}
}
}
}
}
}
}
false
}
fn checks_shift_bounds(&self, condition: &Node, var_name: &str, source: &str) -> bool {
let condition_text = ast_utils::get_node_text(condition, source);
let has_negative_check = condition_text.contains(&format!("{} < 0", var_name))
|| condition_text.contains(&format!("0 > {}", var_name))
|| condition_text.contains(&format!("{} >= 0", var_name))
|| condition_text.contains(&format!("0 <= {}", var_name));
let has_width_check = condition_text.contains(&format!("{} <", var_name))
|| condition_text.contains(&format!("{} >=", var_name))
|| condition_text.contains("PRECISION")
|| condition_text.contains("CHAR_BIT")
|| condition_text.contains("_MAX");
if has_negative_check || has_width_check {
return true;
}
for i in 0..condition.child_count() {
if let Some(child) = condition.child(i) {
if child.kind() == "binary_expression" {
if let Some(operator) = ast_utils::get_binary_operator(&child, source) {
if operator == "<"
|| operator == ">"
|| operator == "<="
|| operator == ">="
{
let left = child.child_by_field_name("left");
let right = child.child_by_field_name("right");
if let (Some(l), Some(r)) = (left, right) {
let left_text = ast_utils::get_node_text(&l, source);
let right_text = ast_utils::get_node_text(&r, source);
if left_text == var_name || right_text == var_name {
if right_text.contains("PRECISION")
|| right_text.contains("CHAR_BIT")
|| right_text.contains("MAX")
|| left_text.contains("PRECISION")
|| left_text.contains("CHAR_BIT")
|| left_text.contains("MAX")
|| right_text == "0"
|| left_text == "0"
{
return true;
}
}
}
}
}
}
}
}
false
}
fn has_return_or_error_handling(node: &Node, source: &str) -> bool {
let text = ast_utils::get_node_text(node, source);
if text.contains("return") || text.contains("error") || text.contains("exit") {
return true;
}
for i in 0..node.child_count() {
if let Some(child) = node.child(i) {
if child.kind() == "return_statement"
|| child.kind() == "break_statement"
|| child.kind() == "continue_statement"
{
return true;
}
if Self::has_return_or_error_handling(&child, source) {
return true;
}
}
}
false
}
fn is_in_safe_branch(&self, if_node: &Node, shift_node: &Node) -> bool {
if let Some(consequence) = if_node.child_by_field_name("consequence") {
if Self::is_descendant(&consequence, shift_node) {
return true;
}
}
if let Some(alternative) = if_node.child_by_field_name("alternative") {
if Self::is_descendant(&alternative, shift_node) {
return true;
}
}
false
}
fn extract_if_condition_ranges(
node: &Node,
source: &str,
macros: &MacroConstantMap,
ranges: &mut VarRangeMap,
) {
let mut current = node.parent();
while let Some(ancestor) = current {
if ancestor.kind() == "if_statement" {
if let Some(consequence) = ancestor.child_by_field_name("consequence") {
if node.start_byte() >= consequence.start_byte()
&& node.end_byte() <= consequence.end_byte()
{
if let Some(condition) = ancestor.child_by_field_name("condition") {
let cond = if condition.kind() == "parenthesized_expression" {
condition.child(1)
} else {
Some(condition)
};
if let Some(cond) = cond {
Self::extract_comparison_bounds(&cond, source, macros, ranges);
}
}
}
}
}
current = ancestor.parent();
}
}
fn extract_comparison_bounds(
node: &Node,
source: &str,
macros: &MacroConstantMap,
ranges: &mut VarRangeMap,
) {
if node.kind() != "binary_expression" {
return;
}
let op = match ast_utils::get_binary_operator(node, source) {
Some(o) => o,
None => return,
};
if op == "&&" {
if let Some(left) = node.child_by_field_name("left") {
Self::extract_comparison_bounds(&left, source, macros, ranges);
}
if let Some(right) = node.child_by_field_name("right") {
Self::extract_comparison_bounds(&right, source, macros, ranges);
}
return;
}
let (left, right) = match (
node.child_by_field_name("left"),
node.child_by_field_name("right"),
) {
(Some(l), Some(r)) => (l, r),
_ => return,
};
match op {
"<" => {
if left.kind() == "identifier" {
if let Some(bound) = const_eval::try_evaluate_expr(&right, source, macros) {
let name = ast_utils::get_node_text(&left, source).to_string();
let entry = ranges
.entry(name)
.or_insert(const_eval::ValueRange::new(i64::MIN, i64::MAX));
entry.max = entry.max.min(bound - 1);
}
}
if right.kind() == "identifier" {
if let Some(bound) = const_eval::try_evaluate_expr(&left, source, macros) {
let name = ast_utils::get_node_text(&right, source).to_string();
let entry = ranges
.entry(name)
.or_insert(const_eval::ValueRange::new(i64::MIN, i64::MAX));
entry.min = entry.min.max(bound + 1);
}
}
}
"<=" => {
if left.kind() == "identifier" {
if let Some(bound) = const_eval::try_evaluate_expr(&right, source, macros) {
let name = ast_utils::get_node_text(&left, source).to_string();
let entry = ranges
.entry(name)
.or_insert(const_eval::ValueRange::new(i64::MIN, i64::MAX));
entry.max = entry.max.min(bound);
}
}
if right.kind() == "identifier" {
if let Some(bound) = const_eval::try_evaluate_expr(&left, source, macros) {
let name = ast_utils::get_node_text(&right, source).to_string();
let entry = ranges
.entry(name)
.or_insert(const_eval::ValueRange::new(i64::MIN, i64::MAX));
entry.min = entry.min.max(bound);
}
}
}
">" => {
if left.kind() == "identifier" {
if let Some(bound) = const_eval::try_evaluate_expr(&right, source, macros) {
let name = ast_utils::get_node_text(&left, source).to_string();
let entry = ranges
.entry(name)
.or_insert(const_eval::ValueRange::new(i64::MIN, i64::MAX));
entry.min = entry.min.max(bound + 1);
}
}
if right.kind() == "identifier" {
if let Some(bound) = const_eval::try_evaluate_expr(&left, source, macros) {
let name = ast_utils::get_node_text(&right, source).to_string();
let entry = ranges
.entry(name)
.or_insert(const_eval::ValueRange::new(i64::MIN, i64::MAX));
entry.max = entry.max.min(bound - 1);
}
}
}
">=" => {
if left.kind() == "identifier" {
if let Some(bound) = const_eval::try_evaluate_expr(&right, source, macros) {
let name = ast_utils::get_node_text(&left, source).to_string();
let entry = ranges
.entry(name)
.or_insert(const_eval::ValueRange::new(i64::MIN, i64::MAX));
entry.min = entry.min.max(bound);
}
}
if right.kind() == "identifier" {
if let Some(bound) = const_eval::try_evaluate_expr(&left, source, macros) {
let name = ast_utils::get_node_text(&right, source).to_string();
let entry = ranges
.entry(name)
.or_insert(const_eval::ValueRange::new(i64::MIN, i64::MAX));
entry.max = entry.max.min(bound);
}
}
}
"!=" => {
}
_ => {}
}
}
fn is_descendant(node: &Node, target: &Node) -> bool {
if node.id() == target.id() {
return true;
}
for i in 0..node.child_count() {
if let Some(child) = node.child(i) {
if Self::is_descendant(&child, target) {
return true;
}
}
}
false
}
fn eval_shift_range_via_vra(
&self,
shift_node: &Node,
right_node: &Node,
source: &str,
) -> Option<const_eval::ValueRange> {
let vra_results = self.vra_results.borrow();
let cfgs = self.function_cfgs.borrow();
let macros = self.current_macros.borrow();
if vra_results.is_empty() || cfgs.is_empty() {
return None;
}
let func = ast_utils::find_containing_function(shift_node)?;
let start_byte = func.start_byte();
let cfg = cfgs.get(&start_byte)?;
let vra = vra_results.get(&start_byte)?;
let body = func.child_by_field_name("body")?;
value_range::eval_expr_range_at(vra, cfg, &body, source, ¯os, right_node)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn parse_c_code(source: &str) -> tree_sitter::Tree {
let mut parser = tree_sitter::Parser::new();
parser
.set_language(&tree_sitter_c::language())
.expect("Error loading C grammar");
parser.parse(source, None).expect("Error parsing C code")
}
#[test]
fn test_unchecked_shift() {
let code = r#"
void func(unsigned int a, unsigned int b) {
unsigned int result = a << b;
}
"#;
let tree = parse_c_code(code);
let rule = Int34C::new();
let violations = rule.check(&tree.root_node(), code);
assert!(!violations.is_empty(), "Should detect unchecked shift");
}
#[test]
fn test_validated_shift() {
let code = r#"
#include <limits.h>
void func(unsigned int a, unsigned int b) {
unsigned int result = 0;
if (b >= 32) {
/* Handle error */
} else {
result = a << b;
}
}
"#;
let tree = parse_c_code(code);
let rule = Int34C::new();
let violations = rule.check(&tree.root_node(), code);
assert!(
violations.is_empty(),
"Should not flag validated shift: {:?}",
violations
);
}
}