use std::{
collections::{HashMap, HashSet},
sync::Arc,
};
use dashmap::DashSet;
use log::debug;
use crate::{
analysis::{ConstValue, SsaFunction, SsaOp, SsaVarId},
compiler::{CompilerContext, EventKind, ModificationScope, PassCapability, SsaPass},
deobfuscation::{utils::build_def_map, EmulationTemplatePool, ProcessCell},
emulation::{EmValue, EmulationProcess, HeapRef},
metadata::token::Token,
CilObject, Result,
};
pub struct OpaqueFieldPredicatePass {
lazy_process: ProcessCell,
template_pool: Arc<EmulationTemplatePool>,
needed_static_fields: DashSet<Token>,
affected_methods: DashSet<Token>,
processed_methods: DashSet<Token>,
sentinel_methods: HashMap<Token, Token>,
}
impl OpaqueFieldPredicatePass {
#[must_use]
pub fn new(
template_pool: Arc<EmulationTemplatePool>,
needed_static_fields: HashSet<Token>,
affected_methods: HashSet<Token>,
sentinel_methods: HashMap<Token, Token>,
) -> Self {
let needed = DashSet::new();
for token in &needed_static_fields {
needed.insert(*token);
}
let affected = DashSet::new();
for token in &affected_methods {
affected.insert(*token);
}
Self {
lazy_process: ProcessCell::new("opaque field"),
template_pool,
needed_static_fields: needed,
affected_methods: affected,
processed_methods: DashSet::new(),
sentinel_methods,
}
}
fn find_targeted_cctors(assembly: &CilObject, needed_fields: &DashSet<Token>) -> Vec<Token> {
let resolved_fields: DashSet<Token> = DashSet::new();
for token in needed_fields.iter() {
resolved_fields.insert(*token);
if let Some(resolved) = assembly.resolver().resolve_field(*token) {
resolved_fields.insert(resolved);
}
}
let registry = assembly.types();
let mut cctors = Vec::new();
for entry in registry.iter() {
let type_ref = entry.value();
let owns_needed_field = type_ref.fields.iter().any(|(_, field)| {
field.flags.is_static() && resolved_fields.contains(&field.token)
});
if owns_needed_field {
if let Some(cctor) = type_ref.cctor() {
if !cctors.contains(&cctor) {
debug!(
"Opaque field warmup: type {}.{} owns needed fields → .cctor 0x{:08X}",
type_ref.namespace,
type_ref.name,
cctor.value()
);
cctors.push(cctor);
}
}
}
}
cctors
}
fn create_process_from_pool(&self) -> Option<EmulationProcess> {
let assembly = self.template_pool.assembly()?;
let cctors = Self::find_targeted_cctors(&assembly, &self.needed_static_fields);
self.template_pool.fork_for_targeted_warmup(&cctors)
}
fn ensure_initialized(
&self,
) -> Result<std::sync::RwLockReadGuard<'_, Option<EmulationProcess>>> {
self.lazy_process
.ensure_initialized(|| self.create_process_from_pool(), |_| {})
}
}
struct FieldResolver<'a> {
process: &'a EmulationProcess,
assembly: &'a CilObject,
}
impl<'a> FieldResolver<'a> {
fn new(process: &'a EmulationProcess, assembly: &'a CilObject) -> Self {
Self { process, assembly }
}
fn get_static(&self, token: Token) -> Result<Option<EmValue>> {
if let Some(val) = self.process.get_static(token)? {
return Ok(Some(val));
}
if let Some(resolved) = self.assembly.resolver().resolve_field(token) {
return self.process.get_static(resolved);
}
Ok(None)
}
fn get_field(&self, heap_ref: HeapRef, token: Token) -> Option<EmValue> {
match self.process.address_space().get_field(heap_ref, token) {
Ok(val) => return Some(val),
Err(e) => {
debug!(
"OpaqueFields: field access failed for 0x{:08X} on heap ref: {e}",
token.value()
);
}
}
if let Some(resolved) = self.assembly.resolver().resolve_field(token) {
match self.process.address_space().get_field(heap_ref, resolved) {
Ok(val) => return Some(val),
Err(e) => {
debug!(
"OpaqueFields: field access failed for resolved 0x{:08X} on heap ref: {e}",
resolved.value()
);
}
}
}
None
}
fn resolve_chain(&self, static_token: Token, field_chain: &[Token]) -> Result<Option<EmValue>> {
if field_chain.is_empty() {
return Ok(None);
}
let Some(static_val) = self.get_static(static_token)? else {
return Ok(None);
};
let mut current_val = static_val;
for (i, &field_token) in field_chain.iter().enumerate() {
let is_last = i == field_chain.len() - 1;
match ¤t_val {
EmValue::ObjectRef(heap_ref) => {
let Some(field_val) = self.get_field(*heap_ref, field_token) else {
return Ok(None);
};
current_val = field_val;
}
_ if !is_last => return Ok(None),
_ => {}
}
}
Ok(Some(current_val))
}
fn resolve_sentinel(&self, sentinel_field_token: Token) -> Result<Option<bool>> {
let val = match self.get_static(sentinel_field_token)? {
Some(val) => val,
None => EmValue::Null, };
Ok(Some(val.is_null()))
}
}
fn trace_field_chain(
starting_op: &SsaOp,
defs: &HashMap<SsaVarId, &SsaOp>,
) -> Option<(Token, Vec<Token>)> {
const MAX_CHAIN_DEPTH: usize = 10;
let mut chain: Vec<Token> = Vec::new();
let mut current_op = starting_op;
for _ in 0..MAX_CHAIN_DEPTH {
match current_op {
SsaOp::LoadField { object, field, .. } => {
chain.push(field.token());
current_op = defs.get(object)?;
}
SsaOp::LoadStaticField { field, .. } => {
chain.reverse();
return Some((field.token(), chain));
}
_ => return None,
}
}
None }
impl SsaPass for OpaqueFieldPredicatePass {
fn name(&self) -> &'static str {
"opaque-field-predicate-removal"
}
fn description(&self) -> &'static str {
"Removes opaque predicates based on static field chains resolved via emulation"
}
fn modification_scope(&self) -> ModificationScope {
ModificationScope::CfgModifying
}
fn provides(&self) -> &[PassCapability] {
&[PassCapability::ResolvedStaticFields]
}
fn should_run(&self, method_token: Token, _ctx: &CompilerContext) -> bool {
self.affected_methods.contains(&method_token)
&& !self.processed_methods.contains(&method_token)
}
fn initialize(&mut self, _ctx: &CompilerContext) -> Result<()> {
let remaining = self.affected_methods.len() - self.processed_methods.len();
if remaining > 0 {
debug!(
"Opaque field predicate pass: {} unique static fields in {} remaining methods ({} already processed)",
self.needed_static_fields.len(),
remaining,
self.processed_methods.len(),
);
}
Ok(())
}
fn run_on_method(
&self,
ssa: &mut SsaFunction,
method_token: Token,
ctx: &CompilerContext,
assembly: &CilObject,
) -> Result<bool> {
let guard = self.ensure_initialized()?;
let Some(process) = guard.as_ref() else {
return Ok(false);
};
let resolver = FieldResolver::new(process, assembly);
let defs = build_def_map(ssa);
let mut replacements: Vec<(usize, usize, usize)> = Vec::new();
for (block_idx, block) in ssa.blocks().iter().enumerate() {
let Some(terminator) = block.terminator_op() else {
continue;
};
let (condition, true_target, false_target) = match terminator {
SsaOp::Branch {
condition,
true_target,
false_target,
} => (*condition, *true_target, *false_target),
_ => continue,
};
let Some(cond_def) = defs.get(&condition) else {
continue;
};
if let Some((static_token, field_chain)) = trace_field_chain(cond_def, &defs) {
if let Some(is_truthy) = resolver
.resolve_chain(static_token, &field_chain)?
.map(|v| v.to_bool_cil())
{
let (target, dropped) = if is_truthy {
(true_target, false_target)
} else {
(false_target, true_target)
};
replacements.push((block_idx, target, dropped));
continue;
}
}
if let SsaOp::Call { method, .. } = cond_def {
if let Some(sentinel_field) = self.sentinel_methods.get(&method.token()) {
if let Some(is_truthy) = resolver.resolve_sentinel(*sentinel_field)? {
let (target, dropped) = if is_truthy {
(true_target, false_target)
} else {
(false_target, true_target)
};
replacements.push((block_idx, target, dropped));
}
}
}
}
let mut sentinel_call_replacements: Vec<(usize, usize, SsaOp)> = Vec::new();
if !self.sentinel_methods.is_empty() {
for (block_idx, block) in ssa.blocks().iter().enumerate() {
for (instr_idx, instr) in block.instructions().iter().enumerate() {
let SsaOp::Call {
dest: Some(dest),
method,
..
} = instr.op()
else {
continue;
};
let Some(sentinel_field) = self.sentinel_methods.get(&method.token()) else {
continue;
};
if let Some(is_truthy) = resolver.resolve_sentinel(*sentinel_field)? {
let value = if is_truthy {
ConstValue::I32(1)
} else {
ConstValue::I32(0)
};
sentinel_call_replacements.push((
block_idx,
instr_idx,
SsaOp::Const { dest: *dest, value },
));
}
}
}
}
let mut const_replacements: Vec<(usize, usize, SsaOp)> = Vec::new();
let mut replaced_object_vars: HashSet<SsaVarId> = HashSet::new();
for (block_idx, block) in ssa.blocks().iter().enumerate() {
for (instr_idx, instr) in block.instructions().iter().enumerate() {
let SsaOp::LoadField { dest, object, .. } = instr.op() else {
continue;
};
let load_field_op: SsaOp = instr.op().clone();
let Some((static_token, field_chain)) = trace_field_chain(&load_field_op, &defs)
else {
continue;
};
let Some(const_val) = resolver
.resolve_chain(static_token, &field_chain)?
.and_then(|v| v.to_const_value())
else {
continue;
};
const_replacements.push((
block_idx,
instr_idx,
SsaOp::Const {
dest: *dest,
value: const_val,
},
));
replaced_object_vars.insert(*object);
}
}
let mut dead_static_loads: Vec<(usize, usize)> = Vec::new();
for (block_idx, block) in ssa.blocks().iter().enumerate() {
for (instr_idx, instr) in block.instructions().iter().enumerate() {
let SsaOp::LoadStaticField { dest, .. } = instr.op() else {
continue;
};
if !replaced_object_vars.contains(dest) {
continue;
}
if let Some(variable) = ssa.variable(*dest) {
let all_uses_replaced = variable.uses().iter().all(|use_site| {
const_replacements
.iter()
.any(|(b, i, _)| *b == use_site.block && *i == use_site.instruction)
});
if all_uses_replaced {
dead_static_loads.push((block_idx, instr_idx));
}
}
}
}
drop(guard);
for &(block_idx, target, dropped) in &replacements {
if let Some(block) = ssa.block_mut(block_idx) {
if let Some(last) = block.instructions_mut().last_mut() {
last.set_op(SsaOp::Jump { target });
ctx.events
.record(EventKind::OpaquePredicateRemoved)
.at(method_token, block_idx)
.message(format!(
"removed opaque field predicate → jump to block {target}"
));
}
}
if dropped != target {
if let Some(dropped_block) = ssa.block_mut(dropped) {
for phi in dropped_block.phi_nodes_mut() {
phi.retain_operands(|pred| pred != block_idx);
}
}
}
}
for (block_idx, instr_idx, new_op) in &sentinel_call_replacements {
if let Some(block) = ssa.block_mut(*block_idx) {
if let Some(instr) = block.instructions_mut().get_mut(*instr_idx) {
instr.set_op(new_op.clone());
ctx.events
.record(EventKind::OpaquePredicateRemoved)
.at(method_token, *block_idx)
.message("resolved sentinel null-check call → constant");
}
}
}
for (block_idx, instr_idx, new_op) in &const_replacements {
if let Some(block) = ssa.block_mut(*block_idx) {
if let Some(instr) = block.instructions_mut().get_mut(*instr_idx) {
instr.set_op(new_op.clone());
ctx.events
.record(EventKind::ConstantFolded)
.at(method_token, *block_idx)
.message("resolved opaque field load → constant");
}
}
}
for (block_idx, instr_idx) in &dead_static_loads {
if let Some(block) = ssa.block_mut(*block_idx) {
if let Some(instr) = block.instructions_mut().get_mut(*instr_idx) {
instr.set_op(SsaOp::Nop);
}
}
}
let changed = !replacements.is_empty()
|| !const_replacements.is_empty()
|| !sentinel_call_replacements.is_empty();
self.processed_methods.insert(method_token);
Ok(changed)
}
fn finalize(&mut self, _ctx: &CompilerContext) -> Result<()> {
self.lazy_process.clear()
}
}