use std::{
collections::{HashMap, HashSet},
sync::Arc,
};
use crate::{
analysis::{find_token_dependencies, SsaFunction, SsaOp},
compiler::{CompilerContext, EventKind, SsaPass},
metadata::token::Token,
CilObject, Result,
};
enum InstrAction {
Nop,
Jump(usize),
}
pub struct NeutralizationPass<'a> {
removed_tokens: &'a HashSet<Token>,
}
impl<'a> NeutralizationPass<'a> {
#[must_use]
pub fn new(removed_tokens: &'a HashSet<Token>) -> Self {
Self { removed_tokens }
}
fn is_block_fully_tainted(
ssa: &SsaFunction,
block_idx: usize,
tainted_instrs: &HashSet<(usize, usize)>,
tainted_phis: &HashSet<(usize, usize)>,
) -> bool {
if let Some(block) = ssa.block(block_idx) {
for phi_idx in 0..block.phi_nodes().len() {
if !tainted_phis.contains(&(block_idx, phi_idx)) {
return false;
}
}
for instr_idx in 0..block.instructions().len() {
if !tainted_instrs.contains(&(block_idx, instr_idx)) {
return false;
}
}
true
} else {
true }
}
fn find_blocks_reaching_exit(ssa: &SsaFunction) -> HashSet<usize> {
let mut predecessors: HashMap<usize, Vec<usize>> = HashMap::new();
for (block_idx, block) in ssa.blocks().iter().enumerate() {
for succ in block.successors() {
predecessors.entry(succ).or_default().push(block_idx);
}
}
let mut can_reach_exit = HashSet::new();
let mut worklist: Vec<usize> = Vec::new();
for (block_idx, block) in ssa.blocks().iter().enumerate() {
for instr in block.instructions() {
if matches!(instr.op(), SsaOp::Return { .. } | SsaOp::Throw { .. }) {
worklist.push(block_idx);
break;
}
}
}
while let Some(block_idx) = worklist.pop() {
if can_reach_exit.insert(block_idx) {
if let Some(preds) = predecessors.get(&block_idx) {
worklist.extend(preds.iter().filter(|p| !can_reach_exit.contains(p)));
}
}
}
can_reach_exit
}
fn choose_branch_target(
ssa: &SsaFunction,
true_target: usize,
false_target: usize,
tainted_instrs: &HashSet<(usize, usize)>,
tainted_phis: &HashSet<(usize, usize)>,
can_reach_exit: &HashSet<usize>,
) -> usize {
let true_tainted =
Self::is_block_fully_tainted(ssa, true_target, tainted_instrs, tainted_phis);
let false_tainted =
Self::is_block_fully_tainted(ssa, false_target, tainted_instrs, tainted_phis);
match (true_tainted, false_tainted) {
(true, false) => false_target, (false, true) => true_target, _ => {
let true_reaches = can_reach_exit.contains(&true_target);
let false_reaches = can_reach_exit.contains(&false_target);
match (true_reaches, false_reaches) {
(false, true) => false_target,
_ => true_target, }
}
}
}
fn neutralize_method(&self, ssa: &mut SsaFunction) -> usize {
if self.removed_tokens.is_empty() {
return 0;
}
let taint = find_token_dependencies(ssa, self.removed_tokens.iter().copied());
if taint.tainted_instr_count() == 0 && taint.tainted_phis().is_empty() {
return 0;
}
let tainted_instrs: HashSet<(usize, usize)> =
taint.tainted_instructions().iter().copied().collect();
let tainted_phis_set: HashSet<(usize, usize)> =
taint.tainted_phis().iter().copied().collect();
let can_reach_exit = Self::find_blocks_reaching_exit(ssa);
let mut count = 0;
let mut tainted_phis: Vec<(usize, usize)> = taint.tainted_phis().iter().copied().collect();
tainted_phis.sort_by(|a, b| b.cmp(a));
for (block_idx, phi_idx) in tainted_phis {
if let Some(block) = ssa.block_mut(block_idx) {
if phi_idx < block.phi_nodes().len() {
block.phi_nodes_mut().remove(phi_idx);
count += 1;
}
}
}
let mut actions: Vec<(usize, usize, InstrAction)> = Vec::new();
for &(block_idx, instr_idx) in taint.tainted_instructions() {
if let Some(block) = ssa.block(block_idx) {
if let Some(instr) = block.instructions().get(instr_idx) {
let action = if instr.is_terminator() {
match instr.op() {
SsaOp::Branch {
true_target,
false_target,
..
}
| SsaOp::BranchCmp {
true_target,
false_target,
..
} => {
let target = Self::choose_branch_target(
ssa,
*true_target,
*false_target,
&tainted_instrs,
&tainted_phis_set,
&can_reach_exit,
);
Some(InstrAction::Jump(target))
}
_ => None,
}
} else {
Some(InstrAction::Nop)
};
if let Some(action) = action {
actions.push((block_idx, instr_idx, action));
}
}
}
}
for (block_idx, instr_idx, action) in actions {
if let Some(block) = ssa.block_mut(block_idx) {
if let Some(instr) = block.instruction_mut(instr_idx) {
match action {
InstrAction::Nop => instr.set_op(SsaOp::Nop),
InstrAction::Jump(target) => instr.set_op(SsaOp::Jump { target }),
}
count += 1;
}
}
}
count
}
}
impl SsaPass for NeutralizationPass<'_> {
fn name(&self) -> &'static str {
"neutralization"
}
fn description(&self) -> &'static str {
"Removes instructions referencing removed protection infrastructure"
}
fn run_on_method(
&self,
ssa: &mut SsaFunction,
method_token: Token,
ctx: &CompilerContext,
_assembly: &Arc<CilObject>,
) -> Result<bool> {
let neutralized = self.neutralize_method(ssa);
if neutralized > 0 {
ctx.events
.record(EventKind::InstructionRemoved)
.method(method_token)
.message(format!(
"Neutralized {neutralized} instructions referencing removed tokens"
));
}
Ok(neutralized > 0)
}
}
#[cfg(test)]
mod tests {
use std::collections::HashSet;
use crate::{
analysis::{ConstValue, MethodRef, SsaBlock, SsaFunction, SsaInstruction, SsaOp, SsaVarId},
deobfuscation::passes::NeutralizationPass,
metadata::token::Token,
};
fn create_test_ssa() -> (SsaFunction, Token, Token) {
let mut ssa = SsaFunction::new(0, 1);
let method_token = Token::new(0x06000001);
let field_token = Token::new(0x04000001);
let v0 = SsaVarId::new();
let v1 = SsaVarId::new();
let v2 = SsaVarId::new();
let mut b0 = SsaBlock::new(0);
b0.add_instruction(SsaInstruction::synthetic(SsaOp::Const {
dest: v0,
value: ConstValue::I32(42),
}));
let method_ref = MethodRef::new(method_token);
b0.add_instruction(SsaInstruction::synthetic(SsaOp::Call {
dest: Some(v1),
method: method_ref,
args: vec![],
}));
b0.add_instruction(SsaInstruction::synthetic(SsaOp::Add {
dest: v2,
left: v0,
right: v1,
}));
b0.add_instruction(SsaInstruction::synthetic(SsaOp::Return { value: Some(v2) }));
ssa.add_block(b0);
(ssa, method_token, field_token)
}
#[test]
fn test_neutralization_removes_tainted_instructions() {
let (mut ssa, method_token, _) = create_test_ssa();
let mut removed = HashSet::new();
removed.insert(method_token);
let pass = NeutralizationPass::new(&removed);
let count = pass.neutralize_method(&mut ssa);
assert!(
count >= 2,
"Expected at least 2 instructions neutralized, got {}",
count
);
let block = ssa.block(0).unwrap();
let call_instr = &block.instructions()[1];
assert!(matches!(call_instr.op(), SsaOp::Nop));
let add_instr = &block.instructions()[2];
assert!(matches!(add_instr.op(), SsaOp::Nop));
let ret_instr = &block.instructions()[3];
assert!(matches!(ret_instr.op(), SsaOp::Return { .. }));
}
#[test]
fn test_neutralization_preserves_unrelated_code() {
let (mut ssa, _, field_token) = create_test_ssa();
let mut removed = HashSet::new();
removed.insert(field_token);
let pass = NeutralizationPass::new(&removed);
let count = pass.neutralize_method(&mut ssa);
assert_eq!(count, 0);
let block = ssa.block(0).unwrap();
assert!(!matches!(block.instructions()[0].op(), SsaOp::Nop));
assert!(!matches!(block.instructions()[1].op(), SsaOp::Nop));
assert!(!matches!(block.instructions()[2].op(), SsaOp::Nop));
}
}