Skip to main content

cairo_lang_lowering/optimizations/
match_optimizer.rs

1#[cfg(test)]
2#[path = "match_optimizer_test.rs"]
3mod test;
4
5use cairo_lang_semantic::MatchArmSelector;
6use cairo_lang_utils::ordered_hash_map::OrderedHashMap;
7use cairo_lang_utils::ordered_hash_set::OrderedHashSet;
8use cairo_lang_utils::unordered_hash_map::UnorderedHashMap;
9use id_arena::Arena;
10use itertools::{Itertools, zip_eq};
11
12use super::var_renamer::VarRenamer;
13use crate::borrow_check::Demand;
14use crate::borrow_check::analysis::{Analyzer, BackAnalysis, StatementLocation};
15use crate::borrow_check::demand::EmptyDemandReporter;
16use crate::utils::RebuilderEx;
17use crate::{
18    Block, BlockEnd, BlockId, Lowered, MatchArm, MatchEnumInfo, MatchInfo, Statement,
19    StatementEnumConstruct, VarRemapping, VarUsage, Variable, VariableId,
20};
21
22pub type MatchOptimizerDemand = Demand<VariableId, (), ()>;
23
24impl MatchOptimizerDemand {
25    fn update(&mut self, statement: &Statement) {
26        self.variables_introduced(&mut EmptyDemandReporter {}, statement.outputs(), ());
27        self.variables_used(
28            &mut EmptyDemandReporter {},
29            statement.inputs().iter().map(|VarUsage { var_id, .. }| (var_id, ())),
30        );
31    }
32}
33
34/// Optimizes Statement::EnumConstruct that is followed by a match to jump to the target of the
35/// relevant match arm.
36///
37/// For example, given:
38///
39/// ```plain
40/// blk0:
41/// Statements:
42/// (v1: core::option::Option::<core::integer::u32>) <- Option::Some(v0)
43/// End:
44/// Goto(blk1, {v1-> v2})
45///
46/// blk1:
47/// Statements:
48/// End:
49/// Match(match_enum(v2) {
50///   Option::Some(v3) => blk4,
51///   Option::None(v4) => blk5,
52/// })
53/// ```
54///
55/// Change `blk0` to jump directly to `blk4`.
56pub fn optimize_matches(lowered: &mut Lowered) {
57    if lowered.blocks.is_empty() {
58        return;
59    }
60    let ctx = MatchOptimizerContext { fixes: vec![] };
61    let mut analysis = BackAnalysis::new(lowered, ctx);
62    analysis.get_root_info();
63    let ctx = analysis.analyzer;
64
65    let mut new_blocks = vec![];
66    let mut next_block_id = BlockId(lowered.blocks.len());
67
68    // Track variable renaming that results from applying the fixes below.
69    // For each (variable_id, arm_idx) pair that is remapped (prior to the match),
70    // we assign a new variable (to satisfy the SSA requirement).
71    //
72    // For example, consider the following blocks:
73    //   blk0:
74    //   Statements:
75    //   (v0: test::Color) <- Color::Red(v5)
76    //   End:
77    //   Goto(blk1, {v1 -> v2, v0 -> v3})
78    //
79    //   blk1:
80    //   Statements:
81    //   End:
82    //   Match(match_enum(v3) {
83    //     Color::Red(v4) => blk2,
84    //   })
85    //
86    // When the optimization is applied, block0 will jump directly to blk2. Since the definition of
87    // v2 is at blk1, we must map v1 to a new variable.
88    //
89    // If there is another fix for the same match arm, the same variable will be used.
90    let mut var_renaming = UnorderedHashMap::<(VariableId, usize), VariableId>::default();
91
92    // Fixes were added in reverse order and need to be applied in that order.
93    // This is because `additional_remappings` in later blocks may need to be renamed by fixes from
94    // earlier blocks.
95    for fix in ctx.fixes {
96        // Choose new variables for each destination of the additional remappings (see comment
97        // above).
98        let mut new_remapping = fix.remapping.clone();
99        let mut renamed_vars = OrderedHashMap::<VariableId, VariableId>::default();
100        for (var, dst) in fix.additional_remappings.iter() {
101            // Allocate a new variable, if it was not allocated before.
102            let new_var = *var_renaming
103                .entry((*var, fix.arm_idx))
104                .or_insert_with(|| lowered.variables.alloc(lowered.variables[*var].clone()));
105            new_remapping.insert(new_var, *dst);
106            renamed_vars.insert(*var, new_var);
107        }
108
109        let block = &mut lowered.blocks[fix.statement_location.0];
110        assert_eq!(
111            block.statements.len() - 1,
112            fix.statement_location.1 + fix.n_same_block_statement,
113            "Unexpected number of statements in block."
114        );
115
116        if fix.remove_enum_construct {
117            block.statements.remove(fix.statement_location.1);
118        }
119
120        handle_additional_statements(
121            &mut lowered.variables,
122            &mut var_renaming,
123            &mut new_remapping,
124            &mut renamed_vars,
125            block,
126            &fix,
127        );
128
129        block.end = BlockEnd::Goto(fix.target_block, new_remapping);
130        if fix.statement_location.0 == fix.match_block {
131            // The match was removed (by the assignment of `block.end` above), no need to fix it.
132            // Sanity check: there should be no additional remapping in this case.
133            assert!(fix.additional_remappings.remapping.is_empty());
134            continue;
135        }
136
137        let block = &mut lowered.blocks[fix.match_block];
138        let BlockEnd::Match { info: MatchInfo::Enum(MatchEnumInfo { arms, location, .. }) } =
139            &mut block.end
140        else {
141            unreachable!("match block should end with a match.");
142        };
143
144        let arm = arms.get_mut(fix.arm_idx).unwrap();
145        if fix.target_block != arm.block_id {
146            // The match arm was already fixed, no need to fix it again.
147            continue;
148        }
149
150        // Fix match arm not to jump directly to a block that has an incoming gotos and add
151        // remapping that matches the goto above.
152        let arm_var = arm.var_ids.get_mut(0).unwrap();
153        let orig_var = *arm_var;
154        *arm_var = lowered.variables.alloc(lowered.variables[orig_var].clone());
155        let mut new_block_remapping: VarRemapping = Default::default();
156
157        new_block_remapping.insert(orig_var, VarUsage { var_id: *arm_var, location: *location });
158        for (var, new_var) in renamed_vars.iter() {
159            new_block_remapping.insert(*new_var, VarUsage { var_id: *var, location: *location });
160        }
161
162        new_blocks.push(Block {
163            statements: vec![],
164            end: BlockEnd::Goto(arm.block_id, new_block_remapping),
165        });
166        arm.block_id = next_block_id;
167        next_block_id = next_block_id.next_block_id();
168
169        let mut var_renamer = VarRenamer { renamed_vars: renamed_vars.into_iter().collect() };
170        // Apply the variable renaming to the reachable blocks.
171        for block_id in fix.reachable_blocks {
172            let block = &mut lowered.blocks[block_id];
173            *block = var_renamer.rebuild_block(block);
174        }
175    }
176
177    for block in new_blocks {
178        lowered.blocks.push(block);
179    }
180}
181
182/// Handles the additional statements in the fix.
183///
184/// The additional statements are not in the same block as the enum construct so they need
185/// to be copied the current block with new outputs to keep the SSA property.
186/// further we need to remap and rename the outputs for the merge in `fix.target_block`.
187///
188/// Note that since the statements are copied this might increase the code size.
189fn handle_additional_statements(
190    variables: &mut Arena<Variable>,
191    var_renaming: &mut UnorderedHashMap<(VariableId, usize), VariableId>,
192    new_remapping: &mut VarRemapping,
193    renamed_vars: &mut OrderedHashMap<VariableId, VariableId>,
194    block: &mut Block,
195    fix: &FixInfo,
196) {
197    if fix.additional_stmts.is_empty() {
198        return;
199    }
200
201    // Maps input in the original lowering to the inputs after the optimization.
202    // Since the statement are copied from after `additional_remappings` to before it,
203    // `inputs_remapping` is initialized with `additional_remapping`.
204    let mut inputs_remapping = UnorderedHashMap::<VariableId, VariableId>::from_iter(
205        fix.additional_remappings.iter().map(|(k, v)| (*k, v.var_id)),
206    );
207    for mut stmt in fix.additional_stmts.iter().cloned() {
208        for input in stmt.inputs_mut() {
209            if let Some(orig_var) = inputs_remapping.get(&input.var_id) {
210                input.var_id = *orig_var;
211            }
212        }
213
214        for output in stmt.outputs_mut() {
215            let orig_output = *output;
216            // Allocate a new variable for the output in the fixed block.
217            *output = variables.alloc(variables[*output].clone());
218            inputs_remapping.insert(orig_output, *output);
219
220            // Allocate a new post remapping output, if it was not allocated before.
221            let new_output = *var_renaming
222                .entry((orig_output, fix.arm_idx))
223                .or_insert_with(|| variables.alloc(variables[*output].clone()));
224            let location = variables[*output].location;
225            new_remapping.insert(new_output, VarUsage { var_id: *output, location });
226            renamed_vars.insert(orig_output, new_output);
227        }
228
229        block.statements.push(stmt);
230    }
231}
232
233/// Try to apply the optimization at the given statement.
234/// If the optimization can be applied, return the fix information and updates the analysis info
235/// accordingly.
236fn try_get_fix_info(
237    StatementEnumConstruct { variant, input, output }: &StatementEnumConstruct,
238    info: &mut AnalysisInfo<'_>,
239    candidate: &mut OptimizationCandidate<'_>,
240    statement_location: (BlockId, usize),
241) -> Option<FixInfo> {
242    let (arm_idx, arm) = candidate
243        .match_arms
244        .iter()
245        .find_position(
246            |arm| matches!(&arm.arm_selector, MatchArmSelector::VariantId(v) if v == variant),
247        )
248        .expect("arm not found.");
249
250    let [var_id] = arm.var_ids.as_slice() else {
251        panic!("An arm of an EnumMatch should produce a single variable.");
252    };
253
254    // Prepare a remapping object for the input of the EnumConstruct, which will be used as `var_id`
255    // in `arm.block_id`.
256    let mut remapping = VarRemapping::default();
257    remapping.insert(*var_id, *input);
258
259    // Compute the demand based on the demand of the specific arm, rather than the current demand
260    // (which contains the union of the demands from all the arms).
261    // Apply the remapping of the input variable and the additional remappings if exist.
262    let mut demand = std::mem::take(&mut candidate.arm_demands[arm_idx]);
263
264    let additional_stmts = candidate
265        .statement_rev
266        .iter()
267        .rev()
268        .skip(candidate.n_same_block_statement)
269        .cloned()
270        .cloned()
271        .collect_vec();
272    for stmt in &additional_stmts {
273        demand.update(stmt);
274    }
275
276    demand
277        .apply_remapping(&mut EmptyDemandReporter {}, [(var_id, (&input.var_id, ()))].into_iter());
278
279    let additional_remappings = match candidate.remapping {
280        Some(remappings) => {
281            // Filter the additional remappings to only include those that are used in relevant arm.
282            VarRemapping {
283                remapping: OrderedHashMap::from_iter(remappings.iter().filter_map(|(dst, src)| {
284                    if demand.vars.contains_key(dst) { Some((*dst, *src)) } else { None }
285                })),
286            }
287        }
288        None => VarRemapping::default(),
289    };
290
291    if !additional_remappings.is_empty() && candidate.future_merge {
292        // If there are additional_remappings and a future merge we cannot apply the optimization.
293        return None;
294    }
295
296    demand.apply_remapping(
297        &mut EmptyDemandReporter {},
298        additional_remappings.iter().map(|(dst, src_var_usage)| (dst, (&src_var_usage.var_id, ()))),
299    );
300
301    for stmt in candidate.statement_rev.iter().rev() {
302        demand.update(stmt);
303    }
304    info.demand = demand;
305    info.reachable_blocks = std::mem::take(&mut candidate.arm_reachable_blocks[arm_idx]);
306
307    Some(FixInfo {
308        statement_location,
309        match_block: candidate.match_block,
310        arm_idx,
311        target_block: arm.block_id,
312        remapping,
313        reachable_blocks: info.reachable_blocks.clone(),
314        additional_remappings,
315        n_same_block_statement: candidate.n_same_block_statement,
316        remove_enum_construct: !info.demand.vars.contains_key(output),
317        additional_stmts,
318    })
319}
320
321pub struct FixInfo {
322    /// The location that needs to be fixed,
323    statement_location: (BlockId, usize),
324    /// The block with the match statement that we want to jump over.
325    match_block: BlockId,
326    /// The index of the arm that we want to jump to.
327    arm_idx: usize,
328    /// The target block to jump to.
329    target_block: BlockId,
330    /// Remaps the input of the enum construct to the variable that is introduced by the match arm.
331    remapping: VarRemapping,
332    /// The blocks that can be reached from the relevant arm of the match.
333    reachable_blocks: OrderedHashSet<BlockId>,
334    /// Additional remappings that appeared in a `Goto` leading to the match.
335    additional_remappings: VarRemapping,
336    /// The number of statement in the in the same block as the enum construct.
337    n_same_block_statement: usize,
338    /// Indicated that the enum construct statement can be removed.
339    remove_enum_construct: bool,
340    /// Additional statement that appear before the match but not in the same block as the enum
341    /// construct.
342    additional_stmts: Vec<Statement>,
343}
344
345#[derive(Clone)]
346struct OptimizationCandidate<'a> {
347    /// The variable that is matched.
348    match_variable: VariableId,
349
350    /// The match arms of the extern match that we are optimizing.
351    match_arms: &'a [MatchArm],
352
353    /// The block that the match is in.
354    match_block: BlockId,
355
356    /// The demands at the arms.
357    arm_demands: Vec<MatchOptimizerDemand>,
358
359    /// Whether there is a future merge between the match arms.
360    future_merge: bool,
361
362    /// The blocks that can be reached from each of the arms.
363    arm_reachable_blocks: Vec<OrderedHashSet<BlockId>>,
364
365    /// A remappings that appeared in a `Goto` leading to the match.
366    /// Only one such remapping is allowed as this is typically the case
367    /// after running `optimize_remapping` and `reorder_statements` and it simplifies the
368    /// optimization.
369    remapping: Option<&'a VarRemapping>,
370
371    /// The statements before the match in reverse order.
372    statement_rev: Vec<&'a Statement>,
373
374    /// The number of statement in the in the same block as the enum construct.
375    n_same_block_statement: usize,
376}
377
378pub struct MatchOptimizerContext {
379    fixes: Vec<FixInfo>,
380}
381
382#[derive(Clone)]
383pub struct AnalysisInfo<'a> {
384    candidate: Option<OptimizationCandidate<'a>>,
385    demand: MatchOptimizerDemand,
386    /// Blocks that can be reached from the current block.
387    reachable_blocks: OrderedHashSet<BlockId>,
388}
389
390impl<'a> Analyzer<'a> for MatchOptimizerContext {
391    type Info = AnalysisInfo<'a>;
392
393    fn visit_block_start(&mut self, info: &mut Self::Info, block_id: BlockId, _block: &Block) {
394        info.reachable_blocks.insert(block_id);
395    }
396
397    fn visit_stmt(
398        &mut self,
399        info: &mut Self::Info,
400        statement_location: StatementLocation,
401        stmt: &'a Statement,
402    ) {
403        if let Some(mut candidate) = info.candidate.take() {
404            match stmt {
405                Statement::EnumConstruct(enum_construct_stmt)
406                    if enum_construct_stmt.output == candidate.match_variable =>
407                {
408                    if let Some(fix_info) = try_get_fix_info(
409                        enum_construct_stmt,
410                        info,
411                        &mut candidate,
412                        statement_location,
413                    ) {
414                        self.fixes.push(fix_info);
415                        return;
416                    }
417
418                    // Since `candidate.match_variable` was introduced, the candidate is no longer
419                    // applicable.
420                    info.candidate = None;
421                }
422                _ => {
423                    candidate.statement_rev.push(stmt);
424                    candidate.n_same_block_statement += 1;
425                    info.candidate = Some(candidate);
426                }
427            }
428        }
429
430        info.demand.update(stmt);
431    }
432
433    fn visit_goto(
434        &mut self,
435        info: &mut Self::Info,
436        _statement_location: StatementLocation,
437        _target_block_id: BlockId,
438        remapping: &'a VarRemapping,
439    ) {
440        info.demand.apply_remapping(
441            &mut EmptyDemandReporter {},
442            remapping.iter().map(|(dst, src)| (dst, (&src.var_id, ()))),
443        );
444
445        let Some(ref mut candidate) = &mut info.candidate else {
446            return;
447        };
448
449        // New block, so reset the number of statements that are in the same block as the enum
450        // construct.
451        candidate.n_same_block_statement = 0;
452
453        if candidate.future_merge
454            && candidate.statement_rev.iter().any(|stmt| !stmt.outputs().is_empty())
455        {
456            // If we have a future merge and a statement not in the same block as the enum construct
457            // has an output, we cannot apply the optimization.
458            info.candidate = None;
459            return;
460        }
461
462        if remapping.is_empty() {
463            return;
464        }
465
466        if candidate.remapping.is_some() {
467            info.candidate = None;
468            return;
469        }
470
471        // Store the goto's remapping.
472        candidate.remapping = Some(remapping);
473        if let Some(var_usage) = remapping.get(&candidate.match_variable) {
474            candidate.match_variable = var_usage.var_id;
475        }
476    }
477
478    fn merge_match(
479        &mut self,
480        (block_id, _statement_idx): StatementLocation,
481        match_info: &'a MatchInfo,
482        infos: impl Iterator<Item = Self::Info>,
483    ) -> Self::Info {
484        let (arm_demands, arm_reachable_blocks): (Vec<_>, Vec<_>) =
485            infos.map(|info| (info.demand, info.reachable_blocks)).unzip();
486
487        let arm_demands_without_arm_var = zip_eq(match_info.arms(), &arm_demands)
488            .map(|(arm, demand)| {
489                let mut demand = demand.clone();
490                // Remove the variable that is introduced by the match arm.
491                demand.variables_introduced(&mut EmptyDemandReporter {}, &arm.var_ids, ());
492
493                (demand, ())
494            })
495            .collect_vec();
496        let mut demand = MatchOptimizerDemand::merge_demands(
497            &arm_demands_without_arm_var,
498            &mut EmptyDemandReporter {},
499        );
500
501        // Union the reachable blocks for all the infos.
502        let mut reachable_blocks = OrderedHashSet::default();
503        let mut max_possible_size = 0;
504        for cur_reachable_blocks in &arm_reachable_blocks {
505            reachable_blocks.extend(cur_reachable_blocks.iter().cloned());
506            max_possible_size += cur_reachable_blocks.len();
507        }
508        // If the size of `reachable_blocks` is less than the sum of the sizes of the
509        // `arm_reachable_blocks`, then there was a collision.
510        let found_collision = reachable_blocks.len() < max_possible_size;
511
512        let candidate = match match_info {
513            // A match is a candidate for the optimization if it is a match on an Enum
514            // and its input is unused after the match.
515            MatchInfo::Enum(MatchEnumInfo { input, arms, .. })
516                if !demand.vars.contains_key(&input.var_id) =>
517            {
518                Some(OptimizationCandidate {
519                    match_variable: input.var_id,
520                    match_arms: arms,
521                    match_block: block_id,
522                    arm_demands,
523                    future_merge: found_collision,
524                    arm_reachable_blocks,
525                    remapping: None,
526                    statement_rev: vec![],
527                    n_same_block_statement: 0,
528                })
529            }
530            _ => None,
531        };
532
533        demand.variables_used(
534            &mut EmptyDemandReporter {},
535            match_info.inputs().iter().map(|VarUsage { var_id, .. }| (var_id, ())),
536        );
537
538        Self::Info { candidate, demand, reachable_blocks }
539    }
540
541    fn info_from_return(
542        &mut self,
543        _statement_location: StatementLocation,
544        vars: &[VarUsage],
545    ) -> Self::Info {
546        let mut demand = MatchOptimizerDemand::default();
547        demand.variables_used(
548            &mut EmptyDemandReporter {},
549            vars.iter().map(|VarUsage { var_id, .. }| (var_id, ())),
550        );
551        Self::Info { candidate: None, demand, reachable_blocks: Default::default() }
552    }
553}