cairo_lang_lowering/optimizations/
dedup_blocks.rs

1#[cfg(test)]
2#[path = "dedup_blocks_test.rs"]
3mod test;
4
5use cairo_lang_semantic::items::constant::ConstValueId;
6use cairo_lang_semantic::{ConcreteVariant, TypeId};
7use cairo_lang_utils::ordered_hash_map::OrderedHashMap;
8use cairo_lang_utils::unordered_hash_map::{self, UnorderedHashMap};
9use itertools::{Itertools, zip_eq};
10
11use crate::ids::FunctionId;
12use crate::utils::{Rebuilder, RebuilderEx};
13use crate::{
14    Block, BlockEnd, BlockId, Lowered, Statement, StatementCall, StatementConst, StatementDesnap,
15    StatementEnumConstruct, StatementSnapshot, StatementStructConstruct,
16    StatementStructDestructure, VarRemapping, VarUsage, VariableArena, VariableId,
17};
18
19/// A canonic representation of a block (used to find duplicated blocks).
20/// Currently only blocks that end with return are supported.
21#[derive(Hash, PartialEq, Eq)]
22struct CanonicBlock<'db> {
23    /// Canonic representation of the statements in the block.
24    stmts: Vec<CanonicStatement<'db>>,
25    /// The types of variables introduced in the block.
26    types: Vec<TypeId<'db>>,
27    /// variables returned by the block.
28    returns: Vec<CanonicVar>,
29}
30
31/// A canonic representation of a variable in a canonic block.
32#[derive(Hash, PartialEq, Eq)]
33struct CanonicVar(usize);
34
35/// A canonic representation of a statement in a canonic block.
36#[derive(Hash, PartialEq, Eq)]
37enum CanonicStatement<'db> {
38    Const {
39        value: ConstValueId<'db>,
40        output: CanonicVar,
41        boxed: bool,
42    },
43    Call {
44        function: FunctionId<'db>,
45        inputs: Vec<CanonicVar>,
46        with_coupon: bool,
47        outputs: Vec<CanonicVar>,
48    },
49    StructConstruct {
50        inputs: Vec<CanonicVar>,
51        output: CanonicVar,
52    },
53    StructDestructure {
54        input: CanonicVar,
55        outputs: Vec<CanonicVar>,
56    },
57    EnumConstruct {
58        variant: ConcreteVariant<'db>,
59        input: CanonicVar,
60        output: CanonicVar,
61    },
62
63    Snapshot {
64        input: CanonicVar,
65        outputs: [CanonicVar; 2],
66    },
67    Desnap {
68        input: CanonicVar,
69        output: CanonicVar,
70    },
71}
72
73struct CanonicBlockBuilder<'db, 'a> {
74    variable: &'a VariableArena<'db>,
75    vars: UnorderedHashMap<VariableId, usize>,
76    types: Vec<TypeId<'db>>,
77    inputs: Vec<VarUsage<'db>>,
78}
79
80impl<'db, 'a> CanonicBlockBuilder<'db, 'a> {
81    fn new(variable: &'a VariableArena<'db>) -> CanonicBlockBuilder<'db, 'a> {
82        CanonicBlockBuilder {
83            variable,
84            vars: Default::default(),
85            types: vec![],
86            inputs: Default::default(),
87        }
88    }
89
90    /// Converts an input var to a CanonicVar.
91    fn handle_input(&mut self, var_usage: &VarUsage<'db>) -> CanonicVar {
92        let v = var_usage.var_id;
93
94        CanonicVar(match self.vars.entry(v) {
95            std::collections::hash_map::Entry::Occupied(e) => *e.get(),
96            std::collections::hash_map::Entry::Vacant(e) => {
97                self.types.push(self.variable[v].ty);
98                let new_id = *e.insert(self.types.len() - 1);
99                self.inputs.push(*var_usage);
100                new_id
101            }
102        })
103    }
104
105    /// Converts an output var to a CanonicVar.
106    fn handle_output(&mut self, v: &VariableId) -> CanonicVar {
107        CanonicVar(match self.vars.entry(*v) {
108            std::collections::hash_map::Entry::Occupied(e) => *e.get(),
109            std::collections::hash_map::Entry::Vacant(e) => {
110                self.types.push(self.variable[*v].ty);
111                *e.insert(self.types.len() - 1)
112            }
113        })
114    }
115
116    /// Converts a statement to a canonic statement.
117    fn handle_statement(&mut self, statement: &Statement<'db>) -> CanonicStatement<'db> {
118        match statement {
119            Statement::Const(StatementConst { value, boxed, output }) => CanonicStatement::Const {
120                value: *value,
121                output: self.handle_output(output),
122                boxed: *boxed,
123            },
124            Statement::Call(StatementCall {
125                function,
126                inputs,
127                with_coupon,
128                outputs,
129                location: _,
130            }) => CanonicStatement::Call {
131                function: *function,
132                inputs: inputs.iter().map(|input| self.handle_input(input)).collect(),
133                with_coupon: *with_coupon,
134                outputs: outputs.iter().map(|output| self.handle_output(output)).collect(),
135            },
136            Statement::StructConstruct(StatementStructConstruct { inputs, output }) => {
137                CanonicStatement::StructConstruct {
138                    inputs: inputs.iter().map(|input| self.handle_input(input)).collect(),
139                    output: self.handle_output(output),
140                }
141            }
142            Statement::StructDestructure(StatementStructDestructure { input, outputs }) => {
143                CanonicStatement::StructDestructure {
144                    input: self.handle_input(input),
145                    outputs: outputs.iter().map(|output| self.handle_output(output)).collect(),
146                }
147            }
148            Statement::EnumConstruct(StatementEnumConstruct { variant, input, output }) => {
149                CanonicStatement::EnumConstruct {
150                    variant: *variant,
151                    input: self.handle_input(input),
152                    output: self.handle_output(output),
153                }
154            }
155            Statement::Snapshot(StatementSnapshot { input, outputs }) => {
156                CanonicStatement::Snapshot {
157                    input: self.handle_input(input),
158                    outputs: outputs.map(|output| self.handle_output(&output)),
159                }
160            }
161            Statement::Desnap(StatementDesnap { input, output }) => CanonicStatement::Desnap {
162                input: self.handle_input(input),
163                output: self.handle_output(output),
164            },
165        }
166    }
167}
168
169impl<'db> CanonicBlock<'db> {
170    /// Tries to create a canonic block from a flat block.
171    /// Return the canonic representation of the block and the external inputs used in the block.
172    /// Blocks that do not end in return do not have a canonic representation.
173    fn try_from_block(
174        variable: &VariableArena<'db>,
175        block: &Block<'db>,
176    ) -> Option<(CanonicBlock<'db>, Vec<VarUsage<'db>>)> {
177        let BlockEnd::Return(returned_vars, _) = &block.end else {
178            return None;
179        };
180
181        if block.statements.is_empty() {
182            // Skip deduplication for empty blocks.
183            return None;
184        }
185
186        let mut builder = CanonicBlockBuilder::new(variable);
187
188        let stmts = block
189            .statements
190            .iter()
191            .map(|statement| builder.handle_statement(statement))
192            .collect_vec();
193
194        let returns = returned_vars.iter().map(|input| builder.handle_input(input)).collect();
195
196        Some((CanonicBlock { stmts, types: builder.types, returns }, builder.inputs))
197    }
198}
199/// Helper class to reassign variable ids.
200pub struct VarReassigner<'db, 'a> {
201    pub variables: &'a mut VariableArena<'db>,
202
203    // Maps old var_id to new_var_id
204    pub vars: UnorderedHashMap<VariableId, VariableId>,
205}
206
207impl<'db, 'a> VarReassigner<'db, 'a> {
208    pub fn new(variables: &'a mut VariableArena<'db>) -> Self {
209        Self { variables, vars: UnorderedHashMap::default() }
210    }
211}
212
213impl<'db, 'a> Rebuilder<'db> for VarReassigner<'db, 'a> {
214    fn map_var_id(&mut self, var: VariableId) -> VariableId {
215        *self.vars.entry(var).or_insert_with(|| self.variables.alloc(self.variables[var].clone()))
216    }
217}
218
219#[derive(Default)]
220struct DedupContext<'db> {
221    /// Maps a CanonicBlock to a reference block that matches it.
222    canonic_blocks: UnorderedHashMap<CanonicBlock<'db>, BlockId>,
223
224    /// Maps a block to the inputs that are needed for it to be shared by multiple flows.
225    block_id_to_inputs: UnorderedHashMap<BlockId, Vec<VarUsage<'db>>>,
226}
227
228/// Given a block and a set of inputs, assigns new ids to all the variables in the block, returning
229/// the new block and the new inputs.
230fn rebuild_block_and_inputs<'db>(
231    variables: &mut VariableArena<'db>,
232    block: &Block<'db>,
233    inputs: &[VarUsage<'db>],
234) -> (Block<'db>, Vec<VarUsage<'db>>) {
235    let mut var_reassigner = VarReassigner::new(variables);
236    (
237        var_reassigner.rebuild_block(block),
238        inputs.iter().map(|var_usage| var_reassigner.map_var_usage(*var_usage)).collect(),
239    )
240}
241
242/// Deduplicates blocks by redirecting goto's and match arms to one of the duplicates.
243/// The duplicate blocks will be removed later by `reorganize_blocks`.
244pub fn dedup_blocks<'db>(lowered: &mut Lowered<'db>) {
245    if lowered.blocks.has_root().is_err() {
246        return;
247    }
248
249    let mut ctx = DedupContext::default();
250    // Maps duplicated blocks to the new shared block and the inputs that need to be remapped for
251    // the block.
252    let mut duplicates: UnorderedHashMap<BlockId, (BlockId, Vec<VarUsage<'_>>)> =
253        Default::default();
254
255    let mut new_blocks = vec![];
256    let mut next_block_id = BlockId(lowered.blocks.len());
257
258    for (block_id, block) in lowered.blocks.iter() {
259        let Some((canonical_block, inputs)) =
260            CanonicBlock::try_from_block(&lowered.variables, block)
261        else {
262            continue;
263        };
264
265        match ctx.canonic_blocks.entry(canonical_block) {
266            unordered_hash_map::Entry::Occupied(e) => {
267                let block_and_inputs = duplicates
268                    .entry(*e.get())
269                    .or_insert_with(|| {
270                        let (block, new_inputs) =
271                            rebuild_block_and_inputs(&mut lowered.variables, block, &inputs);
272                        new_blocks.push(block);
273                        let new_block_id = next_block_id;
274                        next_block_id = next_block_id.next_block_id();
275
276                        (new_block_id, new_inputs)
277                    })
278                    .clone();
279
280                duplicates.insert(block_id, block_and_inputs);
281            }
282            unordered_hash_map::Entry::Vacant(e) => {
283                e.insert(block_id);
284            }
285        };
286
287        ctx.block_id_to_inputs.insert(block_id, inputs);
288    }
289
290    let mut new_goto_block =
291        |block_id, inputs: &Vec<VarUsage<'db>>, target_inputs: &Vec<VarUsage<'db>>| {
292            new_blocks.push(Block {
293                statements: vec![],
294                end: BlockEnd::Goto(
295                    block_id,
296                    VarRemapping {
297                        remapping: OrderedHashMap::from_iter(zip_eq(
298                            target_inputs.iter().map(|var_usage| var_usage.var_id),
299                            inputs.iter().cloned(),
300                        )),
301                    },
302                ),
303            });
304
305            let new_block_id = next_block_id;
306            next_block_id = next_block_id.next_block_id();
307            new_block_id
308        };
309
310    // Note that the loop below can't be merged with the loop above as a block might be marked as
311    // dup after we already visiting an arm that goes to it.
312    for block in lowered.blocks.iter_mut() {
313        match &mut block.end {
314            BlockEnd::Goto(target_block, remappings) => {
315                let Some((block_id, target_inputs)) = duplicates.get(target_block) else {
316                    continue;
317                };
318
319                let inputs = ctx.block_id_to_inputs.get(target_block).unwrap();
320                let mut inputs_remapping = VarRemapping {
321                    remapping: OrderedHashMap::from_iter(zip_eq(
322                        target_inputs.iter().map(|var_usage| var_usage.var_id),
323                        inputs.iter().cloned(),
324                    )),
325                };
326                for (_, src) in inputs_remapping.iter_mut() {
327                    if let Some(src_before_remapping) = remappings.get(&src.var_id) {
328                        *src = *src_before_remapping;
329                    }
330                }
331
332                *target_block = *block_id;
333                *remappings = inputs_remapping;
334            }
335            BlockEnd::Match { info } => {
336                for arm in info.arms_mut() {
337                    let Some((block_id, target_inputs)) = duplicates.get(&arm.block_id) else {
338                        continue;
339                    };
340
341                    let inputs = &ctx.block_id_to_inputs[&arm.block_id];
342                    arm.block_id = new_goto_block(*block_id, inputs, target_inputs);
343                }
344            }
345            _ => {}
346        }
347    }
348
349    for block in new_blocks {
350        lowered.blocks.push(block);
351    }
352}