use std::{collections::HashMap, sync::Arc};
use crate::{
analysis::SsaFunction,
compiler::{pass::SsaPass, passes::utils::resolve_chain, CompilerContext, EventKind, EventLog},
metadata::token::Token,
CilObject, Result,
};
const MAX_ITERATIONS: usize = 50;
pub struct BlockMergingPass;
impl Default for BlockMergingPass {
fn default() -> Self {
Self::new()
}
}
impl BlockMergingPass {
#[must_use]
pub fn new() -> Self {
Self
}
fn redirect_to_ultimate_targets(
ssa: &mut SsaFunction,
trampolines: &HashMap<usize, usize>,
method_token: Token,
changes: &mut EventLog,
) -> usize {
if trampolines.is_empty() {
return 0;
}
let ultimate_targets: HashMap<usize, usize> = trampolines
.keys()
.map(|&t| (t, resolve_chain(trampolines, t)))
.collect();
let mut redirected = 0;
for block_idx in 0..ssa.block_count() {
if let Some(block) = ssa.block_mut(block_idx) {
for instr in block.instructions_mut() {
let op = instr.op_mut();
let old_targets = op.successors();
let mut changed = false;
for (&trampoline, &ultimate) in &ultimate_targets {
if op.redirect_target(trampoline, ultimate) {
changed = true;
}
}
if changed {
let new_targets = op.successors();
changes
.record(EventKind::BranchSimplified)
.at(method_token, block_idx)
.message(format!(
"redirected through trampoline: {old_targets:?} -> {new_targets:?}"
));
redirected += 1;
}
}
}
}
redirected
}
fn clear_trampolines(
ssa: &mut SsaFunction,
trampolines: &HashMap<usize, usize>,
method_token: Token,
changes: &mut EventLog,
) -> usize {
let mut cleared = 0;
for &block_idx in trampolines.keys() {
if let Some(block) = ssa.block_mut(block_idx) {
if !block.instructions().is_empty() {
block.instructions_mut().clear();
changes
.record(EventKind::BlockRemoved)
.at(method_token, block_idx)
.message(format!("cleared trampoline block B{block_idx}"));
cleared += 1;
}
}
}
cleared
}
fn run_iteration(ssa: &mut SsaFunction, method_token: Token, changes: &mut EventLog) -> usize {
let trampolines = ssa.find_trampoline_blocks(true);
if trampolines.is_empty() {
return 0;
}
let redirected =
Self::redirect_to_ultimate_targets(ssa, &trampolines, method_token, changes);
let cleared = Self::clear_trampolines(ssa, &trampolines, method_token, changes);
redirected + cleared
}
}
impl SsaPass for BlockMergingPass {
fn name(&self) -> &'static str {
"block-merging"
}
fn description(&self) -> &'static str {
"Eliminates trampoline blocks (single-jump blocks)"
}
fn run_on_method(
&self,
ssa: &mut SsaFunction,
method_token: Token,
ctx: &CompilerContext,
_assembly: &Arc<CilObject>,
) -> Result<bool> {
let mut changes = EventLog::new();
for _ in 0..MAX_ITERATIONS {
let iteration_changes = Self::run_iteration(ssa, method_token, &mut changes);
if iteration_changes == 0 {
break;
}
}
let changed = !changes.is_empty();
if changed {
ctx.events.merge(&changes);
}
Ok(changed)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{
analysis::{SsaFunctionBuilder, SsaOp},
test::helpers::test_assembly_arc,
};
#[test]
fn test_redirect_simple() {
let mut ssa = SsaFunctionBuilder::new(0, 0).build_with(|f| {
f.block(0, |b| b.jump(1));
f.block(1, |b| b.jump(2));
f.block(2, |b| b.ret());
});
let pass = BlockMergingPass::new();
let ctx = crate::compiler::CompilerContext::new(std::sync::Arc::new(
crate::analysis::CallGraph::new(),
));
let assembly = test_assembly_arc();
let changed = pass
.run_on_method(&mut ssa, Token::new(0x06000001), &ctx, &assembly)
.unwrap();
assert!(changed);
if let Some(block) = ssa.block(0) {
if let Some(instr) = block.instructions().first() {
if let SsaOp::Jump { target } = instr.op() {
assert_eq!(*target, 2);
}
}
}
if let Some(block) = ssa.block(1) {
assert!(
block.instructions().is_empty(),
"B1 should be cleared, but has {} instructions",
block.instructions().len()
);
}
}
#[test]
fn test_chain_of_trampolines() {
let mut ssa = SsaFunctionBuilder::new(0, 0).build_with(|f| {
f.block(0, |b| b.jump(1));
f.block(1, |b| b.jump(2));
f.block(2, |b| b.jump(3));
f.block(3, |b| b.ret());
});
let pass = BlockMergingPass::new();
let ctx = crate::compiler::CompilerContext::new(std::sync::Arc::new(
crate::analysis::CallGraph::new(),
));
let assembly = test_assembly_arc();
let changed = pass
.run_on_method(&mut ssa, Token::new(0x06000001), &ctx, &assembly)
.unwrap();
assert!(changed);
if let Some(block) = ssa.block(0) {
if let Some(instr) = block.instructions().first() {
if let SsaOp::Jump { target } = instr.op() {
assert_eq!(*target, 3, "B0 should jump to B3, not B{}", *target);
}
}
}
for i in 1..=2 {
if let Some(block) = ssa.block(i) {
assert!(block.instructions().is_empty(), "B{} should be cleared", i);
}
}
}
}