Skip to main content

wave_compiler/optimize/
licm.rs

1// Copyright 2026 Ojima Abraham
2// SPDX-License-Identifier: Apache-2.0
3
4//! Loop-invariant code motion pass.
5//!
6//! Identifies instructions inside loops whose operands do not change
7//! across iterations and moves them to the loop preheader.
8
9use std::collections::HashSet;
10
11use super::pass::Pass;
12use crate::analysis::cfg::Cfg;
13use crate::analysis::dominance::DomTree;
14use crate::analysis::loop_analysis::LoopInfo;
15use crate::mir::function::MirFunction;
16use crate::mir::instruction::MirInst;
17use crate::mir::value::ValueId;
18
19/// Loop-invariant code motion pass.
20pub struct Licm;
21
22impl Pass for Licm {
23    fn name(&self) -> &'static str {
24        "licm"
25    }
26
27    fn run(&self, func: &mut MirFunction) -> bool {
28        let cfg = Cfg::build(func);
29        let dom = DomTree::compute(&cfg);
30        let loop_info = LoopInfo::compute(&cfg, &dom);
31
32        let mut changed = false;
33
34        for natural_loop in &loop_info.loops {
35            let mut defs_in_loop: HashSet<ValueId> = HashSet::new();
36            for &bid in &natural_loop.body {
37                if let Some(block) = func.block(bid) {
38                    for inst in &block.instructions {
39                        if let Some(dest) = inst.dest() {
40                            defs_in_loop.insert(dest);
41                        }
42                    }
43                }
44            }
45
46            let mut invariant_insts: Vec<(usize, MirInst)> = Vec::new();
47
48            for &bid in &natural_loop.body {
49                if let Some(block) = func.block(bid) {
50                    for (idx, inst) in block.instructions.iter().enumerate() {
51                        if inst.has_side_effects() {
52                            continue;
53                        }
54                        let all_operands_invariant =
55                            inst.operands().iter().all(|op| !defs_in_loop.contains(op));
56                        if all_operands_invariant {
57                            if let Some(dest) = inst.dest() {
58                                invariant_insts.push((idx, inst.clone()));
59                                defs_in_loop.remove(&dest);
60                            }
61                        }
62                    }
63                }
64            }
65
66            if !invariant_insts.is_empty() {
67                let preds = cfg.preds(natural_loop.header);
68                let preheader = preds
69                    .iter()
70                    .find(|p| !natural_loop.body.contains(p))
71                    .copied();
72
73                if let Some(pre_bid) = preheader {
74                    if let Some(pre_block) = func.block_mut(pre_bid) {
75                        for (_, inst) in &invariant_insts {
76                            pre_block.instructions.push(inst.clone());
77                        }
78                        changed = true;
79                    }
80                }
81            }
82        }
83
84        changed
85    }
86}
87
88#[cfg(test)]
89mod tests {
90    use super::*;
91    use crate::hir::expr::BinOp;
92    use crate::mir::basic_block::{BasicBlock, Terminator};
93    use crate::mir::instruction::{ConstValue, MirInst};
94    use crate::mir::types::MirType;
95    use crate::mir::value::BlockId;
96
97    #[test]
98    fn test_licm_no_loops_no_change() {
99        let mut func = MirFunction::new("test".into(), BlockId(0));
100        let mut bb = BasicBlock::new(BlockId(0));
101        bb.instructions.push(MirInst::Const {
102            dest: ValueId(0),
103            value: ConstValue::I32(42),
104        });
105        bb.terminator = Terminator::Return;
106        func.blocks.push(bb);
107
108        let pass = Licm;
109        assert!(!pass.run(&mut func));
110    }
111
112    #[test]
113    fn test_licm_hoists_invariant() {
114        let mut func = MirFunction::new("test".into(), BlockId(0));
115
116        let mut bb0 = BasicBlock::new(BlockId(0));
117        bb0.terminator = Terminator::Branch { target: BlockId(1) };
118
119        let mut bb1 = BasicBlock::new(BlockId(1));
120        bb1.instructions.push(MirInst::BinOp {
121            dest: ValueId(2),
122            op: BinOp::Add,
123            lhs: ValueId(0),
124            rhs: ValueId(1),
125            ty: MirType::I32,
126        });
127        bb1.terminator = Terminator::CondBranch {
128            cond: ValueId(2),
129            true_target: BlockId(2),
130            false_target: BlockId(3),
131        };
132
133        let mut bb2 = BasicBlock::new(BlockId(2));
134        bb2.terminator = Terminator::Branch { target: BlockId(1) };
135
136        let bb3 = BasicBlock::new(BlockId(3));
137
138        func.blocks.push(bb0);
139        func.blocks.push(bb1);
140        func.blocks.push(bb2);
141        func.blocks.push(bb3);
142
143        let pass = Licm;
144        let changed = pass.run(&mut func);
145        assert!(changed);
146        assert!(!func.blocks[0].instructions.is_empty());
147    }
148}