midenc_hir_transform/
sccp.rs1use midenc_hir::{
2 pass::{Pass, PassExecutionState},
3 patterns::NoopRewriterListener,
4 BlockRef, Builder, EntityMut, OpBuilder, Operation, OperationFolder, OperationName, RegionList,
5 Report, SmallVec, ValueRef,
6};
7use midenc_hir_analysis::{
8 analyses::{constant_propagation::ConstantValue, DeadCodeAnalysis, SparseConstantPropagation},
9 DataFlowSolver, Lattice,
10};
11
12pub struct SparseConditionalConstantPropagation;
21
22impl Pass for SparseConditionalConstantPropagation {
23 type Target = Operation;
24
25 fn name(&self) -> &'static str {
26 "sparse-conditional-constant-propagation"
27 }
28
29 fn argument(&self) -> &'static str {
30 "sparse-conditional-constant-propagation"
31 }
32
33 fn can_schedule_on(&self, _name: &OperationName) -> bool {
34 true
35 }
36
37 fn run_on_operation(
38 &mut self,
39 mut op: EntityMut<'_, Self::Target>,
40 state: &mut PassExecutionState,
41 ) -> Result<(), Report> {
42 let mut solver = DataFlowSolver::default();
44 solver.load::<DeadCodeAnalysis>();
45 solver.load::<SparseConstantPropagation>();
46 solver.initialize_and_run(&op, state.analysis_manager().clone())?;
47
48 self.rewrite(&mut op, state, &solver)
50 }
51}
52
53impl SparseConditionalConstantPropagation {
54 fn rewrite(
57 &mut self,
58 op: &mut Operation,
59 state: &mut PassExecutionState,
60 solver: &DataFlowSolver,
61 ) -> Result<(), Report> {
62 let mut worklist = SmallVec::<[BlockRef; 8]>::default();
63
64 let add_to_worklist = |regions: &RegionList, worklist: &mut SmallVec<[BlockRef; 8]>| {
65 for region in regions {
66 for block in region.body().iter().rev() {
67 worklist.push(block.as_block_ref());
68 }
69 }
70 };
71
72 let context = op.context_rc();
74 let mut folder = OperationFolder::new(context.clone(), None::<NoopRewriterListener>);
75 let mut builder = OpBuilder::new(context.clone());
76
77 add_to_worklist(op.regions(), &mut worklist);
78
79 let mut replaced_any = false;
80 while let Some(mut block) = worklist.pop() {
81 let mut block = block.borrow_mut();
82 let body = block.body_mut();
83 let mut ops = body.front();
84
85 while let Some(mut op) = ops.as_pointer() {
86 ops.move_next();
87
88 builder.set_insertion_point_after(op);
89
90 let num_results = op.borrow().num_results();
92 let mut replaced_all = num_results != 0;
93 for index in 0..num_results {
94 let result = { op.borrow().get_result(index).borrow().as_value_ref() };
95 let replaced = replace_with_constant(solver, &mut builder, &mut folder, result);
96
97 replaced_any |= replaced;
98 replaced_all &= replaced;
99 }
100
101 let mut op = op.borrow_mut();
104 if replaced_all && op.would_be_trivially_dead() {
105 assert!(!op.is_used(), "expected all uses to be replaced");
106 op.erase();
107 continue;
108 }
109
110 add_to_worklist(op.regions(), &mut worklist);
112 }
113
114 builder.set_insertion_point_to_start(block.as_block_ref());
116
117 for arg in block.arguments() {
118 replaced_any |= replace_with_constant(
119 solver,
120 &mut builder,
121 &mut folder,
122 arg.borrow().as_value_ref(),
123 );
124 }
125 }
126
127 state.set_post_pass_status(replaced_any.into());
128
129 Ok(())
130 }
131}
132
133fn replace_with_constant(
137 solver: &DataFlowSolver,
138 builder: &mut OpBuilder,
139 folder: &mut OperationFolder,
140 mut value: ValueRef,
141) -> bool {
142 let Some(lattice) = solver.get::<Lattice<ConstantValue>, _>(&value) else {
143 return false;
144 };
145 if lattice.value().is_uninitialized() {
146 return false;
147 }
148
149 let Some(constant_value) = lattice.value().constant_value() else {
150 return false;
151 };
152
153 let dialect = lattice.value().constant_dialect().unwrap();
155 let constant = folder.get_or_create_constant(
156 builder.insertion_block().unwrap(),
157 dialect,
158 constant_value,
159 value.borrow().ty().clone(),
160 );
161 if let Some(constant) = constant {
162 value.borrow_mut().replace_all_uses_with(constant);
163 true
164 } else {
165 false
166 }
167}