use std::collections::HashSet;
use crate::{
analysis::{ConstValue, SsaFunction, SsaOp},
deobfuscation::{
renamer::{
context::{OpcodeProfile, PhaseInfo},
features,
},
utils::is_obfuscated_name,
},
CilObject,
};
pub fn decompose_method(
ssa: &SsaFunction,
assembly: &CilObject,
small_threshold: usize,
) -> Vec<PhaseInfo> {
let block_count = ssa.blocks().len();
if block_count == 0 {
return Vec::new();
}
let total_instructions = ssa.instruction_count();
if total_instructions <= small_threshold {
let call_targets = features::collect_call_targets(ssa, assembly);
let profile = features::build_opcode_profile(ssa);
return vec![PhaseInfo {
label: String::new(), call_targets,
opcode_profile: Some(profile),
structure: Some("linear".to_string()),
}];
}
let (boundaries, loop_headers) = find_phase_boundaries(ssa, assembly);
build_phases_from_boundaries(ssa, assembly, &boundaries, &loop_headers)
}
pub fn build_call_site_skeleton(ssa: &SsaFunction, assembly: &CilObject) -> Option<String> {
let mut lines = Vec::new();
for (_block_idx, _instr_idx, instr) in ssa.iter_instructions() {
match instr.op() {
SsaOp::Call { method, args, .. } | SsaOp::CallVirt { method, args, .. } => {
let name = assembly
.resolve_method_name(method.token())
.unwrap_or_else(|| format!("method_0x{:08X}", method.token().value()));
let arg_list: Vec<String> = args
.iter()
.enumerate()
.map(|(i, _)| format!("var_{i}"))
.collect();
lines.push(format!(" {name}({});", arg_list.join(", ")));
}
SsaOp::NewObj { ctor, args, .. } => {
let name = assembly
.resolve_method_name(ctor.token())
.unwrap_or_else(|| "Type_0".to_string());
let arg_list: Vec<String> = args
.iter()
.enumerate()
.map(|(i, _)| format!("var_{i}"))
.collect();
lines.push(format!(" new {name}({});", arg_list.join(", ")));
}
SsaOp::LoadField {
dest,
object,
field,
} => {
let name = assembly
.resolve_method_name(field.token())
.unwrap_or_else(|| format!("field_0x{:08X}", field.token().value()));
lines.push(format!(
" var_{} = var_{}.{name};",
dest.index(),
object.index()
));
}
SsaOp::StoreField {
object,
field,
value,
} => {
let name = assembly
.resolve_method_name(field.token())
.unwrap_or_else(|| format!("field_0x{:08X}", field.token().value()));
lines.push(format!(
" var_{}.{name} = var_{};",
object.index(),
value.index()
));
}
SsaOp::LoadStaticField { dest, field, .. } => {
let name = assembly
.resolve_method_name(field.token())
.unwrap_or_else(|| format!("field_0x{:08X}", field.token().value()));
lines.push(format!(" var_{} = {name};", dest.index()));
}
SsaOp::StoreStaticField { field, value, .. } => {
let name = assembly
.resolve_method_name(field.token())
.unwrap_or_else(|| format!("field_0x{:08X}", field.token().value()));
lines.push(format!(" {name} = var_{};", value.index()));
}
SsaOp::Add { dest, left, right }
| SsaOp::AddOvf {
dest, left, right, ..
} => {
lines.push(format!(
" var_{} = var_{} + var_{};",
dest.index(),
left.index(),
right.index()
));
}
SsaOp::Sub { dest, left, right }
| SsaOp::SubOvf {
dest, left, right, ..
} => {
lines.push(format!(
" var_{} = var_{} - var_{};",
dest.index(),
left.index(),
right.index()
));
}
SsaOp::Mul { dest, left, right }
| SsaOp::MulOvf {
dest, left, right, ..
} => {
lines.push(format!(
" var_{} = var_{} * var_{};",
dest.index(),
left.index(),
right.index()
));
}
SsaOp::Div {
dest, left, right, ..
} => {
lines.push(format!(
" var_{} = var_{} / var_{};",
dest.index(),
left.index(),
right.index()
));
}
SsaOp::Rem {
dest, left, right, ..
} => {
lines.push(format!(
" var_{} = var_{} % var_{};",
dest.index(),
left.index(),
right.index()
));
}
SsaOp::Neg { dest, operand } => {
lines.push(format!(
" var_{} = -var_{};",
dest.index(),
operand.index()
));
}
SsaOp::And { dest, left, right } => {
lines.push(format!(
" var_{} = var_{} & var_{};",
dest.index(),
left.index(),
right.index()
));
}
SsaOp::Or { dest, left, right } => {
lines.push(format!(
" var_{} = var_{} | var_{};",
dest.index(),
left.index(),
right.index()
));
}
SsaOp::Xor { dest, left, right } => {
lines.push(format!(
" var_{} = var_{} ^ var_{};",
dest.index(),
left.index(),
right.index()
));
}
SsaOp::Not { dest, operand } => {
lines.push(format!(
" var_{} = ~var_{};",
dest.index(),
operand.index()
));
}
SsaOp::Shl {
dest,
value,
amount,
} => {
lines.push(format!(
" var_{} = var_{} << var_{};",
dest.index(),
value.index(),
amount.index()
));
}
SsaOp::Shr {
dest,
value,
amount,
..
} => {
lines.push(format!(
" var_{} = var_{} >> var_{};",
dest.index(),
value.index(),
amount.index()
));
}
SsaOp::Ceq { dest, left, right } => {
lines.push(format!(
" var_{} = (var_{} == var_{});",
dest.index(),
left.index(),
right.index()
));
}
SsaOp::Clt {
dest, left, right, ..
} => {
lines.push(format!(
" var_{} = (var_{} < var_{});",
dest.index(),
left.index(),
right.index()
));
}
SsaOp::Cgt {
dest, left, right, ..
} => {
lines.push(format!(
" var_{} = (var_{} > var_{});",
dest.index(),
left.index(),
right.index()
));
}
SsaOp::Conv {
dest,
operand,
target,
..
} => {
lines.push(format!(
" var_{} = ({target})var_{};",
dest.index(),
operand.index()
));
}
SsaOp::NewArr { dest, length, .. } => {
lines.push(format!(
" var_{} = new array[var_{}];",
dest.index(),
length.index()
));
}
SsaOp::LoadElement {
dest, array, index, ..
} => {
lines.push(format!(
" var_{} = var_{}[var_{}];",
dest.index(),
array.index(),
index.index()
));
}
SsaOp::StoreElement {
array,
index,
value,
..
} => {
lines.push(format!(
" var_{}[var_{}] = var_{};",
array.index(),
index.index(),
value.index()
));
}
SsaOp::ArrayLength { dest, array } => {
lines.push(format!(
" var_{} = var_{}.Length;",
dest.index(),
array.index()
));
}
SsaOp::Const {
value: ConstValue::DecryptedString(s),
..
} => {
let truncated = if s.len() > 30 {
format!("\"{}...\"", &s[..27])
} else {
format!("\"{s}\"")
};
lines.push(format!(" // string: {truncated}"));
}
SsaOp::Const {
value: ConstValue::String(idx),
..
} => {
if let Some(us) = assembly.userstrings() {
if let Ok(s) = us.get(*idx as usize) {
if let Ok(decoded) = s.to_string() {
if !decoded.is_empty() {
let truncated = if decoded.len() > 30 {
format!("\"{}...\"", &decoded[..27])
} else {
format!("\"{decoded}\"")
};
lines.push(format!(" // string: {truncated}"));
}
}
}
}
}
SsaOp::Return { value: Some(v) } => {
lines.push(format!(" return var_{};", v.index()));
}
SsaOp::Return { value: None } => {
lines.push(" return;".to_string());
}
_ => {}
}
}
if lines.is_empty() {
None
} else {
Some(lines.join("\n"))
}
}
fn find_phase_boundaries(ssa: &SsaFunction, assembly: &CilObject) -> (Vec<usize>, HashSet<usize>) {
let blocks = ssa.blocks();
let block_count = blocks.len();
let mut boundaries: HashSet<usize> = HashSet::new();
let mut loop_headers: HashSet<usize> = HashSet::new();
boundaries.insert(0);
for handler in ssa.exception_handlers() {
if let Some(handler_start) = handler.handler_start_block {
boundaries.insert(handler_start);
}
if let Some(try_start) = handler.try_start_block {
boundaries.insert(try_start);
}
}
let mut visited = vec![false; block_count];
let mut in_stack = vec![false; block_count];
detect_back_edges(ssa, 0, &mut visited, &mut in_stack, &mut loop_headers);
boundaries.extend(&loop_headers);
let mut prev_namespace: Option<String> = None;
for (block_idx, block) in blocks.iter().enumerate() {
for instr in block.instructions() {
if let SsaOp::Call { method, .. } | SsaOp::CallVirt { method, .. } = instr.op() {
if let Some(name) = assembly.resolve_method_name(method.token()) {
let ns = extract_namespace(&name);
if let Some(ref prev) = prev_namespace {
if ns != *prev {
boundaries.insert(block_idx);
}
}
prev_namespace = Some(ns);
}
}
}
}
detect_transform_boundaries(ssa, &mut boundaries);
let mut sorted: Vec<usize> = boundaries.into_iter().collect();
sorted.sort_unstable();
(sorted, loop_headers)
}
fn build_phases_from_boundaries(
ssa: &SsaFunction,
assembly: &CilObject,
boundaries: &[usize],
loop_headers: &HashSet<usize>,
) -> Vec<PhaseInfo> {
let blocks = ssa.blocks();
let block_count = blocks.len();
let mut phases = Vec::new();
for (i, &start) in boundaries.iter().enumerate() {
let end = if i + 1 < boundaries.len() {
boundaries[i + 1]
} else {
block_count
};
if start >= block_count {
continue;
}
let mut call_targets = Vec::new();
let mut call_seen = HashSet::new();
let mut profile = OpcodeProfile::default();
let mut structure = None;
for handler in ssa.exception_handlers() {
if handler.handler_start_block == Some(start) {
structure = Some("try/catch".to_string());
}
}
if structure.is_none() && loop_headers.contains(&start) {
structure = Some("loop".to_string());
}
for (block_idx, block) in blocks
.iter()
.enumerate()
.take(end.min(block_count))
.skip(start)
{
for instr in block.instructions() {
classify_op_into_profile(instr.op(), &mut profile);
if let SsaOp::Call { method, .. } | SsaOp::CallVirt { method, .. } = instr.op() {
if let Some(name) = assembly.resolve_method_name(method.token()) {
if !is_obfuscated_name(&name) && call_seen.insert(name.clone()) {
call_targets.push(name);
}
}
}
}
if structure.is_none() {
for succ in block.successors() {
if succ >= start && succ < end && succ <= block_idx {
structure = Some("loop".to_string());
}
}
}
}
if structure.is_none() {
if profile.calls == 0 && (profile.bitwise >= 3 || profile.arithmetic >= 3) {
structure = Some("transform".to_string());
} else if profile.comparison > 0 {
structure = Some("conditional".to_string());
} else {
structure = Some("linear".to_string());
}
}
phases.push(PhaseInfo {
label: String::new(), call_targets,
opcode_profile: Some(profile),
structure,
});
}
phases
}
fn detect_back_edges(
ssa: &SsaFunction,
block_idx: usize,
visited: &mut [bool],
in_stack: &mut [bool],
boundaries: &mut HashSet<usize>,
) {
if block_idx >= ssa.blocks().len() {
return;
}
visited[block_idx] = true;
in_stack[block_idx] = true;
for succ in ssa.blocks()[block_idx].successors() {
if succ < visited.len() {
if in_stack[succ] {
boundaries.insert(succ);
} else if !visited[succ] {
detect_back_edges(ssa, succ, visited, in_stack, boundaries);
}
}
}
in_stack[block_idx] = false;
}
fn detect_transform_boundaries(ssa: &SsaFunction, boundaries: &mut HashSet<usize>) {
let blocks = ssa.blocks();
let mut consecutive_transform = 0;
for (block_idx, block) in blocks.iter().enumerate() {
let mut has_calls = false;
let mut bitwise_count = 0u32;
let mut arithmetic_count = 0u32;
for instr in block.instructions() {
match instr.op() {
SsaOp::Call { .. } | SsaOp::CallVirt { .. } | SsaOp::CallIndirect { .. } => {
has_calls = true;
}
SsaOp::And { .. }
| SsaOp::Or { .. }
| SsaOp::Xor { .. }
| SsaOp::Not { .. }
| SsaOp::Shl { .. }
| SsaOp::Shr { .. } => {
bitwise_count += 1;
}
SsaOp::Add { .. }
| SsaOp::Sub { .. }
| SsaOp::Mul { .. }
| SsaOp::Div { .. }
| SsaOp::Rem { .. } => {
arithmetic_count += 1;
}
_ => {}
}
}
if !has_calls && (bitwise_count + arithmetic_count) >= 3 {
if consecutive_transform == 0 {
boundaries.insert(block_idx);
}
consecutive_transform += 1;
} else {
if consecutive_transform >= 3 {
boundaries.insert(block_idx);
}
consecutive_transform = 0;
}
}
}
fn classify_op_into_profile(op: &SsaOp, profile: &mut OpcodeProfile) {
match op {
SsaOp::Call { .. } | SsaOp::CallVirt { .. } | SsaOp::CallIndirect { .. } => {
profile.calls += 1;
}
SsaOp::Const {
value: ConstValue::String(_) | ConstValue::DecryptedString(_),
..
} => {
profile.strings += 1;
}
SsaOp::LoadField { .. }
| SsaOp::StoreField { .. }
| SsaOp::LoadStaticField { .. }
| SsaOp::StoreStaticField { .. } => {
profile.field_io += 1;
}
SsaOp::And { .. }
| SsaOp::Or { .. }
| SsaOp::Xor { .. }
| SsaOp::Not { .. }
| SsaOp::Shl { .. }
| SsaOp::Shr { .. } => {
profile.bitwise += 1;
}
SsaOp::Add { .. }
| SsaOp::AddOvf { .. }
| SsaOp::Sub { .. }
| SsaOp::SubOvf { .. }
| SsaOp::Mul { .. }
| SsaOp::MulOvf { .. }
| SsaOp::Div { .. }
| SsaOp::Rem { .. }
| SsaOp::Neg { .. } => {
profile.arithmetic += 1;
}
SsaOp::NewArr { .. }
| SsaOp::LoadElement { .. }
| SsaOp::StoreElement { .. }
| SsaOp::ArrayLength { .. } => {
profile.array += 1;
}
SsaOp::Ceq { .. }
| SsaOp::Clt { .. }
| SsaOp::Cgt { .. }
| SsaOp::Branch { .. }
| SsaOp::BranchCmp { .. } => {
profile.comparison += 1;
}
SsaOp::Conv { .. } => {
profile.conversion += 1;
}
_ => {}
}
}
fn extract_namespace(method_name: &str) -> String {
if let Some(idx) = method_name.rfind("::") {
let type_part = &method_name[..idx];
if let Some(dot_idx) = type_part.rfind('.') {
return type_part[..dot_idx].to_string();
}
return type_part.to_string();
}
method_name.to_string()
}
#[cfg(test)]
mod tests {
use crate::{
deobfuscation::renamer::phases::{
build_call_site_skeleton, decompose_method, extract_namespace,
},
metadata::validation::ValidationConfig,
CilObject,
};
const ORIGINAL_EXE: &str = "tests/samples/packers/confuserex/1.6.0/original.exe";
fn load_and_build_ssa(method_name: &str) -> Option<(crate::analysis::SsaFunction, CilObject)> {
let path = std::path::Path::new(ORIGINAL_EXE);
if !path.exists() {
return None;
}
let assembly =
CilObject::from_path_with_validation(path, ValidationConfig::analysis()).unwrap();
let method = assembly
.methods()
.iter()
.find(|e| e.value().name == method_name)?
.value()
.clone();
let ssa = method.ssa(&assembly).ok()?;
Some((ssa, assembly))
}
#[test]
fn test_extract_namespace() {
assert_eq!(
extract_namespace("System.IO.File::ReadAllText"),
"System.IO"
);
assert_eq!(
extract_namespace("System.Security.Cryptography.Aes::Create"),
"System.Security.Cryptography"
);
assert_eq!(extract_namespace("MyClass::DoWork"), "MyClass");
assert_eq!(extract_namespace("SimpleMethod"), "SimpleMethod");
}
#[test]
fn test_decompose_sayhello_single_phase() {
let Some((ssa, assembly)) = load_and_build_ssa("SayHello") else {
eprintln!("Skipping: original.exe not found");
return;
};
let phases = decompose_method(&ssa, &assembly, 100);
assert_eq!(phases.len(), 1, "SayHello should be a single phase");
assert_eq!(
phases[0].structure.as_deref(),
Some("linear"),
"Simple method should be classified as linear"
);
assert!(
!phases[0].call_targets.is_empty(),
"SayHello should have call targets (Console.WriteLine)"
);
assert!(
phases[0]
.call_targets
.iter()
.any(|t| t.contains("WriteLine")),
"SayHello should call WriteLine, got: {:?}",
phases[0].call_targets
);
}
#[test]
fn test_skeleton_sayhello() {
let Some((ssa, assembly)) = load_and_build_ssa("SayHello") else {
eprintln!("Skipping: original.exe not found");
return;
};
let skeleton = build_call_site_skeleton(&ssa, &assembly);
assert!(skeleton.is_some(), "SayHello should produce a skeleton");
let skeleton = skeleton.unwrap();
assert!(
skeleton.contains("WriteLine"),
"Skeleton should mention WriteLine, got:\n{skeleton}"
);
assert!(
skeleton.contains("return"),
"Skeleton should contain a return statement"
);
}
#[test]
fn test_decompose_demoloop_has_loop_phase() {
let Some((ssa, assembly)) = load_and_build_ssa("DemoLoop") else {
eprintln!("Skipping: original.exe not found");
return;
};
let phases = decompose_method(&ssa, &assembly, 0);
assert!(
phases.len() > 1,
"DemoLoop should produce multiple phases, got {}",
phases.len()
);
let has_loop = phases
.iter()
.any(|p| p.structure.as_deref() == Some("loop"));
assert!(
has_loop,
"DemoLoop should have at least one loop phase, structures: {:?}",
phases.iter().map(|p| &p.structure).collect::<Vec<_>>()
);
}
#[test]
fn test_decompose_fibonacci_has_loop_phase() {
let Some((ssa, assembly)) = load_and_build_ssa("Fibonacci") else {
eprintln!("Skipping: original.exe not found");
return;
};
let phases = decompose_method(&ssa, &assembly, 0);
assert!(
phases.len() > 1,
"Fibonacci should produce multiple phases, got {}",
phases.len()
);
let has_loop = phases
.iter()
.any(|p| p.structure.as_deref() == Some("loop"));
assert!(
has_loop,
"Fibonacci should have a loop phase, structures: {:?}",
phases.iter().map(|p| &p.structure).collect::<Vec<_>>()
);
}
#[test]
fn test_decompose_demoifelse_has_conditional() {
let Some((ssa, assembly)) = load_and_build_ssa("DemoIfElse") else {
eprintln!("Skipping: original.exe not found");
return;
};
let phases = decompose_method(&ssa, &assembly, 0);
assert!(
!phases.is_empty(),
"DemoIfElse should produce at least one phase"
);
let has_conditional = phases
.iter()
.any(|p| p.structure.as_deref() == Some("conditional"));
assert!(
has_conditional,
"DemoIfElse should have a conditional phase, structures: {:?}",
phases.iter().map(|p| &p.structure).collect::<Vec<_>>()
);
}
#[test]
fn test_decompose_add_opcode_profile() {
let Some((ssa, assembly)) = load_and_build_ssa("Add") else {
eprintln!("Skipping: original.exe not found");
return;
};
let phases = decompose_method(&ssa, &assembly, 100);
assert_eq!(phases.len(), 1, "Add should be a single phase");
let profile = phases[0].opcode_profile.as_ref().unwrap();
assert!(
profile.arithmetic > 0,
"Add method should have arithmetic operations in its profile"
);
}
}