Skip to main content

cairo_lang_lowering/optimizations/
dedup_blocks.rs

1#[cfg(test)]
2#[path = "dedup_blocks_test.rs"]
3mod test;
4
5use cairo_lang_semantic::items::constant::ConstValueId;
6use cairo_lang_semantic::{ConcreteVariant, TypeId};
7use cairo_lang_utils::ordered_hash_map::OrderedHashMap;
8use cairo_lang_utils::unordered_hash_map::{self, UnorderedHashMap};
9use itertools::{Itertools, zip_eq};
10
11use crate::ids::FunctionId;
12use crate::utils::{Rebuilder, RebuilderEx};
13use crate::{
14    Block, BlockEnd, BlockId, Lowered, Statement, StatementCall, StatementConst, StatementDesnap,
15    StatementEnumConstruct, StatementIntoBox, StatementSnapshot, StatementStructConstruct,
16    StatementStructDestructure, StatementUnbox, VarRemapping, VarUsage, VariableArena, VariableId,
17};
18
19/// A canonic representation of a block (used to find duplicated blocks).
20/// Currently only blocks that end with return are supported.
21#[derive(Hash, PartialEq, Eq)]
22struct CanonicBlock<'db> {
23    /// Canonic representation of the statements in the block.
24    stmts: Vec<CanonicStatement<'db>>,
25    /// The types of variables introduced in the block.
26    types: Vec<TypeId<'db>>,
27    /// variables returned by the block.
28    returns: Vec<CanonicVar>,
29}
30
31/// A canonic representation of a variable in a canonic block.
32#[derive(Hash, PartialEq, Eq)]
33struct CanonicVar(usize);
34
35/// A canonic representation of a statement in a canonic block.
36#[derive(Hash, PartialEq, Eq)]
37enum CanonicStatement<'db> {
38    Const {
39        value: ConstValueId<'db>,
40        output: CanonicVar,
41        boxed: bool,
42    },
43    Call {
44        function: FunctionId<'db>,
45        inputs: Vec<CanonicVar>,
46        with_coupon: bool,
47        outputs: Vec<CanonicVar>,
48    },
49    StructConstruct {
50        inputs: Vec<CanonicVar>,
51        output: CanonicVar,
52    },
53    StructDestructure {
54        input: CanonicVar,
55        outputs: Vec<CanonicVar>,
56    },
57    EnumConstruct {
58        variant: ConcreteVariant<'db>,
59        input: CanonicVar,
60        output: CanonicVar,
61    },
62    IntoBox {
63        input: CanonicVar,
64        output: CanonicVar,
65    },
66    Unbox {
67        input: CanonicVar,
68        output: CanonicVar,
69    },
70    Snapshot {
71        input: CanonicVar,
72        outputs: [CanonicVar; 2],
73    },
74    Desnap {
75        input: CanonicVar,
76        output: CanonicVar,
77    },
78}
79
80struct CanonicBlockBuilder<'db, 'a> {
81    variable: &'a VariableArena<'db>,
82    vars: UnorderedHashMap<VariableId, usize>,
83    types: Vec<TypeId<'db>>,
84    inputs: Vec<VarUsage<'db>>,
85}
86
87impl<'db, 'a> CanonicBlockBuilder<'db, 'a> {
88    fn new(variable: &'a VariableArena<'db>) -> CanonicBlockBuilder<'db, 'a> {
89        CanonicBlockBuilder {
90            variable,
91            vars: Default::default(),
92            types: vec![],
93            inputs: Default::default(),
94        }
95    }
96
97    /// Converts an input var to a CanonicVar.
98    fn handle_input(&mut self, var_usage: &VarUsage<'db>) -> CanonicVar {
99        let v = var_usage.var_id;
100
101        CanonicVar(match self.vars.entry(v) {
102            std::collections::hash_map::Entry::Occupied(e) => *e.get(),
103            std::collections::hash_map::Entry::Vacant(e) => {
104                self.types.push(self.variable[v].ty);
105                let new_id = *e.insert(self.types.len() - 1);
106                self.inputs.push(*var_usage);
107                new_id
108            }
109        })
110    }
111
112    /// Converts an output var to a CanonicVar.
113    fn handle_output(&mut self, v: &VariableId) -> CanonicVar {
114        CanonicVar(match self.vars.entry(*v) {
115            std::collections::hash_map::Entry::Occupied(e) => *e.get(),
116            std::collections::hash_map::Entry::Vacant(e) => {
117                self.types.push(self.variable[*v].ty);
118                *e.insert(self.types.len() - 1)
119            }
120        })
121    }
122
123    /// Converts a statement to a canonic statement.
124    fn handle_statement(&mut self, statement: &Statement<'db>) -> CanonicStatement<'db> {
125        match statement {
126            Statement::Const(StatementConst { value, boxed, output }) => CanonicStatement::Const {
127                value: *value,
128                output: self.handle_output(output),
129                boxed: *boxed,
130            },
131            Statement::Call(StatementCall {
132                function,
133                inputs,
134                with_coupon,
135                outputs,
136                location: _,
137                is_specialization_base_call: _,
138            }) => CanonicStatement::Call {
139                function: *function,
140                inputs: inputs.iter().map(|input| self.handle_input(input)).collect(),
141                with_coupon: *with_coupon,
142                outputs: outputs.iter().map(|output| self.handle_output(output)).collect(),
143            },
144            Statement::StructConstruct(StatementStructConstruct { inputs, output }) => {
145                CanonicStatement::StructConstruct {
146                    inputs: inputs.iter().map(|input| self.handle_input(input)).collect(),
147                    output: self.handle_output(output),
148                }
149            }
150            Statement::StructDestructure(StatementStructDestructure { input, outputs }) => {
151                CanonicStatement::StructDestructure {
152                    input: self.handle_input(input),
153                    outputs: outputs.iter().map(|output| self.handle_output(output)).collect(),
154                }
155            }
156            Statement::EnumConstruct(StatementEnumConstruct { variant, input, output }) => {
157                CanonicStatement::EnumConstruct {
158                    variant: *variant,
159                    input: self.handle_input(input),
160                    output: self.handle_output(output),
161                }
162            }
163            Statement::Snapshot(StatementSnapshot { input, outputs }) => {
164                CanonicStatement::Snapshot {
165                    input: self.handle_input(input),
166                    outputs: outputs.map(|output| self.handle_output(&output)),
167                }
168            }
169            Statement::Desnap(StatementDesnap { input, output }) => CanonicStatement::Desnap {
170                input: self.handle_input(input),
171                output: self.handle_output(output),
172            },
173            Statement::IntoBox(StatementIntoBox { input, output }) => CanonicStatement::IntoBox {
174                input: self.handle_input(input),
175                output: self.handle_output(output),
176            },
177            Statement::Unbox(StatementUnbox { input, output }) => CanonicStatement::Unbox {
178                input: self.handle_input(input),
179                output: self.handle_output(output),
180            },
181        }
182    }
183}
184
185impl<'db> CanonicBlock<'db> {
186    /// Tries to create a canonic block from a flat block.
187    /// Returns the canonic representation of the block and the external inputs used in the block.
188    /// Blocks that do not end in return do not have a canonic representation.
189    fn try_from_block(
190        variable: &VariableArena<'db>,
191        block: &Block<'db>,
192    ) -> Option<(CanonicBlock<'db>, Vec<VarUsage<'db>>)> {
193        let BlockEnd::Return(returned_vars, _) = &block.end else {
194            return None;
195        };
196
197        if block.statements.is_empty() {
198            // Skip deduplication for empty blocks.
199            return None;
200        }
201
202        let mut builder = CanonicBlockBuilder::new(variable);
203
204        let stmts = block
205            .statements
206            .iter()
207            .map(|statement| builder.handle_statement(statement))
208            .collect_vec();
209
210        let returns = returned_vars.iter().map(|input| builder.handle_input(input)).collect();
211
212        Some((CanonicBlock { stmts, types: builder.types, returns }, builder.inputs))
213    }
214}
215/// Helper class to reassign variable ids.
216pub struct VarReassigner<'db, 'a> {
217    pub variables: &'a mut VariableArena<'db>,
218
219    // Maps old var_id to new_var_id
220    pub vars: UnorderedHashMap<VariableId, VariableId>,
221}
222
223impl<'db, 'a> VarReassigner<'db, 'a> {
224    pub fn new(variables: &'a mut VariableArena<'db>) -> Self {
225        Self { variables, vars: UnorderedHashMap::default() }
226    }
227}
228
229impl<'db, 'a> Rebuilder<'db> for VarReassigner<'db, 'a> {
230    fn map_var_id(&mut self, var: VariableId) -> VariableId {
231        *self.vars.entry(var).or_insert_with(|| self.variables.alloc(self.variables[var].clone()))
232    }
233}
234
235#[derive(Default)]
236struct DedupContext<'db> {
237    /// Maps a CanonicBlock to a reference block that matches it.
238    canonic_blocks: UnorderedHashMap<CanonicBlock<'db>, BlockId>,
239
240    /// Maps a block to the inputs that are needed for it to be shared by multiple flows.
241    block_id_to_inputs: UnorderedHashMap<BlockId, Vec<VarUsage<'db>>>,
242}
243
244/// Given a block and a set of inputs, assigns new ids to all the variables in the block, returning
245/// the new block and the new inputs.
246fn rebuild_block_and_inputs<'db>(
247    variables: &mut VariableArena<'db>,
248    block: &Block<'db>,
249    inputs: &[VarUsage<'db>],
250) -> (Block<'db>, Vec<VarUsage<'db>>) {
251    let mut var_reassigner = VarReassigner::new(variables);
252    (
253        var_reassigner.rebuild_block(block),
254        inputs.iter().map(|var_usage| var_reassigner.map_var_usage(*var_usage)).collect(),
255    )
256}
257
258/// Deduplicates blocks by redirecting goto's and match arms to one of the duplicates.
259/// The duplicate blocks will be removed later by `reorganize_blocks`.
260pub fn dedup_blocks<'db>(lowered: &mut Lowered<'db>) {
261    if lowered.blocks.has_root().is_err() {
262        return;
263    }
264
265    let mut ctx = DedupContext::default();
266    // Maps duplicated blocks to the new shared block and the inputs that need to be remapped for
267    // the block.
268    let mut duplicates: UnorderedHashMap<BlockId, (BlockId, Vec<VarUsage<'_>>)> =
269        Default::default();
270
271    let mut new_blocks = vec![];
272    let mut next_block_id = BlockId(lowered.blocks.len());
273
274    for (block_id, block) in lowered.blocks.iter() {
275        let Some((canonical_block, inputs)) =
276            CanonicBlock::try_from_block(&lowered.variables, block)
277        else {
278            continue;
279        };
280
281        match ctx.canonic_blocks.entry(canonical_block) {
282            unordered_hash_map::Entry::Occupied(e) => {
283                let block_and_inputs = duplicates
284                    .entry(*e.get())
285                    .or_insert_with(|| {
286                        let (block, new_inputs) =
287                            rebuild_block_and_inputs(&mut lowered.variables, block, &inputs);
288                        new_blocks.push(block);
289                        let new_block_id = next_block_id;
290                        next_block_id = next_block_id.next_block_id();
291
292                        (new_block_id, new_inputs)
293                    })
294                    .clone();
295
296                duplicates.insert(block_id, block_and_inputs);
297            }
298            unordered_hash_map::Entry::Vacant(e) => {
299                e.insert(block_id);
300            }
301        };
302
303        ctx.block_id_to_inputs.insert(block_id, inputs);
304    }
305
306    let mut new_goto_block =
307        |block_id, inputs: &[VarUsage<'db>], target_inputs: &[VarUsage<'db>]| {
308            new_blocks.push(Block {
309                statements: vec![],
310                end: BlockEnd::Goto(
311                    block_id,
312                    VarRemapping {
313                        remapping: OrderedHashMap::from_iter(zip_eq(
314                            target_inputs.iter().map(|var_usage| var_usage.var_id),
315                            inputs.iter().cloned(),
316                        )),
317                    },
318                ),
319            });
320
321            let new_block_id = next_block_id;
322            next_block_id = next_block_id.next_block_id();
323            new_block_id
324        };
325
326    // Note that the loop below can't be merged with the loop above as a block might be marked as
327    // dup after we already visiting an arm that goes to it.
328    for block in lowered.blocks.iter_mut() {
329        match &mut block.end {
330            BlockEnd::Goto(target_block, remappings) => {
331                let Some((block_id, target_inputs)) = duplicates.get(target_block) else {
332                    continue;
333                };
334
335                let inputs = ctx.block_id_to_inputs.get(target_block).unwrap();
336                let mut inputs_remapping = VarRemapping {
337                    remapping: OrderedHashMap::from_iter(zip_eq(
338                        target_inputs.iter().map(|var_usage| var_usage.var_id),
339                        inputs.iter().cloned(),
340                    )),
341                };
342                for (_, src) in inputs_remapping.iter_mut() {
343                    if let Some(src_before_remapping) = remappings.get(&src.var_id) {
344                        *src = *src_before_remapping;
345                    }
346                }
347
348                *target_block = *block_id;
349                *remappings = inputs_remapping;
350            }
351            BlockEnd::Match { info } => {
352                for arm in info.arms_mut() {
353                    let Some((block_id, target_inputs)) = duplicates.get(&arm.block_id) else {
354                        continue;
355                    };
356
357                    let inputs = &ctx.block_id_to_inputs[&arm.block_id];
358                    arm.block_id = new_goto_block(*block_id, inputs, target_inputs);
359                }
360            }
361            _ => {}
362        }
363    }
364
365    for block in new_blocks {
366        lowered.blocks.push(block);
367    }
368}