use std::collections::{HashMap, HashSet, VecDeque};
use crate::{
analysis::{LoopAnalyzer, LoopInfo, SsaFunction, SsaInstruction, SsaOp, SsaVarId},
compiler::{
pass::{ModificationScope, SsaPass},
CompilerContext, EventKind,
},
metadata::token::Token,
utils::BitSet,
CilObject, Result,
};
pub struct LicmPass;
impl Default for LicmPass {
fn default() -> Self {
Self::new()
}
}
impl LicmPass {
#[must_use]
pub fn new() -> Self {
Self
}
}
impl SsaPass for LicmPass {
fn name(&self) -> &'static str {
"licm"
}
fn description(&self) -> &'static str {
"Moves loop-invariant computations to loop preheaders"
}
fn modification_scope(&self) -> ModificationScope {
ModificationScope::InstructionsOnly
}
fn run_on_method(
&self,
ssa: &mut SsaFunction,
method_token: Token,
ctx: &CompilerContext,
_assembly: &CilObject,
) -> Result<bool> {
let forest = LoopAnalyzer::new(ssa).analyze();
if forest.is_empty() {
return Ok(false);
}
let mut total_hoisted = 0;
for loop_info in forest.by_depth_descending() {
let Some(preheader) = loop_info.preheader else {
continue;
};
let header_idx = loop_info.header.index();
let preheader_is_pred = ssa
.block(preheader.index())
.map(|b| {
b.instructions()
.last()
.map(|i| i.op().successors().contains(&header_idx))
.unwrap_or(false)
})
.unwrap_or(false);
if !preheader_is_pred {
continue;
}
let header_has_switch = ssa
.block(header_idx)
.and_then(|b| b.terminator_op())
.is_some_and(|op| matches!(op, SsaOp::Switch { .. }));
if header_has_switch {
continue;
}
let invariants = find_loop_invariants(ssa, loop_info);
if invariants.is_empty() {
continue;
}
let mut hoistable: Vec<_> = invariants
.into_iter()
.filter(|(block_idx, instr_idx)| can_hoist(ssa, loop_info, *block_idx, *instr_idx))
.collect();
let mut outside_defs = BitSet::new(ssa.var_id_capacity());
for v in ssa.variables() {
if !loop_info.body.contains(v.def_site().block) {
outside_defs.insert(v.id().index());
}
}
loop {
let mut hoistable_defs = BitSet::new(ssa.var_id_capacity());
for (block_idx, instr_idx) in hoistable.iter() {
if let Some(def) = ssa
.block(*block_idx)
.and_then(|b| b.instruction(*instr_idx))
.and_then(|i| i.def())
{
hoistable_defs.insert(def.index());
}
}
let before = hoistable.len();
hoistable.retain(|(block_idx, instr_idx)| {
let Some(block) = ssa.block(*block_idx) else {
return false;
};
let Some(instr) = block.instruction(*instr_idx) else {
return false;
};
instr.op().uses().iter().all(|operand| {
outside_defs.contains(operand.index())
|| hoistable_defs.contains(operand.index())
})
});
if hoistable.len() == before {
break;
}
}
{
let mut hoist_count_per_block: HashMap<usize, usize> = HashMap::new();
for (block_idx, _) in &hoistable {
*hoist_count_per_block.entry(*block_idx).or_insert(0) += 1;
}
let mut trampoline_blocks = BitSet::new(ssa.block_count());
for (&block_idx, &hoist_count) in &hoist_count_per_block {
if let Some(block) = ssa.block(block_idx) {
let non_term = block
.instructions()
.iter()
.filter(|i| !i.is_terminator() && !matches!(i.op(), SsaOp::Nop))
.count();
if hoist_count >= non_term {
if let Some(term) = block.terminator_op() {
for succ in term.successors() {
if let Some(succ_block) = ssa.block(succ) {
if !succ_block.phi_nodes().is_empty() {
trampoline_blocks.insert(block_idx);
}
}
}
}
}
}
}
if !trampoline_blocks.is_empty() {
hoistable.retain(|(block_idx, _)| !trampoline_blocks.contains(*block_idx));
}
}
if hoistable.is_empty() {
continue;
}
let mut to_hoist: Vec<(usize, usize, SsaOp)> = Vec::new();
for (block_idx, instr_idx) in &hoistable {
if let Some(block) = ssa.block(*block_idx) {
if let Some(instr) = block.instruction(*instr_idx) {
to_hoist.push((*block_idx, *instr_idx, instr.op().clone()));
}
}
}
to_hoist.sort_by_key(|(block_idx, instr_idx, _)| (*block_idx, *instr_idx));
let insert_base = if let Some(preheader_block) = ssa.block(preheader.index()) {
let instrs = preheader_block.instructions();
if instrs.is_empty() {
0
} else if instrs.last().is_some_and(SsaInstruction::is_terminator) {
instrs.len().saturating_sub(1)
} else {
instrs.len()
}
} else {
0
};
let mut hoisted_from = BitSet::new(ssa.block_count());
for (i, (block_idx, instr_idx, op)) in to_hoist.iter().enumerate() {
hoisted_from.insert(*block_idx);
if let Some(preheader_block) = ssa.block_mut(preheader.index()) {
let new_instr = SsaInstruction::synthetic(op.clone());
let instrs = preheader_block.instructions_mut();
instrs.insert(insert_base + i, new_instr);
}
if let Some(block) = ssa.block_mut(*block_idx) {
if let Some(instr) = block.instructions_mut().get_mut(*instr_idx) {
instr.set_op(SsaOp::Nop);
}
}
total_hoisted += 1;
}
let preheader_idx = preheader.index();
for source_block in hoisted_from.iter() {
let is_trampoline = ssa.block(source_block).is_some_and(|b| {
b.instructions()
.iter()
.all(|i| i.is_terminator() || matches!(i.op(), SsaOp::Nop))
});
if !is_trampoline {
continue;
}
let successors: Vec<usize> = ssa
.block(source_block)
.map(|b| {
b.instructions()
.last()
.map(|i| i.op().successors())
.unwrap_or_default()
})
.unwrap_or_default();
for succ in successors {
if let Some(succ_block) = ssa.block_mut(succ) {
for phi in succ_block.phi_nodes_mut() {
for operand in phi.operands_mut() {
if operand.predecessor() == source_block {
operand.set_predecessor(preheader_idx);
}
}
}
}
}
}
}
if total_hoisted > 0 {
ctx.events
.record(EventKind::InstructionRemoved)
.at(method_token, 0)
.message(format!(
"LICM: hoisted {total_hoisted} loop-invariant instructions"
));
}
Ok(total_hoisted > 0)
}
}
fn find_loop_invariants(ssa: &SsaFunction, loop_info: &LoopInfo) -> Vec<(usize, usize)> {
let mut invariants: HashSet<(usize, usize)> = HashSet::new();
let mut invariant_defs = BitSet::new(ssa.var_id_capacity());
let mut header_phi_defs = BitSet::new(ssa.var_id_capacity());
if let Some(header_block) = ssa.block(loop_info.header.index()) {
for phi in header_block.phi_nodes() {
header_phi_defs.insert(phi.result().index());
}
}
let mut outside_defs = BitSet::new(ssa.var_id_capacity());
for var in ssa.variables() {
let def_site = var.def_site();
if !loop_info.body.contains(def_site.block) {
outside_defs.insert(var.id().index());
}
}
let mut changed = true;
while changed {
changed = false;
for block_idx in loop_info.body.iter() {
if let Some(block) = ssa.block(block_idx) {
for (instr_idx, instr) in block.instructions().iter().enumerate() {
if invariants.contains(&(block_idx, instr_idx)) {
continue;
}
if instr.is_terminator() {
continue;
}
if matches!(instr.op(), SsaOp::Nop) {
continue;
}
if is_instruction_invariant(
instr,
&outside_defs,
&invariant_defs,
&header_phi_defs,
) {
invariants.insert((block_idx, instr_idx));
if let Some(def) = instr.def() {
invariant_defs.insert(def.index());
}
changed = true;
}
}
}
}
}
invariants.into_iter().collect()
}
fn is_instruction_invariant(
instr: &SsaInstruction,
outside_defs: &BitSet,
invariant_defs: &BitSet,
header_phi_defs: &BitSet,
) -> bool {
for operand in instr.op().uses() {
if header_phi_defs.contains(operand.index()) {
return false;
}
if !outside_defs.contains(operand.index()) && !invariant_defs.contains(operand.index()) {
return false;
}
}
true
}
fn can_hoist(ssa: &SsaFunction, loop_info: &LoopInfo, block_idx: usize, instr_idx: usize) -> bool {
let Some(block) = ssa.block(block_idx) else {
return false;
};
let Some(instr) = block.instruction(instr_idx) else {
return false;
};
if instr.def().is_none() {
return false;
}
if !instr.op().is_pure() {
return false;
}
if loop_info.preheader.is_none() {
return false;
}
if let Some(dest) = instr.def() {
if feeds_phi_back_edge(ssa, loop_info, dest) {
return false;
}
}
true
}
fn feeds_phi_back_edge(ssa: &SsaFunction, loop_info: &LoopInfo, var: SsaVarId) -> bool {
let mut worklist: VecDeque<SsaVarId> = VecDeque::new();
let mut visited = BitSet::new(ssa.var_id_capacity());
worklist.push_back(var);
visited.insert(var.index());
while let Some(current) = worklist.pop_front() {
for phi_block_idx in loop_info.body.iter() {
let Some(phi_block) = ssa.block(phi_block_idx) else {
continue;
};
for phi in phi_block.phi_nodes() {
for operand in phi.operands() {
if operand.value() == current && loop_info.body.contains(operand.predecessor())
{
return true;
}
}
}
}
for body_block_idx in loop_info.body.iter() {
if let Some(body_block) = ssa.block(body_block_idx) {
for instr in body_block.instructions() {
if instr.op().uses().contains(¤t) {
if let Some(dest) = instr.def() {
if visited.insert(dest.index()) {
worklist.push_back(dest);
}
}
}
}
}
}
}
false
}
#[cfg(test)]
mod tests {
use crate::{
analysis::{ConstValue, LoopAnalyzer, MethodRef, SsaFunctionBuilder, SsaOp, SsaVarId},
compiler::{LicmPass, SsaPass},
metadata::token::Token,
};
#[test]
fn test_pass_metadata() {
let pass = LicmPass::new();
assert_eq!(pass.name(), "licm");
assert!(!pass.description().is_empty());
}
#[test]
fn test_op_is_pure() {
let add_op = SsaOp::Add {
dest: SsaVarId::from_index(0),
left: SsaVarId::from_index(1),
right: SsaVarId::from_index(2),
};
assert!(add_op.is_pure());
let const_op = SsaOp::Const {
dest: SsaVarId::from_index(3),
value: ConstValue::I32(42),
};
assert!(const_op.is_pure());
let call_op = SsaOp::Call {
dest: Some(SsaVarId::from_index(4)),
method: MethodRef::new(Token::new(0x06000001)),
args: vec![],
};
assert!(!call_op.is_pure());
}
#[test]
fn test_op_uses() {
let v1 = SsaVarId::from_index(0);
let v2 = SsaVarId::from_index(1);
let dest = SsaVarId::from_index(2);
let op = SsaOp::Add {
dest,
left: v1,
right: v2,
};
let uses = op.uses();
assert_eq!(uses.len(), 2);
assert!(uses.contains(&v1));
assert!(uses.contains(&v2));
let const_op = SsaOp::Const {
dest,
value: ConstValue::I32(42),
};
assert!(const_op.uses().is_empty());
}
#[test]
fn test_no_loops() {
let ssa = SsaFunctionBuilder::new(0, 0)
.build_with(|f| {
f.block(0, |b| {
let _ = b.const_i32(42);
b.ret();
});
})
.unwrap();
let forest = LoopAnalyzer::new(&ssa).analyze();
assert!(forest.is_empty());
}
#[test]
fn test_loop_without_preheader() {
let ssa = SsaFunctionBuilder::new(0, 0)
.build_with(|f| {
f.block(0, |b| {
let cond = b.const_true();
b.branch(cond, 1, 2);
});
f.block(1, |b| b.jump(3));
f.block(2, |b| b.jump(3));
f.block(3, |b| {
let cond = b.const_true();
b.branch(cond, 3, 4); });
f.block(4, |b| b.ret());
})
.unwrap();
let forest = LoopAnalyzer::new(&ssa).analyze();
assert!(!forest.is_empty());
let loop_info = &forest.loops()[0];
assert!(!loop_info.has_preheader());
}
#[test]
fn test_simple_loop_has_preheader() {
let ssa = SsaFunctionBuilder::new(0, 0)
.build_with(|f| {
f.block(0, |b| b.jump(1));
f.block(1, |b| {
let cond = b.const_true();
b.branch(cond, 1, 2);
});
f.block(2, |b| b.ret());
})
.unwrap();
let forest = LoopAnalyzer::new(&ssa).analyze();
assert_eq!(forest.len(), 1);
let loop_info = &forest.loops()[0];
assert!(loop_info.has_preheader());
}
}