use std::{collections::BTreeSet, mem};
use crate::{
analysis::{
cff_taint_config, ConstValue, SsaEvaluator, SsaFunction, SsaOp, SsaVarId, SsaVariable,
TaintAnalysis,
},
deobfuscation::passes::unflattening::{tracer::types::TracedDispatcher, UnflattenConfig},
utils::BitSet,
CilObject,
};
pub struct TreeTraceContext<'a> {
ssa: &'a SsaFunction,
evaluator: SsaEvaluator<'a>,
assembly: Option<&'a CilObject>,
dispatcher: Option<TracedDispatcher>,
state_tainted: BitSet,
next_node_id: usize,
total_visits: usize,
visited_states: BTreeSet<(usize, i64)>,
last_case_index: usize,
visited_case_counts: Vec<u8>,
last_case_state: Vec<Option<i64>>,
max_block_visits: usize,
max_tree_depth: usize,
other_dispatcher_blocks: Vec<usize>,
no_fork: bool,
}
impl<'a> TreeTraceContext<'a> {
pub fn new(
ssa: &'a SsaFunction,
config: &UnflattenConfig,
assembly: Option<&'a CilObject>,
) -> Self {
Self {
ssa,
evaluator: SsaEvaluator::new(ssa, config.pointer_size),
assembly,
dispatcher: None,
state_tainted: BitSet::new(ssa.var_id_capacity()),
next_node_id: 0,
total_visits: 0,
visited_states: BTreeSet::new(),
last_case_index: usize::MAX,
visited_case_counts: Vec::new(),
last_case_state: Vec::new(),
max_block_visits: config.max_block_visits,
max_tree_depth: config.max_tree_depth,
other_dispatcher_blocks: Vec::new(),
no_fork: false,
}
}
pub fn with_dispatcher(
ssa: &'a SsaFunction,
dispatcher: TracedDispatcher,
config: &UnflattenConfig,
assembly: Option<&'a CilObject>,
) -> Self {
let mut ctx = Self::new(ssa, config, assembly);
if let Some(state_var) = dispatcher.state_var {
let state_origin = ssa.variable(state_var).map(SsaVariable::origin);
let taint_config = cff_taint_config(ssa, dispatcher.block, state_origin);
let mut taint = TaintAnalysis::new(taint_config);
taint.add_tainted_var(state_var);
if let Some(disp_block) = ssa.block(dispatcher.block) {
for phi in disp_block.phi_nodes() {
if phi.result() == state_var {
for op in phi.operands() {
taint.add_tainted_var(op.value());
}
}
}
}
taint.propagate(ssa);
for var in taint.tainted_variables() {
ctx.state_tainted.insert(var.index());
}
}
if let (Some(state_var), Some(initial)) = (dispatcher.state_var, dispatcher.initial_state) {
let entry_pred = {
let mut pred = 0usize;
let mut current = 0usize;
for _ in 0..20 {
if current == dispatcher.block {
break;
}
pred = current;
match ssa.block(current).and_then(|b| b.terminator_op()) {
Some(SsaOp::Jump { target }) => current = *target,
_ => break,
}
}
pred
};
if let Some(disp_block) = ssa.block(dispatcher.block) {
for phi in disp_block.phi_nodes() {
if phi.result() == state_var {
for op in phi.operands() {
if op.predecessor() == entry_pred {
#[allow(clippy::cast_possible_truncation)]
ctx.evaluator
.set_concrete(op.value(), ConstValue::I32(initial as i32));
}
}
}
}
}
}
ctx.visited_case_counts = vec![0u8; dispatcher.targets.len() + 1];
ctx.last_case_state = vec![None; dispatcher.targets.len() + 1];
ctx.dispatcher = Some(dispatcher);
ctx
}
pub fn ssa(&self) -> &'a SsaFunction {
self.ssa
}
pub fn evaluator(&self) -> &SsaEvaluator<'a> {
&self.evaluator
}
pub fn evaluator_mut(&mut self) -> &mut SsaEvaluator<'a> {
&mut self.evaluator
}
pub fn assembly(&self) -> Option<&'a CilObject> {
self.assembly
}
pub fn next_id(&mut self) -> usize {
let id = self.next_node_id;
self.next_node_id += 1;
id
}
pub fn is_dispatcher_block(&self, block: usize) -> bool {
self.dispatcher.as_ref().is_some_and(|d| d.block == block)
}
pub fn is_dispatch_target(&self, block: usize) -> bool {
self.dispatcher
.as_ref()
.is_some_and(|d| d.targets.contains(&block) || d.default == block)
}
pub fn state_var(&self) -> Option<SsaVarId> {
self.dispatcher.as_ref().and_then(|d| d.state_var)
}
pub fn dispatcher_block(&self) -> Option<usize> {
self.dispatcher.as_ref().map(|d| d.block)
}
pub fn is_other_dispatcher(&self, block: usize) -> bool {
self.other_dispatcher_blocks.contains(&block)
}
pub fn set_other_dispatcher_blocks(&mut self, blocks: Vec<usize>) {
self.other_dispatcher_blocks = blocks;
}
pub fn is_tainted(&self, var: SsaVarId) -> bool {
self.state_tainted.contains(var.index())
}
pub fn any_tainted(&self, vars: &[SsaVarId]) -> bool {
vars.iter().any(|v| self.is_tainted(*v))
}
pub fn taint(&mut self, var: SsaVarId) {
self.state_tainted.insert(var.index());
}
pub fn state_tainted(&self) -> &BitSet {
&self.state_tainted
}
pub fn state_tainted_mut(&mut self) -> &mut BitSet {
&mut self.state_tainted
}
pub fn propagate_taint_forward(&mut self) {
super::helpers::propagate_taint_forward(self.ssa, &mut self.state_tainted);
}
pub fn current_state(&self) -> Option<i64> {
self.dispatcher
.as_ref()
.and_then(|d| d.state_var)
.and_then(|v| self.evaluator.get_concrete(v))
.and_then(ConstValue::as_i64)
}
fn visit_state(&self) -> i64 {
self.current_state().unwrap_or_else(|| {
let count = if self.last_case_index < self.visited_case_counts.len() {
self.visited_case_counts[self.last_case_index] as i64
} else {
0
};
(self.last_case_index as i64)
.wrapping_mul(256)
.wrapping_add(count)
})
}
pub fn is_visited(&self, block: usize) -> bool {
self.visited_states.contains(&(block, self.visit_state()))
}
pub fn mark_visited(&mut self, block: usize) {
self.visited_states.insert((block, self.visit_state()));
}
pub fn check_visit_budget(&mut self) -> bool {
self.total_visits += 1;
self.total_visits > self.max_block_visits
}
pub fn max_tree_depth(&self) -> usize {
self.max_tree_depth
}
pub fn record_case_dispatch(&mut self, case_idx: usize) {
if case_idx < self.visited_case_counts.len() {
self.visited_case_counts[case_idx] =
self.visited_case_counts[case_idx].saturating_add(1);
}
self.last_case_index = case_idx;
}
pub fn case_state_is_stuck(&self, case_idx: usize, current_state: i64) -> bool {
self.last_case_state
.get(case_idx)
.and_then(|slot| *slot)
.is_some_and(|prev| prev == current_state)
}
pub fn record_case_state(&mut self, case_idx: usize, state: i64) {
if case_idx < self.last_case_state.len() {
self.last_case_state[case_idx] = Some(state);
}
}
pub fn is_case_loop(&self, case_idx: usize, targets_len: usize) -> bool {
let loop_threshold = (targets_len / 2).max(2) as u8;
case_idx < self.visited_case_counts.len()
&& self.visited_case_counts[case_idx] >= loop_threshold
}
pub fn no_fork(&self) -> bool {
self.no_fork
}
pub fn snapshot(&self) -> ContextSnapshot<'a> {
ContextSnapshot {
evaluator: self.evaluator.clone(),
visited_states: self.visited_states.clone(),
last_case_index: self.last_case_index,
visited_case_counts: self.visited_case_counts.clone(),
last_case_state: self.last_case_state.clone(),
}
}
pub fn restore(&mut self, snap: ContextSnapshot<'a>) {
self.evaluator = snap.evaluator;
self.visited_states = snap.visited_states;
self.last_case_index = snap.last_case_index;
self.visited_case_counts = snap.visited_case_counts;
self.last_case_state = snap.last_case_state;
}
pub fn case_counts_snapshot(&self) -> Vec<u8> {
self.visited_case_counts.clone()
}
pub fn set_case_counts(&mut self, counts: Vec<u8>) {
self.visited_case_counts = counts;
}
pub fn enter_expr_switch_false_arm(&mut self) -> (usize, bool) {
let saved = (self.total_visits, self.no_fork);
self.total_visits = 0;
self.no_fork = true;
saved
}
pub fn exit_expr_switch_false_arm(&mut self, saved: (usize, bool)) {
self.total_visits = saved.0;
self.no_fork = saved.1;
}
pub fn take_dispatcher(&mut self) -> Option<TracedDispatcher> {
self.dispatcher.take()
}
pub fn take_state_tainted(&mut self) -> BitSet {
mem::take(&mut self.state_tainted)
}
pub fn unvisited_handler_blocks(&self) -> Vec<usize> {
self.ssa
.exception_handlers()
.iter()
.filter_map(|h| h.handler_start_block)
.filter(|&block| {
block < self.ssa.block_count()
&& !self.visited_states.iter().any(|(b, _)| *b == block)
})
.collect()
}
pub fn fork_for_handler(&self, node_id_offset: usize) -> Self {
let case_count_len = self.visited_case_counts.len();
Self {
ssa: self.ssa,
evaluator: SsaEvaluator::new(self.ssa, self.evaluator.pointer_size()),
assembly: self.assembly,
dispatcher: self.dispatcher.clone(),
state_tainted: self.state_tainted.clone(),
next_node_id: node_id_offset,
total_visits: 0,
visited_states: BTreeSet::new(),
last_case_index: usize::MAX,
visited_case_counts: vec![0u8; case_count_len],
last_case_state: vec![None; case_count_len],
max_block_visits: self.max_block_visits,
max_tree_depth: self.max_tree_depth,
other_dispatcher_blocks: self.other_dispatcher_blocks.clone(),
no_fork: false,
}
}
pub fn advance_node_id(&mut self, new_id: usize) {
self.next_node_id = new_id;
}
pub fn max_block_visits(&self) -> usize {
self.max_block_visits
}
}
pub struct ContextSnapshot<'a> {
evaluator: SsaEvaluator<'a>,
visited_states: BTreeSet<(usize, i64)>,
last_case_index: usize,
visited_case_counts: Vec<u8>,
last_case_state: Vec<Option<i64>>,
}
impl<'a> ContextSnapshot<'a> {
pub fn clone_snapshot(&self) -> Self {
Self {
evaluator: self.evaluator.clone(),
visited_states: self.visited_states.clone(),
last_case_index: self.last_case_index,
visited_case_counts: self.visited_case_counts.clone(),
last_case_state: self.last_case_state.clone(),
}
}
}