cairo_lang_lowering/optimizations/
match_optimizer.rs

1#[cfg(test)]
2#[path = "match_optimizer_test.rs"]
3mod test;
4
5use cairo_lang_semantic::MatchArmSelector;
6use cairo_lang_utils::ordered_hash_map::OrderedHashMap;
7use cairo_lang_utils::ordered_hash_set::OrderedHashSet;
8use cairo_lang_utils::unordered_hash_map::UnorderedHashMap;
9use itertools::{Itertools, zip_eq};
10
11use super::var_renamer::VarRenamer;
12use crate::borrow_check::Demand;
13use crate::borrow_check::analysis::{Analyzer, BackAnalysis, StatementLocation};
14use crate::borrow_check::demand::EmptyDemandReporter;
15use crate::utils::RebuilderEx;
16use crate::{
17    Block, BlockEnd, BlockId, Lowered, MatchArm, MatchEnumInfo, MatchInfo, Statement,
18    StatementEnumConstruct, VarRemapping, VarUsage, VariableArena, VariableId,
19};
20
21pub type MatchOptimizerDemand<'db> = Demand<VariableId, (), ()>;
22
23impl<'db> MatchOptimizerDemand<'db> {
24    fn update(&mut self, statement: &Statement<'db>) {
25        self.variables_introduced(&mut EmptyDemandReporter {}, statement.outputs(), ());
26        self.variables_used(
27            &mut EmptyDemandReporter {},
28            statement.inputs().iter().map(|VarUsage { var_id, .. }| (var_id, ())),
29        );
30    }
31}
32
33/// Optimizes Statement::EnumConstruct that is followed by a match to jump to the target of the
34/// relevant match arm.
35///
36/// For example, given:
37///
38/// ```plain
39/// blk0:
40/// Statements:
41/// (v1: core::option::Option::<core::integer::u32>) <- Option::Some(v0)
42/// End:
43/// Goto(blk1, {v1-> v2})
44///
45/// blk1:
46/// Statements:
47/// End:
48/// Match(match_enum(v2) {
49///   Option::Some(v3) => blk4,
50///   Option::None(v4) => blk5,
51/// })
52/// ```
53///
54/// Change `blk0` to jump directly to `blk4`.
55pub fn optimize_matches<'db>(lowered: &mut Lowered<'db>) {
56    if lowered.blocks.is_empty() {
57        return;
58    }
59    let ctx = MatchOptimizerContext { fixes: vec![] };
60    let mut analysis = BackAnalysis::new(lowered, ctx);
61    analysis.get_root_info();
62    let ctx = analysis.analyzer;
63
64    let mut new_blocks = vec![];
65    let mut next_block_id = BlockId(lowered.blocks.len());
66
67    // Track variable renaming that results from applying the fixes below.
68    // For each (variable_id, arm_idx) pair that is remapped (prior to the match),
69    // we assign a new variable (to satisfy the SSA requirement).
70    //
71    // For example, consider the following blocks:
72    //   blk0:
73    //   Statements:
74    //   (v0: test::Color) <- Color::Red(v5)
75    //   End:
76    //   Goto(blk1, {v1 -> v2, v0 -> v3})
77    //
78    //   blk1:
79    //   Statements:
80    //   End:
81    //   Match(match_enum(v3) {
82    //     Color::Red(v4) => blk2,
83    //   })
84    //
85    // When the optimization is applied, block0 will jump directly to blk2. Since the definition of
86    // v2 is at blk1, we must map v1 to a new variable.
87    //
88    // If there is another fix for the same match arm, the same variable will be used.
89    let mut var_renaming = UnorderedHashMap::<(VariableId, usize), VariableId>::default();
90
91    // Fixes were added in reverse order and need to be applied in that order.
92    // This is because `additional_remappings` in later blocks may need to be renamed by fixes from
93    // earlier blocks.
94    for fix in ctx.fixes {
95        // Choose new variables for each destination of the additional remappings (see comment
96        // above).
97        let mut new_remapping = fix.remapping.clone();
98        let mut renamed_vars = OrderedHashMap::<VariableId, VariableId>::default();
99        for (var, dst) in fix.additional_remappings.iter() {
100            // Allocate a new variable, if it was not allocated before.
101            let new_var = *var_renaming
102                .entry((*var, fix.arm_idx))
103                .or_insert_with(|| lowered.variables.alloc(lowered.variables[*var].clone()));
104            new_remapping.insert(new_var, *dst);
105            renamed_vars.insert(*var, new_var);
106        }
107
108        let block = &mut lowered.blocks[fix.statement_location.0];
109        assert_eq!(
110            block.statements.len() - 1,
111            fix.statement_location.1 + fix.n_same_block_statement,
112            "Unexpected number of statements in block."
113        );
114
115        if fix.remove_enum_construct {
116            block.statements.remove(fix.statement_location.1);
117        }
118
119        handle_additional_statements(
120            &mut lowered.variables,
121            &mut var_renaming,
122            &mut new_remapping,
123            &mut renamed_vars,
124            block,
125            &fix,
126        );
127
128        block.end = BlockEnd::Goto(fix.target_block, new_remapping);
129        if fix.statement_location.0 == fix.match_block {
130            // The match was removed (by the assignment of `block.end` above), no need to fix it.
131            // Sanity check: there should be no additional remapping in this case.
132            assert!(fix.additional_remappings.remapping.is_empty());
133            continue;
134        }
135
136        let block = &mut lowered.blocks[fix.match_block];
137        let BlockEnd::Match { info: MatchInfo::Enum(MatchEnumInfo { arms, location, .. }) } =
138            &mut block.end
139        else {
140            unreachable!("match block should end with a match.");
141        };
142
143        let arm = arms.get_mut(fix.arm_idx).unwrap();
144        if fix.target_block != arm.block_id {
145            // The match arm was already fixed, no need to fix it again.
146            continue;
147        }
148
149        // Fix match arm not to jump directly to a block that has an incoming gotos and add
150        // remapping that matches the goto above.
151        let arm_var = arm.var_ids.get_mut(0).unwrap();
152        let orig_var = *arm_var;
153        *arm_var = lowered.variables.alloc(lowered.variables[orig_var].clone());
154        let mut new_block_remapping: VarRemapping<'_> = Default::default();
155
156        new_block_remapping.insert(orig_var, VarUsage { var_id: *arm_var, location: *location });
157        for (var, new_var) in renamed_vars.iter() {
158            new_block_remapping.insert(*new_var, VarUsage { var_id: *var, location: *location });
159        }
160
161        new_blocks.push(Block {
162            statements: vec![],
163            end: BlockEnd::Goto(arm.block_id, new_block_remapping),
164        });
165        arm.block_id = next_block_id;
166        next_block_id = next_block_id.next_block_id();
167
168        let mut var_renamer = VarRenamer { renamed_vars: renamed_vars.into_iter().collect() };
169        // Apply the variable renaming to the reachable blocks.
170        for block_id in fix.reachable_blocks {
171            let block = &mut lowered.blocks[block_id];
172            *block = var_renamer.rebuild_block(block);
173        }
174    }
175
176    for block in new_blocks {
177        lowered.blocks.push(block);
178    }
179}
180
181/// Handles the additional statements in the fix.
182///
183/// The additional statements are not in the same block as the enum construct so they need
184/// to be copied the current block with new outputs to keep the SSA property.
185/// further we need to remap and rename the outputs for the merge in `fix.target_block`.
186///
187/// Note that since the statements are copied this might increase the code size.
188fn handle_additional_statements<'db>(
189    variables: &mut VariableArena<'db>,
190    var_renaming: &mut UnorderedHashMap<(VariableId, usize), VariableId>,
191    new_remapping: &mut VarRemapping<'db>,
192    renamed_vars: &mut OrderedHashMap<VariableId, VariableId>,
193    block: &mut Block<'db>,
194    fix: &FixInfo<'db>,
195) {
196    if fix.additional_stmts.is_empty() {
197        return;
198    }
199
200    // Maps input in the original lowering to the inputs after the optimization.
201    // Since the statement are copied from after `additional_remappings` to before it,
202    // `inputs_remapping` is initialized with `additional_remapping`.
203    let mut inputs_remapping = UnorderedHashMap::<VariableId, VariableId>::from_iter(
204        fix.additional_remappings.iter().map(|(k, v)| (*k, v.var_id)),
205    );
206    for mut stmt in fix.additional_stmts.iter().cloned() {
207        for input in stmt.inputs_mut() {
208            if let Some(orig_var) = inputs_remapping.get(&input.var_id) {
209                input.var_id = *orig_var;
210            }
211        }
212
213        for output in stmt.outputs_mut() {
214            let orig_output = *output;
215            // Allocate a new variable for the output in the fixed block.
216            *output = variables.alloc(variables[*output].clone());
217            inputs_remapping.insert(orig_output, *output);
218
219            // Allocate a new post remapping output, if it was not allocated before.
220            let new_output = *var_renaming
221                .entry((orig_output, fix.arm_idx))
222                .or_insert_with(|| variables.alloc(variables[*output].clone()));
223            let location = variables[*output].location;
224            new_remapping.insert(new_output, VarUsage { var_id: *output, location });
225            renamed_vars.insert(orig_output, new_output);
226        }
227
228        block.statements.push(stmt);
229    }
230}
231
232/// Try to apply the optimization at the given statement.
233/// If the optimization can be applied, return the fix information and updates the analysis info
234/// accordingly.
235fn try_get_fix_info<'db>(
236    StatementEnumConstruct { variant, input, output }: &StatementEnumConstruct<'db>,
237    info: &mut AnalysisInfo<'db, '_>,
238    candidate: &mut OptimizationCandidate<'db, '_>,
239    statement_location: (BlockId, usize),
240) -> Option<FixInfo<'db>> {
241    let (arm_idx, arm) = candidate
242        .match_arms
243        .iter()
244        .find_position(
245            |arm| matches!(&arm.arm_selector, MatchArmSelector::VariantId(v) if v == variant),
246        )
247        .expect("arm not found.");
248
249    let [var_id] = arm.var_ids.as_slice() else {
250        panic!("An arm of an EnumMatch should produce a single variable.");
251    };
252
253    // Prepare a remapping object for the input of the EnumConstruct, which will be used as `var_id`
254    // in `arm.block_id`.
255    let mut remapping = VarRemapping::default();
256    remapping.insert(*var_id, *input);
257
258    // Compute the demand based on the demand of the specific arm, rather than the current demand
259    // (which contains the union of the demands from all the arms).
260    // Apply the remapping of the input variable and the additional remappings if exist.
261    let mut demand = std::mem::take(&mut candidate.arm_demands[arm_idx]);
262
263    let additional_stmts = candidate
264        .statement_rev
265        .iter()
266        .rev()
267        .skip(candidate.n_same_block_statement)
268        .cloned()
269        .cloned()
270        .collect_vec();
271    for stmt in &additional_stmts {
272        demand.update(stmt);
273    }
274
275    demand
276        .apply_remapping(&mut EmptyDemandReporter {}, [(var_id, (&input.var_id, ()))].into_iter());
277
278    let additional_remappings = match candidate.remapping {
279        Some(remappings) => {
280            // Filter the additional remappings to only include those that are used in relevant arm.
281            VarRemapping {
282                remapping: OrderedHashMap::from_iter(remappings.iter().filter_map(|(dst, src)| {
283                    if demand.vars.contains_key(dst) { Some((*dst, *src)) } else { None }
284                })),
285            }
286        }
287        None => VarRemapping::default(),
288    };
289
290    if !additional_remappings.is_empty() && candidate.future_merge {
291        // If there are additional_remappings and a future merge we cannot apply the optimization.
292        return None;
293    }
294
295    demand.apply_remapping(
296        &mut EmptyDemandReporter {},
297        additional_remappings.iter().map(|(dst, src_var_usage)| (dst, (&src_var_usage.var_id, ()))),
298    );
299
300    for stmt in candidate.statement_rev.iter().rev() {
301        demand.update(stmt);
302    }
303    info.demand = demand;
304    info.reachable_blocks = std::mem::take(&mut candidate.arm_reachable_blocks[arm_idx]);
305
306    Some(FixInfo {
307        statement_location,
308        match_block: candidate.match_block,
309        arm_idx,
310        target_block: arm.block_id,
311        remapping,
312        reachable_blocks: info.reachable_blocks.clone(),
313        additional_remappings,
314        n_same_block_statement: candidate.n_same_block_statement,
315        remove_enum_construct: !info.demand.vars.contains_key(output),
316        additional_stmts,
317    })
318}
319
320pub struct FixInfo<'db> {
321    /// The location that needs to be fixed,
322    statement_location: (BlockId, usize),
323    /// The block with the match statement that we want to jump over.
324    match_block: BlockId,
325    /// The index of the arm that we want to jump to.
326    arm_idx: usize,
327    /// The target block to jump to.
328    target_block: BlockId,
329    /// Remaps the input of the enum construct to the variable that is introduced by the match arm.
330    remapping: VarRemapping<'db>,
331    /// The blocks that can be reached from the relevant arm of the match.
332    reachable_blocks: OrderedHashSet<BlockId>,
333    /// Additional remappings that appeared in a `Goto` leading to the match.
334    additional_remappings: VarRemapping<'db>,
335    /// The number of statement in the in the same block as the enum construct.
336    n_same_block_statement: usize,
337    /// Indicated that the enum construct statement can be removed.
338    remove_enum_construct: bool,
339    /// Additional statement that appear before the match but not in the same block as the enum
340    /// construct.
341    additional_stmts: Vec<Statement<'db>>,
342}
343
344#[derive(Clone)]
345struct OptimizationCandidate<'db, 'a> {
346    /// The variable that is matched.
347    match_variable: VariableId,
348
349    /// The match arms of the extern match that we are optimizing.
350    match_arms: &'a [MatchArm<'db>],
351
352    /// The block that the match is in.
353    match_block: BlockId,
354
355    /// The demands at the arms.
356    arm_demands: Vec<MatchOptimizerDemand<'db>>,
357
358    /// Whether there is a future merge between the match arms.
359    future_merge: bool,
360
361    /// The blocks that can be reached from each of the arms.
362    arm_reachable_blocks: Vec<OrderedHashSet<BlockId>>,
363
364    /// A remappings that appeared in a `Goto` leading to the match.
365    /// Only one such remapping is allowed as this is typically the case
366    /// after running `optimize_remapping` and `reorder_statements` and it simplifies the
367    /// optimization.
368    remapping: Option<&'a VarRemapping<'db>>,
369
370    /// The statements before the match in reverse order.
371    statement_rev: Vec<&'a Statement<'db>>,
372
373    /// The number of statement in the in the same block as the enum construct.
374    n_same_block_statement: usize,
375}
376
377pub struct MatchOptimizerContext<'db> {
378    fixes: Vec<FixInfo<'db>>,
379}
380
381#[derive(Clone)]
382pub struct AnalysisInfo<'db, 'a> {
383    candidate: Option<OptimizationCandidate<'db, 'a>>,
384    demand: MatchOptimizerDemand<'db>,
385    /// Blocks that can be reached from the current block.
386    reachable_blocks: OrderedHashSet<BlockId>,
387}
388
389impl<'db: 'a, 'a> Analyzer<'db, 'a> for MatchOptimizerContext<'db> {
390    type Info = AnalysisInfo<'db, 'a>;
391
392    fn visit_block_start(&mut self, info: &mut Self::Info, block_id: BlockId, _block: &Block<'db>) {
393        info.reachable_blocks.insert(block_id);
394    }
395
396    fn visit_stmt(
397        &mut self,
398        info: &mut Self::Info,
399        statement_location: StatementLocation,
400        stmt: &'a Statement<'db>,
401    ) {
402        if let Some(mut candidate) = info.candidate.take() {
403            match stmt {
404                Statement::EnumConstruct(enum_construct_stmt)
405                    if enum_construct_stmt.output == candidate.match_variable =>
406                {
407                    if let Some(fix_info) = try_get_fix_info(
408                        enum_construct_stmt,
409                        info,
410                        &mut candidate,
411                        statement_location,
412                    ) {
413                        self.fixes.push(fix_info);
414                        return;
415                    }
416
417                    // Since `candidate.match_variable` was introduced, the candidate is no longer
418                    // applicable.
419                    info.candidate = None;
420                }
421                _ => {
422                    candidate.statement_rev.push(stmt);
423                    candidate.n_same_block_statement += 1;
424                    info.candidate = Some(candidate);
425                }
426            }
427        }
428
429        info.demand.update(stmt);
430    }
431
432    fn visit_goto(
433        &mut self,
434        info: &mut Self::Info,
435        _statement_location: StatementLocation,
436        _target_block_id: BlockId,
437        remapping: &'a VarRemapping<'db>,
438    ) {
439        info.demand.apply_remapping(
440            &mut EmptyDemandReporter {},
441            remapping.iter().map(|(dst, src)| (dst, (&src.var_id, ()))),
442        );
443
444        let Some(candidate) = &mut info.candidate else {
445            return;
446        };
447
448        // New block, so reset the number of statements that are in the same block as the enum
449        // construct.
450        candidate.n_same_block_statement = 0;
451
452        if candidate.future_merge
453            && candidate.statement_rev.iter().any(|stmt| !stmt.outputs().is_empty())
454        {
455            // If we have a future merge and a statement not in the same block as the enum construct
456            // has an output, we cannot apply the optimization.
457            info.candidate = None;
458            return;
459        }
460
461        if remapping.is_empty() {
462            return;
463        }
464
465        if candidate.remapping.is_some() {
466            info.candidate = None;
467            return;
468        }
469
470        // Store the goto's remapping.
471        candidate.remapping = Some(remapping);
472        if let Some(var_usage) = remapping.get(&candidate.match_variable) {
473            candidate.match_variable = var_usage.var_id;
474        }
475    }
476
477    fn merge_match(
478        &mut self,
479        (block_id, _statement_idx): StatementLocation,
480        match_info: &'a MatchInfo<'db>,
481        infos: impl Iterator<Item = Self::Info>,
482    ) -> Self::Info {
483        let (arm_demands, arm_reachable_blocks): (Vec<_>, Vec<_>) =
484            infos.map(|info| (info.demand, info.reachable_blocks)).unzip();
485
486        let arm_demands_without_arm_var = zip_eq(match_info.arms(), &arm_demands)
487            .map(|(arm, demand)| {
488                let mut demand = demand.clone();
489                // Remove the variable that is introduced by the match arm.
490                demand.variables_introduced(&mut EmptyDemandReporter {}, &arm.var_ids, ());
491
492                (demand, ())
493            })
494            .collect_vec();
495        let mut demand = MatchOptimizerDemand::merge_demands(
496            &arm_demands_without_arm_var,
497            &mut EmptyDemandReporter {},
498        );
499
500        // Union the reachable blocks for all the infos.
501        let mut reachable_blocks = OrderedHashSet::default();
502        let mut max_possible_size = 0;
503        for cur_reachable_blocks in &arm_reachable_blocks {
504            reachable_blocks.extend(cur_reachable_blocks.iter().cloned());
505            max_possible_size += cur_reachable_blocks.len();
506        }
507        // If the size of `reachable_blocks` is less than the sum of the sizes of the
508        // `arm_reachable_blocks`, then there was a collision.
509        let found_collision = reachable_blocks.len() < max_possible_size;
510
511        let candidate = match match_info {
512            // A match is a candidate for the optimization if it is a match on an Enum
513            // and its input is unused after the match.
514            MatchInfo::Enum(MatchEnumInfo { input, arms, .. })
515                if !demand.vars.contains_key(&input.var_id) =>
516            {
517                Some(OptimizationCandidate {
518                    match_variable: input.var_id,
519                    match_arms: arms,
520                    match_block: block_id,
521                    arm_demands,
522                    future_merge: found_collision,
523                    arm_reachable_blocks,
524                    remapping: None,
525                    statement_rev: vec![],
526                    n_same_block_statement: 0,
527                })
528            }
529            _ => None,
530        };
531
532        demand.variables_used(
533            &mut EmptyDemandReporter {},
534            match_info.inputs().iter().map(|VarUsage { var_id, .. }| (var_id, ())),
535        );
536
537        Self::Info { candidate, demand, reachable_blocks }
538    }
539
540    fn info_from_return(
541        &mut self,
542        _statement_location: StatementLocation,
543        vars: &[VarUsage<'db>],
544    ) -> Self::Info {
545        let mut demand = MatchOptimizerDemand::default();
546        demand.variables_used(
547            &mut EmptyDemandReporter {},
548            vars.iter().map(|VarUsage { var_id, .. }| (var_id, ())),
549        );
550        Self::Info { candidate: None, demand, reachable_blocks: Default::default() }
551    }
552}