1#[cfg(test)]
2#[path = "match_optimizer_test.rs"]
3mod test;
4
5use cairo_lang_semantic::{ConcreteVariant, MatchArmSelector};
6use cairo_lang_utils::ordered_hash_map::OrderedHashMap;
7use cairo_lang_utils::ordered_hash_set::OrderedHashSet;
8use cairo_lang_utils::unordered_hash_map::UnorderedHashMap;
9use itertools::{Itertools, zip_eq};
10use salsa::Database;
11
12use super::var_renamer::VarRenamer;
13use crate::analysis::def_site::{DefSiteAnalysis, DefSites};
14use crate::analysis::dominator::Dominators;
15use crate::analysis::equality_analysis::{EqualityAnalysis, EqualityState};
16use crate::analysis::{Analyzer, BackAnalysis, DefLocation, StatementLocation};
17use crate::borrow_check::Demand;
18use crate::borrow_check::demand::EmptyDemandReporter;
19use crate::ids::LocationId;
20use crate::utils::RebuilderEx;
21use crate::{
22 Block, BlockEnd, BlockId, Lowered, MatchArm, MatchEnumInfo, MatchInfo, Statement, VarRemapping,
23 VarUsage, VariableArena, VariableId,
24};
25
26pub type MatchOptimizerDemand<'db> = Demand<VariableId, (), ()>;
27
28impl<'db> MatchOptimizerDemand<'db> {
29 fn update(&mut self, statement: &Statement<'db>) {
30 self.variables_introduced(&mut EmptyDemandReporter {}, statement.outputs(), ());
31 self.variables_used(
32 &mut EmptyDemandReporter {},
33 statement.inputs().iter().map(|VarUsage { var_id, .. }| (var_id, ())),
34 );
35 }
36}
37
38pub fn optimize_matches<'db>(db: &'db dyn Database, lowered: &mut Lowered<'db>) {
61 if lowered.blocks.is_empty() {
62 return;
63 }
64 let ctx = MatchOptimizerContext::new(db, lowered);
65 let mut analysis = BackAnalysis::new(&*lowered, ctx);
66 analysis.get_root_info();
67
68 let mut new_blocks = vec![];
69 let mut next_block_id = BlockId(lowered.blocks.len());
70
71 let mut var_renaming = UnorderedHashMap::<(VariableId, usize), VariableId>::default();
94
95 for fix in analysis.analyzer.fixes {
99 let mut new_remapping = fix.remapping.clone();
102 let mut renamed_vars = OrderedHashMap::<VariableId, VariableId>::default();
103 for (var, dst) in fix.additional_remappings.iter() {
104 let new_var = *var_renaming
106 .entry((*var, fix.arm_idx))
107 .or_insert_with(|| lowered.variables.alloc(lowered.variables[*var].clone()));
108 new_remapping.insert(new_var, *dst);
109 renamed_vars.insert(*var, new_var);
110 }
111
112 let block = &mut lowered.blocks[fix.fix_block];
113 if let Some(UpstreamEnumConstruct { stmt_idx, n_same_block_statement, remove }) =
114 &fix.enum_construct
115 {
116 assert_eq!(
117 block.statements.len() - 1,
118 stmt_idx + n_same_block_statement,
119 "Unexpected number of statements in block."
120 );
121 if *remove {
122 block.statements.remove(*stmt_idx);
123 }
124 }
125
126 handle_additional_statements(
127 &mut lowered.variables,
128 &mut var_renaming,
129 &mut new_remapping,
130 &mut renamed_vars,
131 block,
132 &fix,
133 );
134
135 block.end = BlockEnd::Goto(fix.target_block, new_remapping);
136 if fix.fix_block == fix.match_block {
137 assert!(fix.additional_remappings.remapping.is_empty());
140 continue;
141 }
142
143 let block = &mut lowered.blocks[fix.match_block];
144 let BlockEnd::Match { info: MatchInfo::Enum(MatchEnumInfo { arms, location, .. }) } =
145 &mut block.end
146 else {
147 unreachable!("match block should end with a match.");
148 };
149
150 let arm = arms.get_mut(fix.arm_idx).unwrap();
151 if fix.target_block != arm.block_id {
152 continue;
154 }
155
156 let arm_var = arm.var_ids.get_mut(0).unwrap();
159 let orig_var = *arm_var;
160 *arm_var = lowered.variables.alloc(lowered.variables[orig_var].clone());
161 let mut new_block_remapping: VarRemapping<'_> = Default::default();
162
163 new_block_remapping.insert(orig_var, VarUsage { var_id: *arm_var, location: *location });
164 for (var, new_var) in renamed_vars.iter() {
165 new_block_remapping.insert(*new_var, VarUsage { var_id: *var, location: *location });
166 }
167
168 new_blocks.push(Block {
169 statements: vec![],
170 end: BlockEnd::Goto(arm.block_id, new_block_remapping),
171 });
172 arm.block_id = next_block_id;
173 next_block_id = next_block_id.next_block_id();
174
175 let mut var_renamer = VarRenamer { renamed_vars: renamed_vars.into_iter().collect() };
176 for block_id in fix.reachable_blocks {
178 let block = &mut lowered.blocks[block_id];
179 *block = var_renamer.rebuild_block(block);
180 }
181 }
182
183 for block in new_blocks {
184 lowered.blocks.push(block);
185 }
186}
187
188fn handle_additional_statements<'db>(
196 variables: &mut VariableArena<'db>,
197 var_renaming: &mut UnorderedHashMap<(VariableId, usize), VariableId>,
198 new_remapping: &mut VarRemapping<'db>,
199 renamed_vars: &mut OrderedHashMap<VariableId, VariableId>,
200 block: &mut Block<'db>,
201 fix: &FixInfo<'db>,
202) {
203 if fix.additional_stmts.is_empty() {
204 return;
205 }
206
207 let mut inputs_remapping = UnorderedHashMap::<VariableId, VariableId>::from_iter(
211 fix.additional_remappings.iter().map(|(k, v)| (*k, v.var_id)),
212 );
213 for mut stmt in fix.additional_stmts.iter().cloned() {
214 for input in stmt.inputs_mut() {
215 if let Some(orig_var) = inputs_remapping.get(&input.var_id) {
216 input.var_id = *orig_var;
217 }
218 }
219
220 for output in stmt.outputs_mut() {
221 let orig_output = *output;
222 *output = variables.alloc(variables[*output].clone());
224 inputs_remapping.insert(orig_output, *output);
225
226 let new_output = *var_renaming
228 .entry((orig_output, fix.arm_idx))
229 .or_insert_with(|| variables.alloc(variables[*output].clone()));
230 let location = variables[*output].location;
231 new_remapping.insert(new_output, VarUsage { var_id: *output, location });
232 renamed_vars.insert(orig_output, new_output);
233 }
234
235 block.statements.push(stmt);
236 }
237}
238
239fn try_get_fix_info<'db>(
249 variant: &ConcreteVariant<'db>,
250 inner: VarUsage<'db>,
251 enum_construct: Option<(usize, VariableId)>,
252 fix_block: BlockId,
253 info: &mut AnalysisInfo<'db, '_>,
254 mut candidate: OptimizationCandidate<'db, '_>,
255) -> Option<FixInfo<'db>> {
256 let (arm_idx, arm) = candidate
257 .match_arms
258 .iter()
259 .find_position(
260 |arm| matches!(&arm.arm_selector, MatchArmSelector::VariantId(v) if v == variant),
261 )
262 .expect("arm not found.");
263
264 let [var_id] = arm.var_ids.as_slice() else {
265 panic!("An arm of an EnumMatch should produce a single variable.");
266 };
267
268 let mut remapping = VarRemapping::default();
271 remapping.insert(*var_id, inner);
272
273 let mut demand = std::mem::take(&mut candidate.arm_demands[arm_idx]);
277
278 let additional_stmts = candidate
279 .statement_rev
280 .iter()
281 .rev()
282 .skip(candidate.n_same_block_statement)
283 .cloned()
284 .cloned()
285 .collect_vec();
286 for stmt in &additional_stmts {
287 demand.update(stmt);
288 }
289
290 demand
291 .apply_remapping(&mut EmptyDemandReporter {}, [(var_id, (&inner.var_id, ()))].into_iter());
292
293 let additional_remappings = match candidate.remapping {
294 Some(remappings) => {
295 VarRemapping {
297 remapping: OrderedHashMap::from_iter(remappings.iter().filter_map(|(dst, src)| {
298 if demand.vars.contains_key(dst) { Some((*dst, *src)) } else { None }
299 })),
300 }
301 }
302 None => VarRemapping::default(),
303 };
304
305 if !additional_remappings.is_empty() && candidate.future_merge {
306 return None;
308 }
309
310 demand.apply_remapping(
311 &mut EmptyDemandReporter {},
312 additional_remappings.iter().map(|(dst, src_var_usage)| (dst, (&src_var_usage.var_id, ()))),
313 );
314
315 for stmt in candidate.statement_rev.iter().rev() {
316 demand.update(stmt);
317 }
318 info.demand = demand;
319 info.reachable_blocks = std::mem::take(&mut candidate.arm_reachable_blocks[arm_idx]);
320
321 let enum_construct = enum_construct.map(|(stmt_idx, output)| UpstreamEnumConstruct {
322 stmt_idx,
323 n_same_block_statement: candidate.n_same_block_statement,
324 remove: !info.demand.vars.contains_key(&output),
325 });
326
327 Some(FixInfo {
328 fix_block,
329 match_block: candidate.match_block,
330 arm_idx,
331 target_block: arm.block_id,
332 remapping,
333 reachable_blocks: info.reachable_blocks.clone(),
334 additional_remappings,
335 enum_construct,
336 additional_stmts,
337 })
338}
339
340pub struct FixInfo<'db> {
341 fix_block: BlockId,
343 match_block: BlockId,
345 arm_idx: usize,
347 target_block: BlockId,
349 remapping: VarRemapping<'db>,
351 reachable_blocks: OrderedHashSet<BlockId>,
353 additional_remappings: VarRemapping<'db>,
355 enum_construct: Option<UpstreamEnumConstruct>,
358 additional_stmts: Vec<Statement<'db>>,
361}
362
363struct UpstreamEnumConstruct {
365 stmt_idx: usize,
367 n_same_block_statement: usize,
370 remove: bool,
374}
375
376#[derive(Clone)]
377struct OptimizationCandidate<'db, 'a> {
378 match_variable: VariableId,
380
381 match_arms: &'a [MatchArm<'db>],
383
384 match_block: BlockId,
386
387 match_location: LocationId<'db>,
389
390 arm_demands: Vec<MatchOptimizerDemand<'db>>,
392
393 future_merge: bool,
395
396 arm_reachable_blocks: Vec<OrderedHashSet<BlockId>>,
398
399 remapping: Option<&'a VarRemapping<'db>>,
404
405 statement_rev: Vec<&'a Statement<'db>>,
407
408 n_same_block_statement: usize,
410}
411
412pub struct MatchOptimizerContext<'db, 'a> {
413 fixes: Vec<FixInfo<'db>>,
414 equalities: Vec<Option<EqualityState<'db>>>,
415 dominators: Dominators,
416 def_sites: DefSites,
417 variables: &'a VariableArena<'db>,
420}
421
422impl<'db, 'a> MatchOptimizerContext<'db, 'a> {
423 fn new(db: &'db dyn Database, lowered: &'a Lowered<'db>) -> Self {
424 Self {
425 fixes: vec![],
426 equalities: EqualityAnalysis::analyze(db, lowered),
427 dominators: Dominators::analyze(lowered),
428 def_sites: DefSiteAnalysis::analyze(lowered),
429 variables: &lowered.variables,
430 }
431 }
432
433 fn get_enum_variant(
435 &mut self,
436 block_id: BlockId,
437 var: VariableId,
438 ) -> Option<(ConcreteVariant<'db>, VariableId)> {
439 self.equalities[block_id.0].as_mut()?.get_enum_construct(var)
440 }
441
442 fn is_visible_at_block_end(&self, var_id: VariableId, at: BlockId) -> bool {
444 let Some(loc) = self.def_sites[var_id.index()] else { return false };
445 let block_id = match loc {
446 DefLocation::Statement((b, _)) | DefLocation::BlockEntry(b) => b,
447 };
448 self.dominators.dominates(block_id, at)
449 }
450}
451
452#[derive(Clone)]
453pub struct AnalysisInfo<'db, 'a> {
454 candidate: Option<OptimizationCandidate<'db, 'a>>,
455 demand: MatchOptimizerDemand<'db>,
456 reachable_blocks: OrderedHashSet<BlockId>,
458}
459
460impl<'db: 'a, 'a> Analyzer<'db, 'a> for MatchOptimizerContext<'db, 'a> {
461 type Info = AnalysisInfo<'db, 'a>;
462
463 fn visit_block_start(&mut self, info: &mut Self::Info, block_id: BlockId, _block: &Block<'db>) {
464 info.reachable_blocks.insert(block_id);
465 }
466
467 fn visit_stmt(
468 &mut self,
469 info: &mut Self::Info,
470 statement_location: StatementLocation,
471 stmt: &'a Statement<'db>,
472 ) {
473 if let Some(mut candidate) = info.candidate.take() {
474 match stmt {
475 Statement::EnumConstruct(enum_construct_stmt)
476 if enum_construct_stmt.output == candidate.match_variable =>
477 {
478 if let Some(fix_info) = try_get_fix_info(
479 &enum_construct_stmt.variant,
480 enum_construct_stmt.input,
481 Some((statement_location.1, enum_construct_stmt.output)),
482 statement_location.0,
483 info,
484 candidate,
485 ) {
486 self.fixes.push(fix_info);
487 return;
488 }
489 }
490 _ => {
491 candidate.statement_rev.push(stmt);
492 candidate.n_same_block_statement += 1;
493 info.candidate = Some(candidate);
494 }
495 }
496 }
497
498 info.demand.update(stmt);
499 }
500
501 fn visit_goto(
502 &mut self,
503 info: &mut Self::Info,
504 _statement_location: StatementLocation,
505 _target_block_id: BlockId,
506 remapping: &'a VarRemapping<'db>,
507 ) {
508 info.demand.apply_remapping(
509 &mut EmptyDemandReporter {},
510 remapping.iter().map(|(dst, src)| (dst, (&src.var_id, ()))),
511 );
512
513 let Some(candidate) = &mut info.candidate else {
514 return;
515 };
516
517 candidate.n_same_block_statement = 0;
520
521 if candidate.future_merge
522 && candidate.statement_rev.iter().any(|stmt| !stmt.outputs().is_empty())
523 {
524 info.candidate = None;
527 return;
528 }
529
530 if remapping.is_empty() {
531 return;
532 }
533
534 if candidate.remapping.is_some() {
535 info.candidate = None;
536 return;
537 }
538
539 candidate.remapping = Some(remapping);
541 if let Some(var_usage) = remapping.get(&candidate.match_variable) {
542 candidate.match_variable = var_usage.var_id;
543 }
544 }
545
546 fn merge_match(
547 &mut self,
548 (block_id, _statement_idx): StatementLocation,
549 match_info: &'a MatchInfo<'db>,
550 infos: impl Iterator<Item = Self::Info>,
551 ) -> Self::Info {
552 let (arm_demands, arm_reachable_blocks): (Vec<_>, Vec<_>) =
553 infos.map(|info| (info.demand, info.reachable_blocks)).unzip();
554
555 let arm_demands_without_arm_var = zip_eq(match_info.arms(), &arm_demands)
556 .map(|(arm, demand)| {
557 let mut demand = demand.clone();
558 demand.variables_introduced(&mut EmptyDemandReporter {}, &arm.var_ids, ());
560
561 (demand, ())
562 })
563 .collect_vec();
564 let mut demand = MatchOptimizerDemand::merge_demands(
565 &arm_demands_without_arm_var,
566 &mut EmptyDemandReporter {},
567 );
568
569 let mut reachable_blocks = OrderedHashSet::default();
571 let mut max_possible_size = 0;
572 for cur_reachable_blocks in &arm_reachable_blocks {
573 reachable_blocks.extend(cur_reachable_blocks.iter().cloned());
574 max_possible_size += cur_reachable_blocks.len();
575 }
576 let found_collision = reachable_blocks.len() < max_possible_size;
579
580 let candidate = match match_info {
587 MatchInfo::Enum(MatchEnumInfo { input, arms, location, .. })
588 if !demand.vars.contains_key(&input.var_id)
589 || self.variables[input.var_id].info.droppable.is_ok() =>
590 {
591 let candidate = OptimizationCandidate {
592 match_variable: input.var_id,
593 match_arms: arms,
594 match_block: block_id,
595 match_location: *location,
596 arm_demands,
597 future_merge: found_collision,
598 arm_reachable_blocks,
599 remapping: None,
600 statement_rev: vec![],
601 n_same_block_statement: 0,
602 };
603 if let Some((variant, payload_var)) =
608 self.get_enum_variant(block_id, candidate.match_variable)
609 && self.variables[candidate.match_variable].info.droppable.is_ok()
612 && self.variables[payload_var].info.copyable.is_ok()
614 {
615 assert!(self.is_visible_at_block_end(payload_var, block_id));
618 let mut info = Self::Info {
619 candidate: None,
620 demand: MatchOptimizerDemand::default(),
621 reachable_blocks: reachable_blocks.clone(),
622 };
623 let fix_info = try_get_fix_info(
627 &variant,
628 VarUsage { var_id: payload_var, location: candidate.match_location },
629 None,
630 block_id,
631 &mut info,
632 candidate,
633 )
634 .expect("equality-fold has no additional remappings, cannot fail");
635 self.fixes.push(fix_info);
636 return info;
637 }
638 Some(candidate)
639 }
640 _ => None,
641 };
642
643 demand.variables_used(
644 &mut EmptyDemandReporter {},
645 match_info.inputs().iter().map(|VarUsage { var_id, .. }| (var_id, ())),
646 );
647
648 Self::Info { candidate, demand, reachable_blocks }
649 }
650
651 fn info_from_return(
652 &mut self,
653 _statement_location: StatementLocation,
654 vars: &[VarUsage<'db>],
655 ) -> Self::Info {
656 let mut demand = MatchOptimizerDemand::default();
657 demand.variables_used(
658 &mut EmptyDemandReporter {},
659 vars.iter().map(|VarUsage { var_id, .. }| (var_id, ())),
660 );
661 Self::Info { candidate: None, demand, reachable_blocks: Default::default() }
662 }
663}