use super::{
AstMeta, Cfg, EdgeKind, MAX_COND_VARS, MAX_CONDITION_TEXT_LEN, NodeInfo, StmtKind,
collect_idents, connect_all, detect_eq_with_const, detect_negation, has_call_descendant,
member_expr_text, push_node, text_of,
};
use crate::labels::{DataLabel, LangAnalysisRules, classify};
use crate::utils::snippet::truncate_at_char_boundary;
use petgraph::graph::NodeIndex;
use smallvec::SmallVec;
use tree_sitter::Node;
#[derive(Debug, Clone, Copy, PartialEq)]
pub(super) enum BoolOp {
And,
Or,
}
pub(super) fn is_boolean_operator(node: Node) -> Option<BoolOp> {
match node.kind() {
"binary_expression" | "boolean_operator" | "binary" => {
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
match child.kind() {
"&&" | "and" => return Some(BoolOp::And),
"||" | "or" => return Some(BoolOp::Or),
_ => {}
}
}
None
}
_ => None,
}
}
pub(super) fn unwrap_parens(node: Node) -> Node {
if node.kind() == "parenthesized_expression" {
if let Some(inner) = node.named_child(0) {
return unwrap_parens(inner);
}
}
node
}
pub(super) fn get_boolean_operands<'a>(node: Node<'a>) -> Option<(Node<'a>, Node<'a>)> {
if let (Some(left), Some(right)) = (
node.child_by_field_name("left"),
node.child_by_field_name("right"),
) {
return Some((left, right));
}
let mut cursor = node.walk();
let named: Vec<_> = node.named_children(&mut cursor).collect();
if named.len() >= 2 {
return Some((named[0], named[named.len() - 1]));
}
None
}
pub(super) fn push_condition_node<'a>(
g: &mut Cfg,
cond_ast: Node<'a>,
lang: &str,
code: &'a [u8],
enclosing_func: Option<&str>,
) -> NodeIndex {
let (inner, negated) = detect_negation(cond_ast, cond_ast, lang);
let mut vars = Vec::new();
collect_idents(inner, code, &mut vars);
vars.sort();
vars.dedup();
vars.truncate(MAX_COND_VARS);
let text = text_of(cond_ast, code)
.map(|t| truncate_at_char_boundary(&t, MAX_CONDITION_TEXT_LEN).to_string());
let span = (cond_ast.start_byte(), cond_ast.end_byte());
let uses_for_taint: Vec<String> = vars.clone();
g.add_node(NodeInfo {
kind: StmtKind::If,
ast: AstMeta {
span,
enclosing_func: enclosing_func.map(|s| s.to_string()),
},
condition_text: text,
condition_vars: vars,
condition_negated: negated,
taint: crate::cfg::TaintMeta {
uses: uses_for_taint,
..Default::default()
},
..Default::default()
})
}
pub(super) fn detect_rust_let_match_guard<'a>(
ast: Node<'a>,
code: &[u8],
) -> Option<(Node<'a>, String)> {
if ast.kind() != "let_declaration" {
return None;
}
let value = ast.child_by_field_name("value")?;
if value.kind() != "match_expression" {
return None;
}
let body = value.child_by_field_name("body")?;
let mut cursor = body.walk();
let guard = body.children(&mut cursor).find_map(|arm| {
if !matches!(arm.kind(), "match_arm" | "last_match_arm") {
return None;
}
let pattern = arm.child_by_field_name("pattern")?;
pattern.child_by_field_name("condition")
})?;
let pat = ast.child_by_field_name("pattern")?;
let mut idents = Vec::new();
collect_idents(pat, code, &mut idents);
let name = idents.into_iter().next()?;
Some((guard, name))
}
pub(super) fn emit_rust_match_guard_if<'a>(
g: &mut Cfg,
guard: Node<'a>,
let_name: &str,
code: &'a [u8],
enclosing_func: Option<&str>,
) -> NodeIndex {
let mut vars = Vec::new();
collect_idents(guard, code, &mut vars);
vars.push(let_name.to_string());
vars.sort();
vars.dedup();
vars.truncate(MAX_COND_VARS);
let text = text_of(guard, code)
.map(|t| truncate_at_char_boundary(&t, MAX_CONDITION_TEXT_LEN).to_string());
let span = (guard.start_byte(), guard.end_byte());
g.add_node(NodeInfo {
kind: StmtKind::If,
ast: AstMeta {
span,
enclosing_func: enclosing_func.map(|s| s.to_string()),
},
condition_text: text,
condition_vars: vars,
condition_negated: false,
..Default::default()
})
}
#[allow(clippy::too_many_arguments)]
pub(super) fn build_ternary_diamond<'a>(
lhs_text: String,
lhs_labels: SmallVec<[DataLabel; 2]>,
ternary_ast: Node<'a>,
preds: &[NodeIndex],
pred_edge: EdgeKind,
g: &mut Cfg,
lang: &str,
code: &'a [u8],
enclosing_func: Option<&str>,
call_ordinal: &mut u32,
analysis_rules: Option<&LangAnalysisRules>,
) -> Vec<NodeIndex> {
let (Some(cond_field), Some(cons_field), Some(alt_field)) = (
ternary_ast.child_by_field_name("condition"),
ternary_ast.child_by_field_name("consequence"),
ternary_ast.child_by_field_name("alternative"),
) else {
return preds.to_vec();
};
let cond_ast = unwrap_parens(cond_field);
let cons_ast = unwrap_parens(cons_field);
let alt_ast = unwrap_parens(alt_field);
let cond_if = push_condition_node(g, cond_ast, lang, code, enclosing_func);
g[cond_if].is_eq_with_const = detect_eq_with_const(cond_ast, lang);
connect_all(g, preds, cond_if, pred_edge);
let true_exits = lower_ternary_branch(
cons_ast,
&[cond_if],
EdgeKind::True,
&lhs_text,
&lhs_labels,
g,
lang,
code,
enclosing_func,
call_ordinal,
analysis_rules,
);
let false_exits = lower_ternary_branch(
alt_ast,
&[cond_if],
EdgeKind::False,
&lhs_text,
&lhs_labels,
g,
lang,
code,
enclosing_func,
call_ordinal,
analysis_rules,
);
let join_pos = ternary_ast.end_byte();
let join = g.add_node(NodeInfo {
kind: StmtKind::Seq,
ast: AstMeta {
span: (join_pos, join_pos),
enclosing_func: enclosing_func.map(|s| s.to_string()),
},
..Default::default()
});
connect_all(g, &true_exits, join, EdgeKind::Seq);
connect_all(g, &false_exits, join, EdgeKind::Seq);
vec![join]
}
#[allow(clippy::too_many_arguments)]
pub(super) fn lower_ternary_branch<'a>(
branch_ast: Node<'a>,
preds: &[NodeIndex],
pred_edge: EdgeKind,
lhs_text: &str,
lhs_labels: &SmallVec<[DataLabel; 2]>,
g: &mut Cfg,
lang: &str,
code: &'a [u8],
enclosing_func: Option<&str>,
call_ordinal: &mut u32,
analysis_rules: Option<&LangAnalysisRules>,
) -> Vec<NodeIndex> {
if branch_ast.kind() == "ternary_expression" {
return build_ternary_diamond(
lhs_text.to_string(),
lhs_labels.clone(),
branch_ast,
preds,
pred_edge,
g,
lang,
code,
enclosing_func,
call_ordinal,
analysis_rules,
);
}
let has_call = has_call_descendant(branch_ast, lang);
let kind = if has_call {
StmtKind::Call
} else {
StmtKind::Seq
};
let ord = if kind == StmtKind::Call {
let o = *call_ordinal;
*call_ordinal += 1;
o
} else {
0
};
let node = push_node(
g,
kind,
branch_ast,
lang,
code,
enclosing_func,
ord,
analysis_rules,
);
g[node].taint.defines = Some(lhs_text.to_string());
for label in lhs_labels {
if !g[node].taint.labels.contains(label) {
g[node].taint.labels.push(*label);
}
}
connect_all(g, preds, node, pred_edge);
vec![node]
}
pub(super) fn find_ternary_rhs_wrapper<'a>(outer_ast: Node<'a>) -> Option<(Node<'a>, Node<'a>)> {
let mut cursor = outer_ast.walk();
let mut declarator_count = 0usize;
let mut found: Option<(Node<'a>, Node<'a>)> = None;
for child in outer_ast.children(&mut cursor) {
match child.kind() {
"variable_declarator" => {
declarator_count += 1;
if declarator_count > 1 {
return None;
}
let (Some(name), Some(value)) = (
child.child_by_field_name("name"),
child.child_by_field_name("value"),
) else {
continue;
};
let rhs = unwrap_parens(value);
if rhs.kind() == "ternary_expression" {
found = Some((name, rhs));
}
}
"assignment_expression" => {
let (Some(left), Some(right)) = (
child.child_by_field_name("left"),
child.child_by_field_name("right"),
) else {
continue;
};
let rhs = unwrap_parens(right);
if rhs.kind() == "ternary_expression" {
return Some((left, rhs));
}
}
_ => {}
}
}
found
}
pub(super) fn classify_ternary_lhs(
lhs_ast: Node,
lang: &str,
code: &[u8],
analysis_rules: Option<&LangAnalysisRules>,
) -> (String, SmallVec<[DataLabel; 2]>) {
let extra = analysis_rules.map(|r| r.extra_labels.as_slice());
let mut labels: SmallVec<[DataLabel; 2]> = SmallVec::new();
let lhs_text = member_expr_text(lhs_ast, code)
.or_else(|| text_of(lhs_ast, code))
.unwrap_or_default();
if let Some(l) = classify(lang, &lhs_text, extra) {
labels.push(l);
}
if labels.is_empty()
&& let Some(prop) = lhs_ast.child_by_field_name("property")
&& let Some(prop_text) = text_of(prop, code)
&& let Some(l) = classify(lang, &prop_text, extra)
{
labels.push(l);
}
(lhs_text, labels)
}
pub(super) fn build_condition_chain<'a>(
cond_ast: Node<'a>,
preds: &[NodeIndex],
pred_edge: EdgeKind,
g: &mut Cfg,
lang: &str,
code: &'a [u8],
enclosing_func: Option<&str>,
) -> (Vec<NodeIndex>, Vec<NodeIndex>) {
let inner = unwrap_parens(cond_ast);
match is_boolean_operator(inner) {
Some(BoolOp::And) => {
if let Some((left, right)) = get_boolean_operands(inner) {
let (left_true, left_false) =
build_condition_chain(left, preds, pred_edge, g, lang, code, enclosing_func);
let (right_true, right_false) = build_condition_chain(
right,
&left_true,
EdgeKind::True,
g,
lang,
code,
enclosing_func,
);
let mut false_exits = left_false;
false_exits.extend(right_false);
(right_true, false_exits)
} else {
let node = push_condition_node(g, inner, lang, code, enclosing_func);
connect_all(g, preds, node, pred_edge);
(vec![node], vec![node])
}
}
Some(BoolOp::Or) => {
if let Some((left, right)) = get_boolean_operands(inner) {
let (left_true, left_false) =
build_condition_chain(left, preds, pred_edge, g, lang, code, enclosing_func);
let (right_true, right_false) = build_condition_chain(
right,
&left_false,
EdgeKind::False,
g,
lang,
code,
enclosing_func,
);
let mut true_exits = left_true;
true_exits.extend(right_true);
(true_exits, right_false)
} else {
let node = push_condition_node(g, inner, lang, code, enclosing_func);
connect_all(g, preds, node, pred_edge);
(vec![node], vec![node])
}
}
None => {
let node = push_condition_node(g, inner, lang, code, enclosing_func);
connect_all(g, preds, node, pred_edge);
(vec![node], vec![node])
}
}
}