use std::{
collections::{HashSet, VecDeque},
sync::Arc,
};
use crate::{
analysis::{LoopAnalyzer, LoopInfo, SsaFunction, SsaInstruction, SsaOp, SsaVarId},
compiler::{pass::SsaPass, CompilerContext, EventKind},
metadata::token::Token,
utils::graph::NodeId,
CilObject, Result,
};
pub struct LicmPass;
impl Default for LicmPass {
fn default() -> Self {
Self::new()
}
}
impl LicmPass {
#[must_use]
pub fn new() -> Self {
Self
}
}
impl SsaPass for LicmPass {
fn name(&self) -> &'static str {
"licm"
}
fn description(&self) -> &'static str {
"Moves loop-invariant computations to loop preheaders"
}
fn run_on_method(
&self,
ssa: &mut SsaFunction,
method_token: Token,
ctx: &CompilerContext,
_assembly: &Arc<CilObject>,
) -> Result<bool> {
let forest = LoopAnalyzer::new(ssa).analyze();
if forest.is_empty() {
return Ok(false);
}
let mut total_hoisted = 0;
for loop_info in forest.by_depth_descending() {
let Some(preheader) = loop_info.preheader else {
continue;
};
let invariants = find_loop_invariants(ssa, loop_info);
if invariants.is_empty() {
continue;
}
let hoistable: Vec<_> = invariants
.into_iter()
.filter(|(block_idx, instr_idx)| can_hoist(ssa, loop_info, *block_idx, *instr_idx))
.collect();
if hoistable.is_empty() {
continue;
}
let mut to_hoist: Vec<(usize, usize, SsaOp)> = Vec::new();
for (block_idx, instr_idx) in &hoistable {
if let Some(block) = ssa.block(*block_idx) {
if let Some(instr) = block.instruction(*instr_idx) {
to_hoist.push((*block_idx, *instr_idx, instr.op().clone()));
}
}
}
to_hoist.sort_by_key(|(block_idx, instr_idx, _)| (*block_idx, *instr_idx));
let insert_base = if let Some(preheader_block) = ssa.block(preheader.index()) {
let instrs = preheader_block.instructions();
if instrs.is_empty() {
0
} else if instrs.last().is_some_and(SsaInstruction::is_terminator) {
instrs.len().saturating_sub(1)
} else {
instrs.len()
}
} else {
0
};
for (i, (block_idx, instr_idx, op)) in to_hoist.iter().enumerate() {
if let Some(preheader_block) = ssa.block_mut(preheader.index()) {
let new_instr = SsaInstruction::synthetic(op.clone());
let instrs = preheader_block.instructions_mut();
instrs.insert(insert_base + i, new_instr);
}
if let Some(block) = ssa.block_mut(*block_idx) {
if let Some(instr) = block.instructions_mut().get_mut(*instr_idx) {
instr.set_op(SsaOp::Nop);
}
}
total_hoisted += 1;
}
}
if total_hoisted > 0 {
ctx.events
.record(EventKind::InstructionRemoved)
.at(method_token, 0)
.message(format!(
"LICM: hoisted {total_hoisted} loop-invariant instructions"
));
}
Ok(total_hoisted > 0)
}
}
fn find_loop_invariants(ssa: &SsaFunction, loop_info: &LoopInfo) -> Vec<(usize, usize)> {
let mut invariants: HashSet<(usize, usize)> = HashSet::new();
let mut invariant_defs: HashSet<SsaVarId> = HashSet::new();
let mut header_phi_defs: HashSet<SsaVarId> = HashSet::new();
if let Some(header_block) = ssa.block(loop_info.header.index()) {
for phi in header_block.phi_nodes() {
header_phi_defs.insert(phi.result());
}
}
let mut outside_defs: HashSet<SsaVarId> = HashSet::new();
for var in ssa.variables() {
let def_site = var.def_site();
if !loop_info.body.contains(&NodeId::new(def_site.block)) {
outside_defs.insert(var.id());
}
}
let mut changed = true;
while changed {
changed = false;
for body_block in &loop_info.body {
let block_idx = body_block.index();
if let Some(block) = ssa.block(block_idx) {
for (instr_idx, instr) in block.instructions().iter().enumerate() {
if invariants.contains(&(block_idx, instr_idx)) {
continue;
}
if instr.is_terminator() {
continue;
}
if is_instruction_invariant(
instr,
&outside_defs,
&invariant_defs,
&header_phi_defs,
) {
invariants.insert((block_idx, instr_idx));
if let Some(def) = instr.def() {
invariant_defs.insert(def);
}
changed = true;
}
}
}
}
}
invariants.into_iter().collect()
}
fn is_instruction_invariant(
instr: &SsaInstruction,
outside_defs: &HashSet<SsaVarId>,
invariant_defs: &HashSet<SsaVarId>,
header_phi_defs: &HashSet<SsaVarId>,
) -> bool {
for operand in instr.op().uses() {
if header_phi_defs.contains(&operand) {
return false;
}
if !outside_defs.contains(&operand) && !invariant_defs.contains(&operand) {
return false;
}
}
true
}
fn can_hoist(ssa: &SsaFunction, loop_info: &LoopInfo, block_idx: usize, instr_idx: usize) -> bool {
let Some(block) = ssa.block(block_idx) else {
return false;
};
let Some(instr) = block.instruction(instr_idx) else {
return false;
};
if !instr.op().is_pure() {
return false;
}
if loop_info.preheader.is_none() {
return false;
}
if let Some(dest) = instr.def() {
if feeds_phi_back_edge(ssa, loop_info, dest) {
return false;
}
}
true
}
fn feeds_phi_back_edge(ssa: &SsaFunction, loop_info: &LoopInfo, var: SsaVarId) -> bool {
let header_idx = loop_info.header.index();
let mut worklist: VecDeque<SsaVarId> = VecDeque::new();
let mut visited: HashSet<SsaVarId> = HashSet::new();
worklist.push_back(var);
visited.insert(var);
while let Some(current) = worklist.pop_front() {
if let Some(header_block) = ssa.block(header_idx) {
for phi in header_block.phi_nodes() {
for operand in phi.operands() {
if operand.value() == current {
let pred = operand.predecessor();
if loop_info.body.contains(&NodeId::new(pred)) && pred != header_idx {
return true;
}
}
}
}
}
for body_block_id in &loop_info.body {
if let Some(body_block) = ssa.block(body_block_id.index()) {
for instr in body_block.instructions() {
if instr.op().uses().contains(¤t) {
if let Some(dest) = instr.def() {
if visited.insert(dest) {
worklist.push_back(dest);
}
}
}
}
}
}
}
false
}
#[cfg(test)]
mod tests {
use crate::{
analysis::{ConstValue, LoopAnalyzer, MethodRef, SsaFunctionBuilder, SsaOp, SsaVarId},
compiler::{LicmPass, SsaPass},
metadata::token::Token,
};
#[test]
fn test_pass_metadata() {
let pass = LicmPass::new();
assert_eq!(pass.name(), "licm");
assert!(!pass.description().is_empty());
}
#[test]
fn test_op_is_pure() {
let add_op = SsaOp::Add {
dest: SsaVarId::new(),
left: SsaVarId::new(),
right: SsaVarId::new(),
};
assert!(add_op.is_pure());
let const_op = SsaOp::Const {
dest: SsaVarId::new(),
value: ConstValue::I32(42),
};
assert!(const_op.is_pure());
let call_op = SsaOp::Call {
dest: Some(SsaVarId::new()),
method: MethodRef::new(Token::new(0x06000001)),
args: vec![],
};
assert!(!call_op.is_pure());
}
#[test]
fn test_op_uses() {
let v1 = SsaVarId::new();
let v2 = SsaVarId::new();
let dest = SsaVarId::new();
let op = SsaOp::Add {
dest,
left: v1,
right: v2,
};
let uses = op.uses();
assert_eq!(uses.len(), 2);
assert!(uses.contains(&v1));
assert!(uses.contains(&v2));
let const_op = SsaOp::Const {
dest,
value: ConstValue::I32(42),
};
assert!(const_op.uses().is_empty());
}
#[test]
fn test_no_loops() {
let ssa = SsaFunctionBuilder::new(0, 0).build_with(|f| {
f.block(0, |b| {
let _ = b.const_i32(42);
b.ret();
});
});
let forest = LoopAnalyzer::new(&ssa).analyze();
assert!(forest.is_empty());
}
#[test]
fn test_loop_without_preheader() {
let ssa = SsaFunctionBuilder::new(0, 0).build_with(|f| {
f.block(0, |b| {
let cond = b.const_true();
b.branch(cond, 1, 2);
});
f.block(1, |b| b.jump(3));
f.block(2, |b| b.jump(3));
f.block(3, |b| {
let cond = b.const_true();
b.branch(cond, 3, 4); });
f.block(4, |b| b.ret());
});
let forest = LoopAnalyzer::new(&ssa).analyze();
assert!(!forest.is_empty());
let loop_info = &forest.loops()[0];
assert!(!loop_info.has_preheader());
}
#[test]
fn test_simple_loop_has_preheader() {
let ssa = SsaFunctionBuilder::new(0, 0).build_with(|f| {
f.block(0, |b| b.jump(1));
f.block(1, |b| {
let cond = b.const_true();
b.branch(cond, 1, 2);
});
f.block(2, |b| b.ret());
});
let forest = LoopAnalyzer::new(&ssa).analyze();
assert_eq!(forest.len(), 1);
let loop_info = &forest.loops()[0];
assert!(loop_info.has_preheader());
}
}