use super::{Transform, TransformCategory, TransformLevel};
use crate::mir::{Function, Instruction, Operand};
use std::collections::HashMap;
#[derive(Default)]
pub struct CfgSimplify;
impl Transform for CfgSimplify {
fn name(&self) -> &'static str {
"cfg_simplify"
}
fn description(&self) -> &'static str {
"Simplifies trivial branches and selects"
}
fn category(&self) -> TransformCategory {
TransformCategory::ControlFlowOptimization
}
fn level(&self) -> TransformLevel {
TransformLevel::Stable
}
fn apply(&self, func: &mut Function) -> Result<bool, String> {
self.apply_internal(func)
}
}
impl CfgSimplify {
fn apply_internal(&self, func: &mut Function) -> Result<bool, String> {
let mut changed = false;
for block in &mut func.blocks {
for instr in &mut block.instructions {
match instr {
Instruction::Br {
cond: _,
true_target,
false_target,
} if true_target == false_target => {
let target = true_target.clone();
*instr = Instruction::Jmp { target };
changed = true;
}
Instruction::Select {
dst,
ty,
cond: _,
true_val,
false_val,
} if true_val == false_val => {
let replacement = Instruction::IntBinary {
op: crate::mir::IntBinOp::Add,
ty: *ty,
dst: dst.clone(),
lhs: true_val.clone(),
rhs: Operand::Immediate(crate::mir::instruction::Immediate::I64(0)),
};
*instr = replacement;
changed = true;
}
_ => {}
}
}
}
let mut preds: HashMap<String, Vec<String>> = HashMap::new();
for block in &func.blocks {
if let Some(term) = block.instructions.last()
&& term.is_terminator()
{
match term {
Instruction::Jmp { target } => {
preds
.entry(target.clone())
.or_default()
.push(block.label.clone());
}
Instruction::Br {
true_target,
false_target,
..
} => {
preds
.entry(true_target.clone())
.or_default()
.push(block.label.clone());
preds
.entry(false_target.clone())
.or_default()
.push(block.label.clone());
}
Instruction::Switch { cases, default, .. } => {
preds
.entry(default.clone())
.or_default()
.push(block.label.clone());
for (_, case_target) in cases {
preds
.entry(case_target.clone())
.or_default()
.push(block.label.clone());
}
}
_ => {}
}
}
}
let mut merges = Vec::new();
for block in &func.blocks {
if block.instructions.len() == 1
&& let Some(Instruction::Jmp { target }) = block.instructions.last()
&& let Some(preds_list) = preds.get(&block.label)
&& preds_list.len() == 1
{
let pred_label = preds_list[0].clone();
merges.push((pred_label, target.clone(), block.label.clone()));
}
}
let mut to_remove = Vec::new();
for (pred_label, new_target, trivial_label) in merges {
if let Some(pred_block) = func.blocks.iter_mut().find(|b| b.label == pred_label)
&& let Some(pred_term) = pred_block.instructions.last_mut()
&& let Instruction::Jmp { target } = pred_term
&& *target == trivial_label
{
*target = new_target;
changed = true;
to_remove.push(trivial_label);
}
}
func.blocks.retain(|b| !to_remove.contains(&b.label));
func.blocks.sort_by_key(|b| b.label.clone());
Ok(changed)
}
}
#[derive(Default)]
pub struct JumpThreading;
impl Transform for JumpThreading {
fn name(&self) -> &'static str {
"jump_threading"
}
fn description(&self) -> &'static str {
"Bypass trivial jump-only blocks in branch/jump targets"
}
fn category(&self) -> TransformCategory {
TransformCategory::ControlFlowOptimization
}
fn level(&self) -> TransformLevel {
TransformLevel::Stable
}
fn apply(&self, func: &mut Function) -> Result<bool, String> {
self.apply_internal(func)
}
}
impl JumpThreading {
fn apply_internal(&self, func: &mut Function) -> Result<bool, String> {
let mut simple_jumps: HashMap<String, String> = HashMap::new();
for block in &func.blocks {
if block.instructions.len() == 1
&& let Instruction::Jmp { target } = &block.instructions[0]
{
simple_jumps.insert(block.label.clone(), target.clone());
}
}
fn resolve_target(map: &HashMap<String, String>, mut tgt: String) -> String {
let mut seen = std::collections::HashSet::new();
const MAX_CHAIN: usize = 100;
let mut i = 0;
while let Some(next) = map.get(&tgt) {
if i >= MAX_CHAIN || !seen.insert(tgt.clone()) {
break;
}
tgt = next.clone();
i += 1;
}
tgt
}
let mut resolved_targets: HashMap<String, String> = HashMap::new();
for k in simple_jumps.keys() {
let resolved = resolve_target(&simple_jumps, k.clone());
if resolved != *k && !simple_jumps.contains_key(&resolved) {
resolved_targets.insert(k.clone(), resolved);
}
}
let mut changed = false;
for block in &mut func.blocks {
for instr in &mut block.instructions {
match instr {
Instruction::Jmp { target } => {
if let Some(new_tgt) = resolved_targets.get(target)
&& new_tgt != target
{
*target = new_tgt.clone();
changed = true;
}
}
Instruction::Br {
cond: _,
true_target,
false_target,
} => {
if let Some(new_tgt) = resolved_targets.get(true_target)
&& new_tgt != true_target
{
*true_target = new_tgt.clone();
changed = true;
}
if let Some(new_tgt) = resolved_targets.get(false_target)
&& new_tgt != false_target
{
*false_target = new_tgt.clone();
changed = true;
}
}
Instruction::Switch {
value: _, cases, ..
} => {
let mut local_change = false;
for (_val, tgt) in cases.iter_mut() {
if let Some(new_tgt) = resolved_targets.get(tgt)
&& new_tgt != tgt
{
*tgt = new_tgt.clone();
local_change = true;
}
}
if local_change {
changed = true;
}
}
_ => {}
}
}
}
Ok(changed)
}
}