use std::collections::HashSet;
use crate::{
analysis::{ConstValue, SsaFunction, SsaOp, SsaVarId, TypeRef},
compiler::{CompilerContext, EventKind, ModificationScope, SsaPass},
metadata::token::Token,
CilObject, Result,
};
pub struct TokenResolverPass {
accessor_tokens: HashSet<Token>,
}
impl TokenResolverPass {
#[must_use]
pub fn new(accessor_tokens: impl IntoIterator<Item = Token>) -> Self {
Self {
accessor_tokens: accessor_tokens.into_iter().collect(),
}
}
}
impl SsaPass for TokenResolverPass {
fn name(&self) -> &'static str {
"netreactor-token-resolver"
}
fn description(&self) -> &'static str {
"Folds NR anti-tamper metadata-token accessor calls back to ldtoken"
}
fn modification_scope(&self) -> ModificationScope {
ModificationScope::InstructionsOnly
}
fn run_on_method(
&self,
ssa: &mut SsaFunction,
_method_token: Token,
ctx: &CompilerContext,
_assembly: &CilObject,
) -> Result<bool> {
if self.accessor_tokens.is_empty() {
return Ok(false);
}
let constants = ssa.find_constants();
let mut replacements: Vec<(usize, usize, Token, SsaVarId)> = Vec::new();
for (block_idx, block) in ssa.blocks().iter().enumerate() {
for (instr_idx, instr) in block.instructions().iter().enumerate() {
if let SsaOp::Call { dest, method, args } = instr.op() {
if !self.accessor_tokens.contains(&method.token()) {
continue;
}
if args.len() != 1 {
continue;
}
let Some(dest_var) = dest else {
continue;
};
let Some(const_val) = constants.get(&args[0]) else {
continue;
};
let raw_token = match const_val {
ConstValue::I32(v) => *v as u32,
ConstValue::U32(v) => *v,
_ => continue,
};
if raw_token == 0 {
continue;
}
replacements.push((block_idx, instr_idx, Token::new(raw_token), *dest_var));
}
}
}
if replacements.is_empty() {
return Ok(false);
}
for (block_idx, instr_idx, token, dest) in &replacements {
ssa.replace_instruction_op(
*block_idx,
*instr_idx,
SsaOp::LoadToken {
dest: *dest,
token: TypeRef::new(*token),
},
);
ctx.events.record(EventKind::ValueResolved);
}
Ok(true)
}
}