midenc_hir_transform/
sccp.rs

1use midenc_hir::{
2    pass::{Pass, PassExecutionState},
3    patterns::NoopRewriterListener,
4    BlockRef, Builder, EntityMut, OpBuilder, Operation, OperationFolder, OperationName, RegionList,
5    Report, SmallVec, ValueRef,
6};
7use midenc_hir_analysis::{
8    analyses::{constant_propagation::ConstantValue, DeadCodeAnalysis, SparseConstantPropagation},
9    DataFlowSolver, Lattice,
10};
11
12/// This pass implements a general algorithm for sparse conditional constant propagation.
13///
14/// This algorithm detects values that are known to be constant and optimistically propagates this
15/// throughout the IR. Any values proven to be constant are replaced, and removed if possible.
16///
17/// This implementation is based on the algorithm described by Wegman and Zadeck in
18/// [“Constant Propagation with Conditional Branches”](https://dl.acm.org/doi/10.1145/103135.103136)
19/// (1991).
20pub struct SparseConditionalConstantPropagation;
21
22impl Pass for SparseConditionalConstantPropagation {
23    type Target = Operation;
24
25    fn name(&self) -> &'static str {
26        "sparse-conditional-constant-propagation"
27    }
28
29    fn argument(&self) -> &'static str {
30        "sparse-conditional-constant-propagation"
31    }
32
33    fn can_schedule_on(&self, _name: &OperationName) -> bool {
34        true
35    }
36
37    fn run_on_operation(
38        &mut self,
39        mut op: EntityMut<'_, Self::Target>,
40        state: &mut PassExecutionState,
41    ) -> Result<(), Report> {
42        // Run sparse constant propagation + dead code analysis
43        let mut solver = DataFlowSolver::default();
44        solver.load::<DeadCodeAnalysis>();
45        solver.load::<SparseConstantPropagation>();
46        solver.initialize_and_run(&op, state.analysis_manager().clone())?;
47
48        // Rewrite based on results of analysis
49        self.rewrite(&mut op, state, &solver)
50    }
51}
52
53impl SparseConditionalConstantPropagation {
54    /// Rewrite the given regions using the computing analysis. This replaces the uses of all values
55    /// that have been computed to be constant, and erases as many newly dead operations.
56    fn rewrite(
57        &mut self,
58        op: &mut Operation,
59        state: &mut PassExecutionState,
60        solver: &DataFlowSolver,
61    ) -> Result<(), Report> {
62        let mut worklist = SmallVec::<[BlockRef; 8]>::default();
63
64        let add_to_worklist = |regions: &RegionList, worklist: &mut SmallVec<[BlockRef; 8]>| {
65            for region in regions {
66                for block in region.body().iter().rev() {
67                    worklist.push(block.as_block_ref());
68                }
69            }
70        };
71
72        // An operation folder used to create and unique constants.
73        let context = op.context_rc();
74        let mut folder = OperationFolder::new(context.clone(), None::<NoopRewriterListener>);
75        let mut builder = OpBuilder::new(context.clone());
76
77        add_to_worklist(op.regions(), &mut worklist);
78
79        let mut replaced_any = false;
80        while let Some(mut block) = worklist.pop() {
81            let mut block = block.borrow_mut();
82            let body = block.body_mut();
83            let mut ops = body.front();
84
85            while let Some(mut op) = ops.as_pointer() {
86                ops.move_next();
87
88                builder.set_insertion_point_after(op);
89
90                // Replace any result with constants.
91                let num_results = op.borrow().num_results();
92                let mut replaced_all = num_results != 0;
93                for index in 0..num_results {
94                    let result = { op.borrow().get_result(index).borrow().as_value_ref() };
95                    let replaced = replace_with_constant(solver, &mut builder, &mut folder, result);
96
97                    replaced_any |= replaced;
98                    replaced_all &= replaced;
99                }
100
101                // If all of the results of the operation were replaced, try to erase the operation
102                // completely.
103                let mut op = op.borrow_mut();
104                if replaced_all && op.would_be_trivially_dead() {
105                    assert!(!op.is_used(), "expected all uses to be replaced");
106                    op.erase();
107                    continue;
108                }
109
110                // Add any of the regions of this operation to the worklist
111                add_to_worklist(op.regions(), &mut worklist);
112            }
113
114            // Replace any block arguments with constants
115            builder.set_insertion_point_to_start(block.as_block_ref());
116
117            for arg in block.arguments() {
118                replaced_any |= replace_with_constant(
119                    solver,
120                    &mut builder,
121                    &mut folder,
122                    arg.borrow().as_value_ref(),
123                );
124            }
125        }
126
127        state.set_post_pass_status(replaced_any.into());
128
129        Ok(())
130    }
131}
132
133/// Replace the given value with a constant if the corresponding lattice represents a constant.
134///
135/// Returns success if the value was replaced, failure otherwise.
136fn replace_with_constant(
137    solver: &DataFlowSolver,
138    builder: &mut OpBuilder,
139    folder: &mut OperationFolder,
140    mut value: ValueRef,
141) -> bool {
142    let Some(lattice) = solver.get::<Lattice<ConstantValue>, _>(&value) else {
143        return false;
144    };
145    if lattice.value().is_uninitialized() {
146        return false;
147    }
148
149    let Some(constant_value) = lattice.value().constant_value() else {
150        return false;
151    };
152
153    // Attempt to materialize a constant for the given value.
154    let dialect = lattice.value().constant_dialect().unwrap();
155    let constant = folder.get_or_create_constant(
156        builder.insertion_block().unwrap(),
157        dialect,
158        constant_value,
159        value.borrow().ty().clone(),
160    );
161    if let Some(constant) = constant {
162        value.borrow_mut().replace_all_uses_with(constant);
163        true
164    } else {
165        false
166    }
167}