use crate::analysis::{SsaFunction, SsaOp, SsaVarId};
#[derive(Debug, Clone)]
pub struct Dispatcher {
pub block: usize,
pub switch_var: SsaVarId,
pub cases: Vec<usize>,
pub default: usize,
pub state_phi: Option<SsaVarId>,
pub transform: StateTransform,
pub confidence: f64,
}
impl Dispatcher {
#[must_use]
pub fn new(block: usize, switch_var: SsaVarId, cases: Vec<usize>, default: usize) -> Self {
Self {
block,
switch_var,
cases,
default,
state_phi: None,
transform: StateTransform::Identity,
confidence: 0.0,
}
}
#[must_use]
pub fn with_state_phi(mut self, phi: SsaVarId) -> Self {
self.state_phi = Some(phi);
self
}
#[must_use]
pub fn with_transform(mut self, transform: StateTransform) -> Self {
self.transform = transform;
self
}
#[must_use]
pub fn with_confidence(mut self, confidence: f64) -> Self {
self.confidence = confidence;
self
}
#[must_use]
pub fn case_count(&self) -> usize {
self.cases.len()
}
#[must_use]
pub fn all_targets(&self) -> Vec<usize> {
let mut targets = self.cases.clone();
if !targets.contains(&self.default) {
targets.push(self.default);
}
targets
}
#[must_use]
pub fn target_for_state(&self, state: i32) -> usize {
#[allow(clippy::cast_sign_loss)]
let index = self.transform.apply(state) as usize;
if index < self.cases.len() {
self.cases[index]
} else {
self.default
}
}
#[must_use]
pub fn to_info(&self) -> DispatcherInfo {
DispatcherInfo::Switch {
block: self.block,
switch_var: self.switch_var,
cases: self.cases.clone(),
default: self.default,
transform: self.transform.clone(),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Default)]
pub enum StateTransform {
#[default]
Identity,
Modulo(u32),
XorModulo {
xor_key: i32,
divisor: u32,
},
And(u32),
Shr(u32),
}
impl StateTransform {
#[must_use]
pub fn apply(&self, state: i32) -> i32 {
match self {
Self::Identity => state,
Self::Modulo(n) => {
let u_state = state.cast_unsigned();
(u_state % n).cast_signed()
}
Self::XorModulo { xor_key, divisor } => {
let xored = state ^ xor_key;
let u_xored = xored.cast_unsigned();
(u_xored % divisor).cast_signed()
}
Self::And(mask) => state & (*mask).cast_signed(),
Self::Shr(amount) => {
let u_state = state.cast_unsigned();
(u_state >> amount).cast_signed()
}
}
}
#[must_use]
pub fn modulo_divisor(&self) -> Option<u32> {
match self {
Self::Modulo(n) => Some(*n),
Self::XorModulo { divisor, .. } => Some(*divisor),
_ => None,
}
}
#[must_use]
pub fn xor_key(&self) -> Option<i32> {
match self {
Self::XorModulo { xor_key, .. } => Some(*xor_key),
_ => None,
}
}
#[must_use]
pub fn is_identity(&self) -> bool {
matches!(self, Self::Identity)
}
#[must_use]
pub fn is_xor_modulo(&self) -> bool {
matches!(self, Self::XorModulo { .. })
}
}
#[derive(Debug, Clone)]
pub enum DispatcherInfo {
Switch {
block: usize,
switch_var: SsaVarId,
cases: Vec<usize>,
default: usize,
transform: StateTransform,
},
IfElseChain {
head_block: usize,
state_var: SsaVarId,
comparisons: Vec<(i32, usize)>,
default: Option<usize>,
},
ComputedJump {
block: usize,
target_var: SsaVarId,
jump_table: Vec<usize>,
base_address: Option<u64>,
},
}
impl DispatcherInfo {
#[must_use]
pub fn block(&self) -> usize {
match self {
Self::Switch { block, .. } | Self::ComputedJump { block, .. } => *block,
Self::IfElseChain { head_block, .. } => *head_block,
}
}
#[must_use]
pub fn case_count(&self) -> usize {
match self {
Self::Switch { cases, .. } => cases.len(),
Self::IfElseChain { comparisons, .. } => comparisons.len(),
Self::ComputedJump { jump_table, .. } => jump_table.len(),
}
}
#[must_use]
pub fn target_for_case(&self, case_value: i32) -> Option<usize> {
match self {
Self::Switch {
cases,
default,
transform,
..
} => {
#[allow(clippy::cast_sign_loss)]
let index = transform.apply(case_value) as usize;
if index < cases.len() {
Some(cases[index])
} else {
Some(*default)
}
}
Self::IfElseChain {
comparisons,
default,
..
} => {
for (cmp_val, target) in comparisons {
if *cmp_val == case_value {
return Some(*target);
}
}
*default
}
Self::ComputedJump { jump_table, .. } => {
#[allow(clippy::cast_sign_loss)]
let index = case_value as usize;
jump_table.get(index).copied()
}
}
}
#[must_use]
pub fn all_targets(&self) -> Vec<usize> {
match self {
Self::Switch { cases, default, .. } => {
let mut targets: Vec<usize> = cases.clone();
if !targets.contains(default) {
targets.push(*default);
}
targets
}
Self::IfElseChain {
comparisons,
default,
..
} => {
let mut targets: Vec<usize> = comparisons.iter().map(|(_, t)| *t).collect();
if let Some(def) = default {
if !targets.contains(def) {
targets.push(*def);
}
}
targets
}
Self::ComputedJump { jump_table, .. } => jump_table.clone(),
}
}
#[must_use]
pub fn transform(&self) -> StateTransform {
match self {
Self::Switch { transform, .. } => transform.clone(),
Self::IfElseChain { .. } | Self::ComputedJump { .. } => StateTransform::Identity,
}
}
#[must_use]
pub fn dispatch_var(&self) -> SsaVarId {
match self {
Self::Switch { switch_var, .. } => *switch_var,
Self::IfElseChain { state_var, .. } => *state_var,
Self::ComputedJump { target_var, .. } => *target_var,
}
}
#[must_use]
pub fn is_computed_jump(&self) -> bool {
matches!(self, Self::ComputedJump { .. })
}
#[must_use]
pub fn base_address(&self) -> Option<u64> {
match self {
Self::ComputedJump { base_address, .. } => *base_address,
_ => None,
}
}
}
pub fn analyze_switch_dispatcher(ssa: &SsaFunction, block_idx: usize) -> Option<DispatcherInfo> {
let block = ssa.block(block_idx)?;
let switch_instr = block
.instructions()
.iter()
.rev()
.find(|i| matches!(i.op(), SsaOp::Switch { .. }))?;
let (switch_var, targets, default) = match switch_instr.op() {
SsaOp::Switch {
value,
targets,
default,
} => (*value, targets.clone(), *default),
_ => return None,
};
let transform = analyze_switch_transform(ssa, switch_var);
Some(DispatcherInfo::Switch {
block: block_idx,
switch_var,
cases: targets,
default,
transform,
})
}
fn analyze_switch_transform(ssa: &SsaFunction, switch_var: SsaVarId) -> StateTransform {
let Some(def) = ssa.get_definition(switch_var) else {
return StateTransform::Identity;
};
match def {
SsaOp::Rem {
left,
right,
unsigned: true,
..
} => {
let Some(SsaOp::Const { value, .. }) = ssa.get_definition(*right) else {
return StateTransform::Identity;
};
let Some(divisor) = value.as_i32() else {
return StateTransform::Identity;
};
if let Some(xor_key) = find_xor_key(ssa, *left) {
return StateTransform::XorModulo {
xor_key,
divisor: divisor.cast_unsigned(),
};
}
StateTransform::Modulo(divisor.cast_unsigned())
}
SsaOp::And { right, .. } => {
if let Some(SsaOp::Const { value, .. }) = ssa.get_definition(*right) {
if let Some(mask) = value.as_i32() {
return StateTransform::And(mask.cast_unsigned());
}
}
StateTransform::Identity
}
SsaOp::Shr { amount, .. } => {
if let Some(SsaOp::Const { value, .. }) = ssa.get_definition(*amount) {
if let Some(shift) = value.as_i32() {
return StateTransform::Shr(shift.cast_unsigned());
}
}
StateTransform::Identity
}
_ => StateTransform::Identity,
}
}
fn find_xor_key(ssa: &SsaFunction, var: SsaVarId) -> Option<i32> {
let def = ssa.get_definition(var)?;
match def {
SsaOp::Xor { left, right, .. } => {
if let Some(SsaOp::Const { value, .. }) = ssa.get_definition(*right) {
if let Some(key) = value.as_i32() {
return Some(key);
}
}
if let Some(SsaOp::Const { value, .. }) = ssa.get_definition(*left) {
if let Some(key) = value.as_i32() {
return Some(key);
}
}
None
}
SsaOp::Copy { src, .. } => find_xor_key(ssa, *src),
_ => None,
}
}
#[cfg(test)]
mod tests {
use crate::{
analysis::SsaVarId,
deobfuscation::passes::unflattening::dispatcher::{DispatcherInfo, StateTransform},
};
#[test]
fn test_state_transform_identity() {
let transform = StateTransform::Identity;
assert_eq!(transform.apply(42), 42);
assert_eq!(transform.apply(-5), -5);
assert!(transform.is_identity());
}
#[test]
fn test_state_transform_modulo() {
let transform = StateTransform::Modulo(7);
assert_eq!(transform.apply(10), 3);
assert_eq!(transform.apply(7), 0);
assert_eq!(transform.apply(0), 0);
assert_eq!(transform.modulo_divisor(), Some(7));
assert!(!transform.is_identity());
}
#[test]
fn test_state_transform_and() {
let transform = StateTransform::And(0xFF);
assert_eq!(transform.apply(0x12345678_u32 as i32), 0x78);
assert_eq!(transform.apply(255), 255);
}
#[test]
fn test_state_transform_shr() {
let transform = StateTransform::Shr(8);
assert_eq!(transform.apply(0x1234), 0x12);
}
#[test]
fn test_state_transform_xor_modulo() {
let transform = StateTransform::XorModulo {
xor_key: -576502913_i32,
divisor: 7,
};
assert!(!transform.is_identity());
assert!(transform.is_xor_modulo());
assert_eq!(transform.modulo_divisor(), Some(7));
assert_eq!(transform.xor_key(), Some(-576502913_i32));
let state = -781784372_i32;
let xored = state ^ -576502913_i32;
let expected = ((xored as u32) % 7) as i32;
assert_eq!(transform.apply(state), expected);
}
#[test]
fn test_dispatcher_info_switch() {
let dispatcher = DispatcherInfo::Switch {
block: 0,
switch_var: SsaVarId::new(),
cases: vec![1, 2, 3, 4, 5],
default: 6,
transform: StateTransform::Modulo(5),
};
assert_eq!(dispatcher.block(), 0);
assert_eq!(dispatcher.case_count(), 5);
assert_eq!(dispatcher.target_for_case(0), Some(1)); assert_eq!(dispatcher.target_for_case(1), Some(2)); assert_eq!(dispatcher.target_for_case(5), Some(1)); assert_eq!(dispatcher.target_for_case(7), Some(3)); }
#[test]
fn test_dispatcher_info_switch_with_xor_modulo() {
let dispatcher = DispatcherInfo::Switch {
block: 1,
switch_var: SsaVarId::new(),
cases: vec![2, 3, 4, 5, 6, 7, 8], default: 9,
transform: StateTransform::XorModulo {
xor_key: -576502913_i32,
divisor: 7,
},
};
assert_eq!(dispatcher.case_count(), 7);
let state = -781784372_i32;
let xored = state ^ -576502913_i32;
let expected_case_idx = ((xored as u32) % 7) as usize;
assert_eq!(
dispatcher.target_for_case(state),
Some(dispatcher.all_targets()[expected_case_idx])
);
}
#[test]
fn test_dispatcher_info_if_else() {
let dispatcher = DispatcherInfo::IfElseChain {
head_block: 0,
state_var: SsaVarId::new(),
comparisons: vec![(10, 1), (20, 2), (30, 3)],
default: Some(4),
};
assert_eq!(dispatcher.block(), 0);
assert_eq!(dispatcher.case_count(), 3);
assert_eq!(dispatcher.target_for_case(10), Some(1));
assert_eq!(dispatcher.target_for_case(20), Some(2));
assert_eq!(dispatcher.target_for_case(99), Some(4)); }
#[test]
fn test_dispatcher_all_targets() {
let dispatcher = DispatcherInfo::Switch {
block: 0,
switch_var: SsaVarId::new(),
cases: vec![1, 2, 2, 3], default: 4,
transform: StateTransform::Identity,
};
let targets = dispatcher.all_targets();
assert!(targets.contains(&1));
assert!(targets.contains(&2));
assert!(targets.contains(&3));
assert!(targets.contains(&4));
}
#[test]
fn test_dispatcher_info_computed_jump() {
let dispatcher = DispatcherInfo::ComputedJump {
block: 5,
target_var: SsaVarId::new(),
jump_table: vec![10, 20, 30, 40],
base_address: Some(0x1000),
};
assert_eq!(dispatcher.block(), 5);
assert_eq!(dispatcher.case_count(), 4);
assert!(dispatcher.is_computed_jump());
assert_eq!(dispatcher.base_address(), Some(0x1000));
assert_eq!(dispatcher.target_for_case(0), Some(10));
assert_eq!(dispatcher.target_for_case(1), Some(20));
assert_eq!(dispatcher.target_for_case(3), Some(40));
assert_eq!(dispatcher.target_for_case(4), None);
let targets = dispatcher.all_targets();
assert_eq!(targets, vec![10, 20, 30, 40]);
}
}