use std::{collections::HashMap, sync::Arc};
use crate::{
analysis::{
LoopInfo, PhiNode, PhiOperand, SsaBlock, SsaFunction, SsaInstruction, SsaLoopAnalysis,
SsaOp, SsaVarId, VariableOrigin,
},
compiler::{pass::SsaPass, CompilerContext, EventKind, EventLog},
metadata::token::Token,
utils::graph::NodeId,
CilObject, Result,
};
pub struct LoopCanonicalizationPass;
impl Default for LoopCanonicalizationPass {
fn default() -> Self {
Self::new()
}
}
impl LoopCanonicalizationPass {
#[must_use]
pub fn new() -> Self {
Self
}
fn canonicalize_loops(
ssa: &mut SsaFunction,
method_token: Token,
changes: &mut EventLog,
) -> usize {
let mut total_modified = 0;
loop {
let forest = ssa.analyze_loops();
if forest.is_empty() {
break;
}
let mut modified_this_iteration = 0;
for loop_info in forest.by_depth_descending() {
if !loop_info.has_preheader() {
let non_loop_preds = Self::get_non_loop_predecessors(ssa, loop_info);
if non_loop_preds.len() > 1 {
Self::insert_preheader(
ssa,
loop_info,
&non_loop_preds,
method_token,
changes,
);
modified_this_iteration += 1;
break;
}
}
if !loop_info.has_single_latch() && loop_info.latches.len() > 1 {
Self::unify_latches(ssa, loop_info, method_token, changes);
modified_this_iteration += 1;
break;
}
}
total_modified += modified_this_iteration;
if modified_this_iteration == 0 {
break;
}
}
total_modified
}
fn get_non_loop_predecessors(ssa: &SsaFunction, loop_info: &LoopInfo) -> Vec<usize> {
let header_idx = loop_info.header.index();
let mut non_loop_preds = Vec::new();
for (block_idx, block) in ssa.iter_blocks() {
if let Some(op) = block.terminator_op() {
let targets = Self::get_targets(op);
if targets.contains(&header_idx)
&& !loop_info.body.contains(&NodeId::new(block_idx))
{
non_loop_preds.push(block_idx);
}
}
}
non_loop_preds
}
fn get_targets(op: &SsaOp) -> Vec<usize> {
match op {
SsaOp::Jump { target } | SsaOp::Leave { target } => vec![*target],
SsaOp::Branch {
true_target,
false_target,
..
} => vec![*true_target, *false_target],
SsaOp::Switch {
targets, default, ..
} => {
let mut all = targets.clone();
all.push(*default);
all
}
_ => vec![],
}
}
fn insert_preheader(
ssa: &mut SsaFunction,
loop_info: &LoopInfo,
non_loop_preds: &[usize],
method_token: Token,
changes: &mut EventLog,
) {
let header_idx = loop_info.header.index();
let preheader_idx = ssa.block_count();
let mut preheader = SsaBlock::new(preheader_idx);
preheader.add_instruction(SsaInstruction::synthetic(SsaOp::Jump {
target: header_idx,
}));
if let Some(header) = ssa.block(header_idx) {
let header_phis: Vec<_> = header.phi_nodes().to_vec();
for phi in &header_phis {
let non_loop_operands: Vec<_> = phi
.operands()
.iter()
.filter(|op| non_loop_preds.contains(&op.predecessor()))
.copied()
.collect();
if non_loop_operands.len() > 1 {
let new_var = SsaVarId::new();
let mut preheader_phi = PhiNode::new(new_var, phi.origin());
for op in &non_loop_operands {
preheader_phi.add_operand(*op);
}
preheader.phi_nodes_mut().push(preheader_phi);
}
}
}
ssa.add_block(preheader);
for &pred_idx in non_loop_preds {
Self::redirect_targets(ssa, pred_idx, header_idx, preheader_idx);
}
let preheader_phi_map: HashMap<VariableOrigin, SsaVarId> = ssa
.block(preheader_idx)
.map(|b| {
b.phi_nodes()
.iter()
.map(|p| (p.origin(), p.result()))
.collect()
})
.unwrap_or_default();
if let Some(header) = ssa.block_mut(header_idx) {
for phi in header.phi_nodes_mut() {
let origin = phi.origin();
let operands = phi.operands_mut();
let mut loop_operands: Vec<PhiOperand> = Vec::new();
let mut non_loop_values: Vec<PhiOperand> = Vec::new();
for op in operands.drain(..) {
if non_loop_preds.contains(&op.predecessor()) {
non_loop_values.push(op);
} else {
loop_operands.push(op);
}
}
operands.extend(loop_operands);
if !non_loop_values.is_empty() {
if non_loop_values.len() == 1 {
operands.push(PhiOperand::new(non_loop_values[0].value(), preheader_idx));
} else if let Some(&preheader_var) = preheader_phi_map.get(&origin) {
operands.push(PhiOperand::new(preheader_var, preheader_idx));
}
}
}
}
changes
.record(EventKind::ControlFlowRestructured)
.at(method_token, preheader_idx)
.message(format!(
"Inserted preheader B{preheader_idx} for loop at B{header_idx}"
));
}
fn unify_latches(
ssa: &mut SsaFunction,
loop_info: &LoopInfo,
method_token: Token,
changes: &mut EventLog,
) {
let header_idx = loop_info.header.index();
let latches: Vec<usize> = loop_info.latches.iter().map(|n| n.index()).collect();
let unified_latch_idx = ssa.block_count();
let mut unified_latch = SsaBlock::new(unified_latch_idx);
unified_latch.add_instruction(SsaInstruction::synthetic(SsaOp::Jump {
target: header_idx,
}));
let mut latch_phi_vars: HashMap<VariableOrigin, SsaVarId> = HashMap::new();
if let Some(header) = ssa.block(header_idx) {
for phi in header.phi_nodes() {
let latch_operands: Vec<_> = phi
.operands()
.iter()
.filter(|op| latches.contains(&op.predecessor()))
.copied()
.collect();
if latch_operands.len() > 1 {
let new_var = SsaVarId::new();
let mut latch_phi = PhiNode::new(new_var, phi.origin());
for op in &latch_operands {
latch_phi.add_operand(*op);
}
latch_phi_vars.insert(phi.origin(), new_var);
unified_latch.phi_nodes_mut().push(latch_phi);
} else if latch_operands.len() == 1 {
latch_phi_vars.insert(phi.origin(), latch_operands[0].value());
}
}
}
ssa.add_block(unified_latch);
for &latch_idx in &latches {
Self::redirect_targets(ssa, latch_idx, header_idx, unified_latch_idx);
}
if let Some(header) = ssa.block_mut(header_idx) {
for phi in header.phi_nodes_mut() {
let origin = phi.origin();
let operands = phi.operands_mut();
operands.retain(|op| !latches.contains(&op.predecessor()));
if let Some(&var) = latch_phi_vars.get(&origin) {
operands.push(PhiOperand::new(var, unified_latch_idx));
}
}
}
changes
.record(EventKind::ControlFlowRestructured)
.at(method_token, unified_latch_idx)
.message(format!(
"Unified {} latches into B{} for loop at B{}",
latches.len(),
unified_latch_idx,
header_idx
));
}
fn redirect_targets(
ssa: &mut SsaFunction,
block_idx: usize,
old_target: usize,
new_target: usize,
) {
if let Some(block) = ssa.block_mut(block_idx) {
if let Some(last) = block.instructions_mut().last_mut() {
let new_op = match last.op() {
SsaOp::Jump { target } if *target == old_target => {
Some(SsaOp::Jump { target: new_target })
}
SsaOp::Leave { target } if *target == old_target => {
Some(SsaOp::Leave { target: new_target })
}
SsaOp::Branch {
condition,
true_target,
false_target,
} => {
let new_true = if *true_target == old_target {
new_target
} else {
*true_target
};
let new_false = if *false_target == old_target {
new_target
} else {
*false_target
};
if new_true != *true_target || new_false != *false_target {
Some(SsaOp::Branch {
condition: *condition,
true_target: new_true,
false_target: new_false,
})
} else {
None
}
}
SsaOp::Switch {
value,
targets,
default,
} => {
let new_targets: Vec<_> = targets
.iter()
.map(|&t| if t == old_target { new_target } else { t })
.collect();
let new_default = if *default == old_target {
new_target
} else {
*default
};
if new_targets != *targets || new_default != *default {
Some(SsaOp::Switch {
value: *value,
targets: new_targets,
default: new_default,
})
} else {
None
}
}
_ => None,
};
if let Some(new_op) = new_op {
last.set_op(new_op);
}
}
}
}
}
impl SsaPass for LoopCanonicalizationPass {
fn name(&self) -> &'static str {
"LoopCanonicalization"
}
fn description(&self) -> &'static str {
"Transforms loops into canonical form with single preheaders and single latches"
}
fn run_on_method(
&self,
ssa: &mut SsaFunction,
method_token: Token,
ctx: &CompilerContext,
_assembly: &Arc<CilObject>,
) -> Result<bool> {
if ssa.block_count() < 2 {
return Ok(false);
}
let mut changes = EventLog::new();
let modified = Self::canonicalize_loops(ssa, method_token, &mut changes);
if modified > 0 {
ssa.canonicalize();
ctx.events.merge(&changes);
return Ok(true);
}
Ok(false)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::analysis::{SsaFunctionBuilder, SsaLoopAnalysis};
#[test]
fn test_preheader_insertion() {
let mut ssa = SsaFunctionBuilder::new(0, 0).build_with(|f| {
f.block(0, |b| {
let cond0 = b.const_true();
b.branch(cond0, 1, 2);
});
f.block(1, |b| b.jump(3));
f.block(2, |b| b.jump(3));
f.block(3, |b| b.jump(4));
f.block(4, |b| {
let cond1 = b.const_true();
b.branch(cond1, 3, 5); });
f.block(5, |b| b.ret());
});
let forest = ssa.analyze_loops();
assert_eq!(forest.len(), 1);
let loop_info = &forest.loops()[0];
assert!(!loop_info.has_preheader());
let mut changes = EventLog::new();
let modified =
LoopCanonicalizationPass::canonicalize_loops(&mut ssa, Token::new(0), &mut changes);
assert!(modified > 0);
let forest = ssa.analyze_loops();
assert_eq!(forest.len(), 1);
let loop_info = &forest.loops()[0];
assert!(loop_info.has_preheader());
}
#[test]
fn test_latch_unification() {
let mut ssa = SsaFunctionBuilder::new(0, 0).build_with(|f| {
f.block(0, |b| b.jump(1));
f.block(1, |b| {
let cond1 = b.const_true();
b.branch(cond1, 2, 3);
});
f.block(2, |b| b.jump(1));
f.block(3, |b| {
let cond2 = b.const_true();
b.branch(cond2, 1, 4); });
f.block(4, |b| b.ret());
});
let forest = ssa.analyze_loops();
assert_eq!(forest.len(), 1);
let loop_info = &forest.loops()[0];
assert!(!loop_info.has_single_latch());
assert!(loop_info.latches.len() >= 2);
let mut changes = EventLog::new();
let modified =
LoopCanonicalizationPass::canonicalize_loops(&mut ssa, Token::new(0), &mut changes);
assert!(modified > 0);
let forest = ssa.analyze_loops();
assert_eq!(forest.len(), 1);
let loop_info = &forest.loops()[0];
assert!(loop_info.has_single_latch());
}
#[test]
fn test_already_canonical_loop() {
let mut ssa = SsaFunctionBuilder::new(0, 0).build_with(|f| {
f.block(0, |b| b.jump(1));
f.block(1, |b| b.jump(2));
f.block(2, |b| {
let cond = b.const_true();
b.branch(cond, 1, 3); });
f.block(3, |b| b.ret());
});
let forest = ssa.analyze_loops();
assert_eq!(forest.len(), 1);
let loop_info = &forest.loops()[0];
assert!(loop_info.is_canonical());
let mut changes = EventLog::new();
let modified =
LoopCanonicalizationPass::canonicalize_loops(&mut ssa, Token::new(0), &mut changes);
assert_eq!(modified, 0);
}
#[test]
fn test_no_loops() {
let mut ssa = SsaFunctionBuilder::new(0, 0).build_with(|f| {
f.block(0, |b| b.jump(1));
f.block(1, |b| b.jump(2));
f.block(2, |b| b.ret());
});
let forest = ssa.analyze_loops();
assert!(forest.is_empty());
let mut changes = EventLog::new();
let modified =
LoopCanonicalizationPass::canonicalize_loops(&mut ssa, Token::new(0), &mut changes);
assert_eq!(modified, 0);
}
}