use std::collections::HashSet;
use crate::{
analysis::{PhiTaintMode, SsaFunction, SsaOp, TaintConfig, TokenTaintBuilder},
compiler::{CompilerContext, EventKind, ModificationScope, SsaPass},
deobfuscation::utils::resolve_qualified_method_name,
metadata::token::Token,
CilObject, Result,
};
#[derive(Debug, Clone)]
pub enum SentinelCondition {
All,
Any,
AtLeast(usize),
}
impl SentinelCondition {
fn is_satisfied(&self, matched: usize, total: usize) -> bool {
match self {
Self::All => matched >= total,
Self::Any => matched >= 1,
Self::AtLeast(n) => matched >= *n,
}
}
}
pub struct SentinelTaintRemovalPass {
pass_name: &'static str,
pass_description: &'static str,
target_methods: HashSet<Token>,
sentinel_patterns: Vec<&'static str>,
condition: SentinelCondition,
}
impl SentinelTaintRemovalPass {
pub fn new(
pass_name: &'static str,
pass_description: &'static str,
target_methods: HashSet<Token>,
sentinel_patterns: Vec<&'static str>,
condition: SentinelCondition,
) -> Self {
Self {
pass_name,
pass_description,
target_methods,
sentinel_patterns,
condition,
}
}
}
impl SsaPass for SentinelTaintRemovalPass {
fn name(&self) -> &'static str {
self.pass_name
}
fn description(&self) -> &'static str {
self.pass_description
}
fn modification_scope(&self) -> ModificationScope {
ModificationScope::CfgModifying
}
fn should_run(&self, method_token: Token, _ctx: &CompilerContext) -> bool {
self.target_methods.is_empty() || self.target_methods.contains(&method_token)
}
fn run_on_method(
&self,
ssa: &mut SsaFunction,
method_token: Token,
ctx: &CompilerContext,
assembly: &CilObject,
) -> Result<bool> {
let (sentinel_tokens, distinct_count) =
find_sentinel_tokens(ssa, assembly, &self.sentinel_patterns);
if !self
.condition
.is_satisfied(distinct_count, self.sentinel_patterns.len())
{
return Ok(false);
}
if sentinel_tokens.is_empty() {
return Ok(false);
}
let taint = TokenTaintBuilder::new(sentinel_tokens)
.with_config(TaintConfig {
forward: true,
backward: false,
phi_mode: PhiTaintMode::NoPropagation,
max_iterations: 100,
})
.analyze(ssa);
if taint.tainted_instr_count() == 0 {
return Ok(false);
}
let tainted_instrs: HashSet<(usize, usize)> =
taint.tainted_instructions().iter().copied().collect();
let mut branch_redirects: Vec<(usize, usize, usize)> = Vec::new();
for &(block_idx, instr_idx) in taint.tainted_instructions() {
if let Some(block) = ssa.block(block_idx) {
if let Some(instr) = block.instruction(instr_idx) {
if instr.is_terminator() {
match instr.op() {
SsaOp::Branch {
true_target,
false_target,
..
}
| SsaOp::BranchCmp {
true_target,
false_target,
..
} => {
let target = choose_clean_target(
ssa,
*true_target,
*false_target,
&tainted_instrs,
);
branch_redirects.push((block_idx, instr_idx, target));
}
_ => {}
}
}
}
}
}
let mut neutralized = 0usize;
for &(block_idx, instr_idx) in taint.tainted_instructions() {
if let Some(block) = ssa.block_mut(block_idx) {
if let Some(instr) = block.instruction_mut(instr_idx) {
match instr.op() {
SsaOp::Call { method, .. } | SsaOp::CallVirt { method, .. } => {
ctx.neutralized_tokens.insert(method.token());
}
_ => {}
}
if instr.is_terminator() {
if let Some(&(_, _, target)) = branch_redirects
.iter()
.find(|&&(bi, ii, _)| bi == block_idx && ii == instr_idx)
{
instr.set_op(SsaOp::Jump { target });
neutralized += 1;
}
} else {
instr.set_op(SsaOp::Nop);
neutralized += 1;
}
}
}
}
let mut tainted_phis: Vec<(usize, usize)> = taint.tainted_phis().iter().copied().collect();
tainted_phis.sort_by(|a, b| b.cmp(a));
for (block_idx, phi_idx) in tainted_phis {
if let Some(block) = ssa.block_mut(block_idx) {
if phi_idx < block.phi_nodes().len() {
block.phi_nodes_mut().remove(phi_idx);
neutralized += 1;
}
}
}
if neutralized > 0 {
ctx.events
.record(EventKind::InstructionRemoved)
.method(method_token)
.message(format!(
"{}: removed {neutralized} tainted instructions",
self.pass_name,
));
}
Ok(neutralized > 0)
}
}
fn find_sentinel_tokens(
ssa: &SsaFunction,
assembly: &CilObject,
sentinel_patterns: &[&str],
) -> (HashSet<Token>, usize) {
let mut tokens = HashSet::new();
let mut matched_patterns: HashSet<usize> = HashSet::new();
for block in ssa.blocks() {
for instr in block.instructions() {
let method_token = match instr.op() {
SsaOp::Call { method, .. } | SsaOp::CallVirt { method, .. } => method.token(),
_ => continue,
};
if let Some(name) = resolve_qualified_method_name(assembly, method_token) {
for (idx, pattern) in sentinel_patterns.iter().enumerate() {
if name.contains(pattern) {
tokens.insert(method_token);
matched_patterns.insert(idx);
}
}
}
}
}
(tokens, matched_patterns.len())
}
fn choose_clean_target(
ssa: &SsaFunction,
true_target: usize,
false_target: usize,
tainted_instrs: &HashSet<(usize, usize)>,
) -> usize {
let true_tainted = ssa.block(true_target).is_some_and(|b| {
!b.instructions().is_empty() && tainted_instrs.contains(&(true_target, 0))
});
let false_tainted = ssa.block(false_target).is_some_and(|b| {
!b.instructions().is_empty() && tainted_instrs.contains(&(false_target, 0))
});
match (true_tainted, false_tainted) {
(true, false) => false_target,
(false, true) => true_target,
_ => true_target,
}
}