cairo_lang_lowering/optimizations/
dedup_blocks.rs

1#[cfg(test)]
2#[path = "dedup_blocks_test.rs"]
3mod test;
4
5use cairo_lang_semantic::items::constant::ConstValueId;
6use cairo_lang_semantic::{ConcreteVariant, TypeId};
7use cairo_lang_utils::ordered_hash_map::OrderedHashMap;
8use cairo_lang_utils::unordered_hash_map::{self, UnorderedHashMap};
9use itertools::{Itertools, zip_eq};
10
11use crate::ids::FunctionId;
12use crate::utils::{Rebuilder, RebuilderEx};
13use crate::{
14    Block, BlockEnd, BlockId, Lowered, Statement, StatementCall, StatementConst, StatementDesnap,
15    StatementEnumConstruct, StatementSnapshot, StatementStructConstruct,
16    StatementStructDestructure, VarRemapping, VarUsage, VariableArena, VariableId,
17};
18
19/// A canonic representation of a block (used to find duplicated blocks).
20/// Currently only blocks that end with return are supported.
21#[derive(Hash, PartialEq, Eq)]
22struct CanonicBlock<'db> {
23    /// Canonic representation of the statements in the block.
24    stmts: Vec<CanonicStatement<'db>>,
25    /// The types of variables introduced in the block.
26    types: Vec<TypeId<'db>>,
27    /// variables returned by the block.
28    returns: Vec<CanonicVar>,
29}
30
31/// A canonic representation of a variable in a canonic block.
32#[derive(Hash, PartialEq, Eq)]
33struct CanonicVar(usize);
34
35/// A canonic representation of a statement in a canonic block.
36#[derive(Hash, PartialEq, Eq)]
37enum CanonicStatement<'db> {
38    Const {
39        value: ConstValueId<'db>,
40        output: CanonicVar,
41        boxed: bool,
42    },
43    Call {
44        function: FunctionId<'db>,
45        inputs: Vec<CanonicVar>,
46        with_coupon: bool,
47        outputs: Vec<CanonicVar>,
48    },
49    StructConstruct {
50        inputs: Vec<CanonicVar>,
51        output: CanonicVar,
52    },
53    StructDestructure {
54        input: CanonicVar,
55        outputs: Vec<CanonicVar>,
56    },
57    EnumConstruct {
58        variant: ConcreteVariant<'db>,
59        input: CanonicVar,
60        output: CanonicVar,
61    },
62
63    Snapshot {
64        input: CanonicVar,
65        outputs: [CanonicVar; 2],
66    },
67    Desnap {
68        input: CanonicVar,
69        output: CanonicVar,
70    },
71}
72
73struct CanonicBlockBuilder<'db, 'a> {
74    variable: &'a VariableArena<'db>,
75    vars: UnorderedHashMap<VariableId, usize>,
76    types: Vec<TypeId<'db>>,
77    inputs: Vec<VarUsage<'db>>,
78}
79
80impl<'db, 'a> CanonicBlockBuilder<'db, 'a> {
81    fn new(variable: &'a VariableArena<'db>) -> CanonicBlockBuilder<'db, 'a> {
82        CanonicBlockBuilder {
83            variable,
84            vars: Default::default(),
85            types: vec![],
86            inputs: Default::default(),
87        }
88    }
89
90    /// Converts an input var to a CanonicVar.
91    fn handle_input(&mut self, var_usage: &VarUsage<'db>) -> CanonicVar {
92        let v = var_usage.var_id;
93
94        CanonicVar(match self.vars.entry(v) {
95            std::collections::hash_map::Entry::Occupied(e) => *e.get(),
96            std::collections::hash_map::Entry::Vacant(e) => {
97                self.types.push(self.variable[v].ty);
98                let new_id = *e.insert(self.types.len() - 1);
99                self.inputs.push(*var_usage);
100                new_id
101            }
102        })
103    }
104
105    /// Converts an output var to a CanonicVar.
106    fn handle_output(&mut self, v: &VariableId) -> CanonicVar {
107        CanonicVar(match self.vars.entry(*v) {
108            std::collections::hash_map::Entry::Occupied(e) => *e.get(),
109            std::collections::hash_map::Entry::Vacant(e) => {
110                self.types.push(self.variable[*v].ty);
111                *e.insert(self.types.len() - 1)
112            }
113        })
114    }
115
116    /// Converts a statement to a canonic statement.
117    fn handle_statement(&mut self, statement: &Statement<'db>) -> CanonicStatement<'db> {
118        match statement {
119            Statement::Const(StatementConst { value, boxed, output }) => CanonicStatement::Const {
120                value: *value,
121                output: self.handle_output(output),
122                boxed: *boxed,
123            },
124            Statement::Call(StatementCall {
125                function,
126                inputs,
127                with_coupon,
128                outputs,
129                location: _,
130                is_specialization_base_call: _,
131            }) => CanonicStatement::Call {
132                function: *function,
133                inputs: inputs.iter().map(|input| self.handle_input(input)).collect(),
134                with_coupon: *with_coupon,
135                outputs: outputs.iter().map(|output| self.handle_output(output)).collect(),
136            },
137            Statement::StructConstruct(StatementStructConstruct { inputs, output }) => {
138                CanonicStatement::StructConstruct {
139                    inputs: inputs.iter().map(|input| self.handle_input(input)).collect(),
140                    output: self.handle_output(output),
141                }
142            }
143            Statement::StructDestructure(StatementStructDestructure { input, outputs }) => {
144                CanonicStatement::StructDestructure {
145                    input: self.handle_input(input),
146                    outputs: outputs.iter().map(|output| self.handle_output(output)).collect(),
147                }
148            }
149            Statement::EnumConstruct(StatementEnumConstruct { variant, input, output }) => {
150                CanonicStatement::EnumConstruct {
151                    variant: *variant,
152                    input: self.handle_input(input),
153                    output: self.handle_output(output),
154                }
155            }
156            Statement::Snapshot(StatementSnapshot { input, outputs }) => {
157                CanonicStatement::Snapshot {
158                    input: self.handle_input(input),
159                    outputs: outputs.map(|output| self.handle_output(&output)),
160                }
161            }
162            Statement::Desnap(StatementDesnap { input, output }) => CanonicStatement::Desnap {
163                input: self.handle_input(input),
164                output: self.handle_output(output),
165            },
166        }
167    }
168}
169
170impl<'db> CanonicBlock<'db> {
171    /// Tries to create a canonic block from a flat block.
172    /// Return the canonic representation of the block and the external inputs used in the block.
173    /// Blocks that do not end in return do not have a canonic representation.
174    fn try_from_block(
175        variable: &VariableArena<'db>,
176        block: &Block<'db>,
177    ) -> Option<(CanonicBlock<'db>, Vec<VarUsage<'db>>)> {
178        let BlockEnd::Return(returned_vars, _) = &block.end else {
179            return None;
180        };
181
182        if block.statements.is_empty() {
183            // Skip deduplication for empty blocks.
184            return None;
185        }
186
187        let mut builder = CanonicBlockBuilder::new(variable);
188
189        let stmts = block
190            .statements
191            .iter()
192            .map(|statement| builder.handle_statement(statement))
193            .collect_vec();
194
195        let returns = returned_vars.iter().map(|input| builder.handle_input(input)).collect();
196
197        Some((CanonicBlock { stmts, types: builder.types, returns }, builder.inputs))
198    }
199}
200/// Helper class to reassign variable ids.
201pub struct VarReassigner<'db, 'a> {
202    pub variables: &'a mut VariableArena<'db>,
203
204    // Maps old var_id to new_var_id
205    pub vars: UnorderedHashMap<VariableId, VariableId>,
206}
207
208impl<'db, 'a> VarReassigner<'db, 'a> {
209    pub fn new(variables: &'a mut VariableArena<'db>) -> Self {
210        Self { variables, vars: UnorderedHashMap::default() }
211    }
212}
213
214impl<'db, 'a> Rebuilder<'db> for VarReassigner<'db, 'a> {
215    fn map_var_id(&mut self, var: VariableId) -> VariableId {
216        *self.vars.entry(var).or_insert_with(|| self.variables.alloc(self.variables[var].clone()))
217    }
218}
219
220#[derive(Default)]
221struct DedupContext<'db> {
222    /// Maps a CanonicBlock to a reference block that matches it.
223    canonic_blocks: UnorderedHashMap<CanonicBlock<'db>, BlockId>,
224
225    /// Maps a block to the inputs that are needed for it to be shared by multiple flows.
226    block_id_to_inputs: UnorderedHashMap<BlockId, Vec<VarUsage<'db>>>,
227}
228
229/// Given a block and a set of inputs, assigns new ids to all the variables in the block, returning
230/// the new block and the new inputs.
231fn rebuild_block_and_inputs<'db>(
232    variables: &mut VariableArena<'db>,
233    block: &Block<'db>,
234    inputs: &[VarUsage<'db>],
235) -> (Block<'db>, Vec<VarUsage<'db>>) {
236    let mut var_reassigner = VarReassigner::new(variables);
237    (
238        var_reassigner.rebuild_block(block),
239        inputs.iter().map(|var_usage| var_reassigner.map_var_usage(*var_usage)).collect(),
240    )
241}
242
243/// Deduplicates blocks by redirecting goto's and match arms to one of the duplicates.
244/// The duplicate blocks will be removed later by `reorganize_blocks`.
245pub fn dedup_blocks<'db>(lowered: &mut Lowered<'db>) {
246    if lowered.blocks.has_root().is_err() {
247        return;
248    }
249
250    let mut ctx = DedupContext::default();
251    // Maps duplicated blocks to the new shared block and the inputs that need to be remapped for
252    // the block.
253    let mut duplicates: UnorderedHashMap<BlockId, (BlockId, Vec<VarUsage<'_>>)> =
254        Default::default();
255
256    let mut new_blocks = vec![];
257    let mut next_block_id = BlockId(lowered.blocks.len());
258
259    for (block_id, block) in lowered.blocks.iter() {
260        let Some((canonical_block, inputs)) =
261            CanonicBlock::try_from_block(&lowered.variables, block)
262        else {
263            continue;
264        };
265
266        match ctx.canonic_blocks.entry(canonical_block) {
267            unordered_hash_map::Entry::Occupied(e) => {
268                let block_and_inputs = duplicates
269                    .entry(*e.get())
270                    .or_insert_with(|| {
271                        let (block, new_inputs) =
272                            rebuild_block_and_inputs(&mut lowered.variables, block, &inputs);
273                        new_blocks.push(block);
274                        let new_block_id = next_block_id;
275                        next_block_id = next_block_id.next_block_id();
276
277                        (new_block_id, new_inputs)
278                    })
279                    .clone();
280
281                duplicates.insert(block_id, block_and_inputs);
282            }
283            unordered_hash_map::Entry::Vacant(e) => {
284                e.insert(block_id);
285            }
286        };
287
288        ctx.block_id_to_inputs.insert(block_id, inputs);
289    }
290
291    let mut new_goto_block =
292        |block_id, inputs: &Vec<VarUsage<'db>>, target_inputs: &Vec<VarUsage<'db>>| {
293            new_blocks.push(Block {
294                statements: vec![],
295                end: BlockEnd::Goto(
296                    block_id,
297                    VarRemapping {
298                        remapping: OrderedHashMap::from_iter(zip_eq(
299                            target_inputs.iter().map(|var_usage| var_usage.var_id),
300                            inputs.iter().cloned(),
301                        )),
302                    },
303                ),
304            });
305
306            let new_block_id = next_block_id;
307            next_block_id = next_block_id.next_block_id();
308            new_block_id
309        };
310
311    // Note that the loop below can't be merged with the loop above as a block might be marked as
312    // dup after we already visiting an arm that goes to it.
313    for block in lowered.blocks.iter_mut() {
314        match &mut block.end {
315            BlockEnd::Goto(target_block, remappings) => {
316                let Some((block_id, target_inputs)) = duplicates.get(target_block) else {
317                    continue;
318                };
319
320                let inputs = ctx.block_id_to_inputs.get(target_block).unwrap();
321                let mut inputs_remapping = VarRemapping {
322                    remapping: OrderedHashMap::from_iter(zip_eq(
323                        target_inputs.iter().map(|var_usage| var_usage.var_id),
324                        inputs.iter().cloned(),
325                    )),
326                };
327                for (_, src) in inputs_remapping.iter_mut() {
328                    if let Some(src_before_remapping) = remappings.get(&src.var_id) {
329                        *src = *src_before_remapping;
330                    }
331                }
332
333                *target_block = *block_id;
334                *remappings = inputs_remapping;
335            }
336            BlockEnd::Match { info } => {
337                for arm in info.arms_mut() {
338                    let Some((block_id, target_inputs)) = duplicates.get(&arm.block_id) else {
339                        continue;
340                    };
341
342                    let inputs = &ctx.block_id_to_inputs[&arm.block_id];
343                    arm.block_id = new_goto_block(*block_id, inputs, target_inputs);
344                }
345            }
346            _ => {}
347        }
348    }
349
350    for block in new_blocks {
351        lowered.blocks.push(block);
352    }
353}