cairo_lang_lowering/optimizations/
dedup_blocks.rs1#[cfg(test)]
2#[path = "dedup_blocks_test.rs"]
3mod test;
4
5use cairo_lang_semantic::items::constant::ConstValue;
6use cairo_lang_semantic::{ConcreteVariant, TypeId};
7use cairo_lang_utils::ordered_hash_map::OrderedHashMap;
8use cairo_lang_utils::unordered_hash_map::{self, UnorderedHashMap};
9use id_arena::Arena;
10use itertools::{Itertools, zip_eq};
11
12use crate::ids::FunctionId;
13use crate::utils::{Rebuilder, RebuilderEx};
14use crate::{
15 Block, BlockEnd, BlockId, Lowered, Statement, StatementCall, StatementConst, StatementDesnap,
16 StatementEnumConstruct, StatementSnapshot, StatementStructConstruct,
17 StatementStructDestructure, VarRemapping, VarUsage, Variable, VariableId,
18};
19
20#[derive(Hash, PartialEq, Eq)]
23struct CanonicBlock {
24 stmts: Vec<CanonicStatement>,
26 types: Vec<TypeId>,
28 returns: Vec<CanonicVar>,
30}
31
32#[derive(Hash, PartialEq, Eq)]
34struct CanonicVar(usize);
35
36#[derive(Hash, PartialEq, Eq)]
38enum CanonicStatement {
39 Const {
40 value: ConstValue,
41 output: CanonicVar,
42 },
43 Call {
44 function: FunctionId,
45 inputs: Vec<CanonicVar>,
46 with_coupon: bool,
47 outputs: Vec<CanonicVar>,
48 },
49 StructConstruct {
50 inputs: Vec<CanonicVar>,
51 output: CanonicVar,
52 },
53 StructDestructure {
54 input: CanonicVar,
55 outputs: Vec<CanonicVar>,
56 },
57 EnumConstruct {
58 variant: ConcreteVariant,
59 input: CanonicVar,
60 output: CanonicVar,
61 },
62
63 Snapshot {
64 input: CanonicVar,
65 outputs: [CanonicVar; 2],
66 },
67 Desnap {
68 input: CanonicVar,
69 output: CanonicVar,
70 },
71}
72
73struct CanonicBlockBuilder<'a> {
74 variable: &'a Arena<Variable>,
75 vars: UnorderedHashMap<VariableId, usize>,
76 types: Vec<TypeId>,
77 inputs: Vec<VarUsage>,
78}
79
80impl CanonicBlockBuilder<'_> {
81 fn new(variable: &Arena<Variable>) -> CanonicBlockBuilder<'_> {
82 CanonicBlockBuilder {
83 variable,
84 vars: Default::default(),
85 types: vec![],
86 inputs: Default::default(),
87 }
88 }
89
90 fn handle_input(&mut self, var_usage: &VarUsage) -> CanonicVar {
92 let v = var_usage.var_id;
93
94 CanonicVar(match self.vars.entry(v) {
95 std::collections::hash_map::Entry::Occupied(e) => *e.get(),
96 std::collections::hash_map::Entry::Vacant(e) => {
97 self.types.push(self.variable[v].ty);
98 let new_id = *e.insert(self.types.len() - 1);
99 self.inputs.push(*var_usage);
100 new_id
101 }
102 })
103 }
104
105 fn handle_output(&mut self, v: &VariableId) -> CanonicVar {
107 CanonicVar(match self.vars.entry(*v) {
108 std::collections::hash_map::Entry::Occupied(e) => *e.get(),
109 std::collections::hash_map::Entry::Vacant(e) => {
110 self.types.push(self.variable[*v].ty);
111 *e.insert(self.types.len() - 1)
112 }
113 })
114 }
115
116 fn handle_statement(&mut self, statement: &Statement) -> CanonicStatement {
118 match statement {
119 Statement::Const(StatementConst { value, output }) => {
120 CanonicStatement::Const { value: value.clone(), output: self.handle_output(output) }
121 }
122 Statement::Call(StatementCall {
123 function,
124 inputs,
125 with_coupon,
126 outputs,
127 location: _,
128 }) => CanonicStatement::Call {
129 function: *function,
130 inputs: inputs.iter().map(|input| self.handle_input(input)).collect(),
131 with_coupon: *with_coupon,
132 outputs: outputs.iter().map(|output| self.handle_output(output)).collect(),
133 },
134 Statement::StructConstruct(StatementStructConstruct { inputs, output }) => {
135 CanonicStatement::StructConstruct {
136 inputs: inputs.iter().map(|input| self.handle_input(input)).collect(),
137 output: self.handle_output(output),
138 }
139 }
140 Statement::StructDestructure(StatementStructDestructure { input, outputs }) => {
141 CanonicStatement::StructDestructure {
142 input: self.handle_input(input),
143 outputs: outputs.iter().map(|output| self.handle_output(output)).collect(),
144 }
145 }
146 Statement::EnumConstruct(StatementEnumConstruct { variant, input, output }) => {
147 CanonicStatement::EnumConstruct {
148 variant: *variant,
149 input: self.handle_input(input),
150 output: self.handle_output(output),
151 }
152 }
153 Statement::Snapshot(StatementSnapshot { input, outputs }) => {
154 CanonicStatement::Snapshot {
155 input: self.handle_input(input),
156 outputs: outputs.map(|output| self.handle_output(&output)),
157 }
158 }
159 Statement::Desnap(StatementDesnap { input, output }) => CanonicStatement::Desnap {
160 input: self.handle_input(input),
161 output: self.handle_output(output),
162 },
163 }
164 }
165}
166
167impl CanonicBlock {
168 fn try_from_block(
172 variable: &Arena<Variable>,
173 block: &Block,
174 ) -> Option<(CanonicBlock, Vec<VarUsage>)> {
175 let BlockEnd::Return(returned_vars, _) = &block.end else {
176 return None;
177 };
178
179 if block.statements.is_empty() {
180 return None;
182 }
183
184 let mut builder = CanonicBlockBuilder::new(variable);
185
186 let stmts = block
187 .statements
188 .iter()
189 .map(|statement| builder.handle_statement(statement))
190 .collect_vec();
191
192 let returns = returned_vars.iter().map(|input| builder.handle_input(input)).collect();
193
194 Some((CanonicBlock { stmts, types: builder.types, returns }, builder.inputs))
195 }
196}
197pub struct VarReassigner<'a> {
199 pub variables: &'a mut Arena<Variable>,
200
201 pub vars: UnorderedHashMap<VariableId, VariableId>,
203}
204
205impl<'a> VarReassigner<'a> {
206 pub fn new(variables: &'a mut Arena<Variable>) -> Self {
207 Self { variables, vars: UnorderedHashMap::default() }
208 }
209}
210
211impl Rebuilder for VarReassigner<'_> {
212 fn map_var_id(&mut self, var: VariableId) -> VariableId {
213 *self.vars.entry(var).or_insert_with(|| self.variables.alloc(self.variables[var].clone()))
214 }
215}
216
217#[derive(Default)]
218struct DedupContext {
219 canonic_blocks: UnorderedHashMap<CanonicBlock, BlockId>,
221
222 block_id_to_inputs: UnorderedHashMap<BlockId, Vec<VarUsage>>,
224}
225
226fn rebuild_block_and_inputs(
229 variables: &mut Arena<Variable>,
230 block: &Block,
231 inputs: &[VarUsage],
232) -> (Block, Vec<VarUsage>) {
233 let mut var_reassigner = VarReassigner::new(variables);
234 (
235 var_reassigner.rebuild_block(block),
236 inputs.iter().map(|var_usage| var_reassigner.map_var_usage(*var_usage)).collect(),
237 )
238}
239
240pub fn dedup_blocks(lowered: &mut Lowered) {
243 if lowered.blocks.has_root().is_err() {
244 return;
245 }
246
247 let mut ctx = DedupContext::default();
248 let mut duplicates: UnorderedHashMap<BlockId, (BlockId, Vec<VarUsage>)> = Default::default();
251
252 let mut new_blocks = vec![];
253 let mut next_block_id = BlockId(lowered.blocks.len());
254
255 for (block_id, block) in lowered.blocks.iter() {
256 let Some((canonical_block, inputs)) =
257 CanonicBlock::try_from_block(&lowered.variables, block)
258 else {
259 continue;
260 };
261
262 match ctx.canonic_blocks.entry(canonical_block) {
263 unordered_hash_map::Entry::Occupied(e) => {
264 let block_and_inputs = duplicates
265 .entry(*e.get())
266 .or_insert_with(|| {
267 let (block, new_inputs) =
268 rebuild_block_and_inputs(&mut lowered.variables, block, &inputs);
269 new_blocks.push(block);
270 let new_block_id = next_block_id;
271 next_block_id = next_block_id.next_block_id();
272
273 (new_block_id, new_inputs)
274 })
275 .clone();
276
277 duplicates.insert(block_id, block_and_inputs);
278 }
279 unordered_hash_map::Entry::Vacant(e) => {
280 e.insert(block_id);
281 }
282 };
283
284 ctx.block_id_to_inputs.insert(block_id, inputs);
285 }
286
287 let mut new_goto_block = |block_id, inputs: &Vec<VarUsage>, target_inputs: &Vec<VarUsage>| {
288 new_blocks.push(Block {
289 statements: vec![],
290 end: BlockEnd::Goto(
291 block_id,
292 VarRemapping {
293 remapping: OrderedHashMap::from_iter(zip_eq(
294 target_inputs.iter().map(|var_usage| var_usage.var_id),
295 inputs.iter().cloned(),
296 )),
297 },
298 ),
299 });
300
301 let new_block_id = next_block_id;
302 next_block_id = next_block_id.next_block_id();
303 new_block_id
304 };
305
306 for block in lowered.blocks.iter_mut() {
309 match &mut block.end {
310 BlockEnd::Goto(target_block, remappings) => {
311 let Some((block_id, target_inputs)) = duplicates.get(target_block) else {
312 continue;
313 };
314
315 let inputs = ctx.block_id_to_inputs.get(target_block).unwrap();
316 let mut inputs_remapping = VarRemapping {
317 remapping: OrderedHashMap::from_iter(zip_eq(
318 target_inputs.iter().map(|var_usage| var_usage.var_id),
319 inputs.iter().cloned(),
320 )),
321 };
322 for (_, src) in inputs_remapping.iter_mut() {
323 if let Some(src_before_remapping) = remappings.get(&src.var_id) {
324 *src = *src_before_remapping;
325 }
326 }
327
328 *target_block = *block_id;
329 *remappings = inputs_remapping;
330 }
331 BlockEnd::Match { info } => {
332 for arm in info.arms_mut() {
333 let Some((block_id, target_inputs)) = duplicates.get(&arm.block_id) else {
334 continue;
335 };
336
337 let inputs = &ctx.block_id_to_inputs[&arm.block_id];
338 arm.block_id = new_goto_block(*block_id, inputs, target_inputs);
339 }
340 }
341 _ => {}
342 }
343 }
344
345 for block in new_blocks {
346 lowered.blocks.push(block);
347 }
348}