cairo_lang_lowering/optimizations/
dedup_blocks.rs

1#[cfg(test)]
2#[path = "dedup_blocks_test.rs"]
3mod test;
4
5use cairo_lang_semantic::items::constant::ConstValue;
6use cairo_lang_semantic::{ConcreteVariant, TypeId};
7use cairo_lang_utils::ordered_hash_map::OrderedHashMap;
8use cairo_lang_utils::unordered_hash_map::{self, UnorderedHashMap};
9use id_arena::Arena;
10use itertools::{Itertools, zip_eq};
11
12use crate::ids::FunctionId;
13use crate::utils::{Rebuilder, RebuilderEx};
14use crate::{
15    Block, BlockEnd, BlockId, Lowered, Statement, StatementCall, StatementConst, StatementDesnap,
16    StatementEnumConstruct, StatementSnapshot, StatementStructConstruct,
17    StatementStructDestructure, VarRemapping, VarUsage, Variable, VariableId,
18};
19
20/// A canonic representation of a block (used to find duplicated blocks).
21/// Currently only blocks that end with return are supported.
22#[derive(Hash, PartialEq, Eq)]
23struct CanonicBlock {
24    /// Canonic representation of the statements in the block.
25    stmts: Vec<CanonicStatement>,
26    /// The types of variables introduced in the block.
27    types: Vec<TypeId>,
28    /// variables returned by the block.
29    returns: Vec<CanonicVar>,
30}
31
32/// A canonic representation of a variable in a canonic block.
33#[derive(Hash, PartialEq, Eq)]
34struct CanonicVar(usize);
35
36/// A canonic representation of a statement in a canonic block.
37#[derive(Hash, PartialEq, Eq)]
38enum CanonicStatement {
39    Const {
40        value: ConstValue,
41        output: CanonicVar,
42    },
43    Call {
44        function: FunctionId,
45        inputs: Vec<CanonicVar>,
46        with_coupon: bool,
47        outputs: Vec<CanonicVar>,
48    },
49    StructConstruct {
50        inputs: Vec<CanonicVar>,
51        output: CanonicVar,
52    },
53    StructDestructure {
54        input: CanonicVar,
55        outputs: Vec<CanonicVar>,
56    },
57    EnumConstruct {
58        variant: ConcreteVariant,
59        input: CanonicVar,
60        output: CanonicVar,
61    },
62
63    Snapshot {
64        input: CanonicVar,
65        outputs: [CanonicVar; 2],
66    },
67    Desnap {
68        input: CanonicVar,
69        output: CanonicVar,
70    },
71}
72
73struct CanonicBlockBuilder<'a> {
74    variable: &'a Arena<Variable>,
75    vars: UnorderedHashMap<VariableId, usize>,
76    types: Vec<TypeId>,
77    inputs: Vec<VarUsage>,
78}
79
80impl CanonicBlockBuilder<'_> {
81    fn new(variable: &Arena<Variable>) -> CanonicBlockBuilder<'_> {
82        CanonicBlockBuilder {
83            variable,
84            vars: Default::default(),
85            types: vec![],
86            inputs: Default::default(),
87        }
88    }
89
90    /// Converts an input var to a CanonicVar.
91    fn handle_input(&mut self, var_usage: &VarUsage) -> CanonicVar {
92        let v = var_usage.var_id;
93
94        CanonicVar(match self.vars.entry(v) {
95            std::collections::hash_map::Entry::Occupied(e) => *e.get(),
96            std::collections::hash_map::Entry::Vacant(e) => {
97                self.types.push(self.variable[v].ty);
98                let new_id = *e.insert(self.types.len() - 1);
99                self.inputs.push(*var_usage);
100                new_id
101            }
102        })
103    }
104
105    /// Converts an output var to a CanonicVar.
106    fn handle_output(&mut self, v: &VariableId) -> CanonicVar {
107        CanonicVar(match self.vars.entry(*v) {
108            std::collections::hash_map::Entry::Occupied(e) => *e.get(),
109            std::collections::hash_map::Entry::Vacant(e) => {
110                self.types.push(self.variable[*v].ty);
111                *e.insert(self.types.len() - 1)
112            }
113        })
114    }
115
116    /// Converts a statement to a cononic statement.
117    fn handle_statement(&mut self, statement: &Statement) -> CanonicStatement {
118        match statement {
119            Statement::Const(StatementConst { value, output }) => {
120                CanonicStatement::Const { value: value.clone(), output: self.handle_output(output) }
121            }
122            Statement::Call(StatementCall {
123                function,
124                inputs,
125                with_coupon,
126                outputs,
127                location: _,
128            }) => CanonicStatement::Call {
129                function: *function,
130                inputs: inputs.iter().map(|input| self.handle_input(input)).collect(),
131                with_coupon: *with_coupon,
132                outputs: outputs.iter().map(|output| self.handle_output(output)).collect(),
133            },
134            Statement::StructConstruct(StatementStructConstruct { inputs, output }) => {
135                CanonicStatement::StructConstruct {
136                    inputs: inputs.iter().map(|input| self.handle_input(input)).collect(),
137                    output: self.handle_output(output),
138                }
139            }
140            Statement::StructDestructure(StatementStructDestructure { input, outputs }) => {
141                CanonicStatement::StructDestructure {
142                    input: self.handle_input(input),
143                    outputs: outputs.iter().map(|output| self.handle_output(output)).collect(),
144                }
145            }
146            Statement::EnumConstruct(StatementEnumConstruct { variant, input, output }) => {
147                CanonicStatement::EnumConstruct {
148                    variant: *variant,
149                    input: self.handle_input(input),
150                    output: self.handle_output(output),
151                }
152            }
153            Statement::Snapshot(StatementSnapshot { input, outputs }) => {
154                CanonicStatement::Snapshot {
155                    input: self.handle_input(input),
156                    outputs: outputs.map(|output| self.handle_output(&output)),
157                }
158            }
159            Statement::Desnap(StatementDesnap { input, output }) => CanonicStatement::Desnap {
160                input: self.handle_input(input),
161                output: self.handle_output(output),
162            },
163        }
164    }
165}
166
167impl CanonicBlock {
168    /// Tries to create a canonic block from a flat block.
169    /// Return the canonic representation of the block and the external inputs used in the block.
170    /// Blocks that do not end in return do not have a canonic representation.
171    fn try_from_block(
172        variable: &Arena<Variable>,
173        block: &Block,
174    ) -> Option<(CanonicBlock, Vec<VarUsage>)> {
175        let BlockEnd::Return(returned_vars, _) = &block.end else {
176            return None;
177        };
178
179        if block.statements.is_empty() {
180            // Skip deduplication for empty blocks.
181            return None;
182        }
183
184        let mut builder = CanonicBlockBuilder::new(variable);
185
186        let stmts = block
187            .statements
188            .iter()
189            .map(|statement| builder.handle_statement(statement))
190            .collect_vec();
191
192        let returns = returned_vars.iter().map(|input| builder.handle_input(input)).collect();
193
194        Some((CanonicBlock { stmts, types: builder.types, returns }, builder.inputs))
195    }
196}
197/// Helper class to reassign variable ids.
198pub struct VarReassigner<'a> {
199    pub variables: &'a mut Arena<Variable>,
200
201    // Maps old var_id to new_var_id
202    pub vars: UnorderedHashMap<VariableId, VariableId>,
203}
204
205impl<'a> VarReassigner<'a> {
206    pub fn new(variables: &'a mut Arena<Variable>) -> Self {
207        Self { variables, vars: UnorderedHashMap::default() }
208    }
209}
210
211impl Rebuilder for VarReassigner<'_> {
212    fn map_var_id(&mut self, var: VariableId) -> VariableId {
213        *self.vars.entry(var).or_insert_with(|| self.variables.alloc(self.variables[var].clone()))
214    }
215}
216
217#[derive(Default)]
218struct DedupContext {
219    /// Maps a CanonicBlock to a reference block that matches it.
220    canonic_blocks: UnorderedHashMap<CanonicBlock, BlockId>,
221
222    /// Maps a block to the inputs that are needed for it to be shared by multiple flows.
223    block_id_to_inputs: UnorderedHashMap<BlockId, Vec<VarUsage>>,
224}
225
226/// Given a block and a set of inputs, assigns new ids to all the variables in the block, returning
227/// the new block and the new inputs.
228fn rebuild_block_and_inputs(
229    variables: &mut Arena<Variable>,
230    block: &Block,
231    inputs: &[VarUsage],
232) -> (Block, Vec<VarUsage>) {
233    let mut var_reassigner = VarReassigner::new(variables);
234    (
235        var_reassigner.rebuild_block(block),
236        inputs.iter().map(|var_usage| var_reassigner.map_var_usage(*var_usage)).collect(),
237    )
238}
239
240/// Deduplicates blocks by redirecting goto's and match arms to one of the duplicates.
241/// The duplicate blocks will be removed later by `reorganize_blocks`.
242pub fn dedup_blocks(lowered: &mut Lowered) {
243    if lowered.blocks.has_root().is_err() {
244        return;
245    }
246
247    let mut ctx = DedupContext::default();
248    // Maps duplicated blocks to the new shared block and the inputs that need to be remapped for
249    // the block.
250    let mut duplicates: UnorderedHashMap<BlockId, (BlockId, Vec<VarUsage>)> = Default::default();
251
252    let mut new_blocks = vec![];
253    let mut next_block_id = BlockId(lowered.blocks.len());
254
255    for (block_id, block) in lowered.blocks.iter() {
256        let Some((canonical_block, inputs)) =
257            CanonicBlock::try_from_block(&lowered.variables, block)
258        else {
259            continue;
260        };
261
262        match ctx.canonic_blocks.entry(canonical_block) {
263            unordered_hash_map::Entry::Occupied(e) => {
264                let block_and_inputs = duplicates
265                    .entry(*e.get())
266                    .or_insert_with(|| {
267                        let (block, new_inputs) =
268                            rebuild_block_and_inputs(&mut lowered.variables, block, &inputs);
269                        new_blocks.push(block);
270                        let new_block_id = next_block_id;
271                        next_block_id = next_block_id.next_block_id();
272
273                        (new_block_id, new_inputs)
274                    })
275                    .clone();
276
277                duplicates.insert(block_id, block_and_inputs);
278            }
279            unordered_hash_map::Entry::Vacant(e) => {
280                e.insert(block_id);
281            }
282        };
283
284        ctx.block_id_to_inputs.insert(block_id, inputs);
285    }
286
287    let mut new_goto_block = |block_id, inputs: &Vec<VarUsage>, target_inputs: &Vec<VarUsage>| {
288        new_blocks.push(Block {
289            statements: vec![],
290            end: BlockEnd::Goto(
291                block_id,
292                VarRemapping {
293                    remapping: OrderedHashMap::from_iter(zip_eq(
294                        target_inputs.iter().map(|var_usage| var_usage.var_id),
295                        inputs.iter().cloned(),
296                    )),
297                },
298            ),
299        });
300
301        let new_block_id = next_block_id;
302        next_block_id = next_block_id.next_block_id();
303        new_block_id
304    };
305
306    // Note that the loop below can't be merged with the loop above as a block might be marked as
307    // dup after we already visiting an arm that goes to it.
308    for block in lowered.blocks.iter_mut() {
309        match &mut block.end {
310            BlockEnd::Goto(target_block, remappings) => {
311                let Some((block_id, target_inputs)) = duplicates.get(target_block) else {
312                    continue;
313                };
314
315                let inputs = ctx.block_id_to_inputs.get(target_block).unwrap();
316                let mut inputs_remapping = VarRemapping {
317                    remapping: OrderedHashMap::from_iter(zip_eq(
318                        target_inputs.iter().map(|var_usage| var_usage.var_id),
319                        inputs.iter().cloned(),
320                    )),
321                };
322                for (_, src) in inputs_remapping.iter_mut() {
323                    if let Some(src_before_remapping) = remappings.get(&src.var_id) {
324                        *src = *src_before_remapping;
325                    }
326                }
327
328                *target_block = *block_id;
329                *remappings = inputs_remapping;
330            }
331            BlockEnd::Match { info } => {
332                for arm in info.arms_mut() {
333                    let Some((block_id, target_inputs)) = duplicates.get(&arm.block_id) else {
334                        continue;
335                    };
336
337                    let inputs = &ctx.block_id_to_inputs[&arm.block_id];
338                    arm.block_id = new_goto_block(*block_id, inputs, target_inputs);
339                }
340            }
341            _ => {}
342        }
343    }
344
345    for block in new_blocks {
346        lowered.blocks.push(block);
347    }
348}