1#[cfg(test)]
2#[path = "match_optimizer_test.rs"]
3mod test;
4
5use cairo_lang_semantic::MatchArmSelector;
6use cairo_lang_utils::ordered_hash_map::OrderedHashMap;
7use cairo_lang_utils::ordered_hash_set::OrderedHashSet;
8use cairo_lang_utils::unordered_hash_map::UnorderedHashMap;
9use itertools::{Itertools, zip_eq};
10
11use super::var_renamer::VarRenamer;
12use crate::borrow_check::Demand;
13use crate::borrow_check::analysis::{Analyzer, BackAnalysis, StatementLocation};
14use crate::borrow_check::demand::EmptyDemandReporter;
15use crate::utils::RebuilderEx;
16use crate::{
17 Block, BlockEnd, BlockId, Lowered, MatchArm, MatchEnumInfo, MatchInfo, Statement,
18 StatementEnumConstruct, VarRemapping, VarUsage, VariableArena, VariableId,
19};
20
21pub type MatchOptimizerDemand<'db> = Demand<VariableId, (), ()>;
22
23impl<'db> MatchOptimizerDemand<'db> {
24 fn update(&mut self, statement: &Statement<'db>) {
25 self.variables_introduced(&mut EmptyDemandReporter {}, statement.outputs(), ());
26 self.variables_used(
27 &mut EmptyDemandReporter {},
28 statement.inputs().iter().map(|VarUsage { var_id, .. }| (var_id, ())),
29 );
30 }
31}
32
33pub fn optimize_matches<'db>(lowered: &mut Lowered<'db>) {
56 if lowered.blocks.is_empty() {
57 return;
58 }
59 let ctx = MatchOptimizerContext { fixes: vec![] };
60 let mut analysis = BackAnalysis::new(lowered, ctx);
61 analysis.get_root_info();
62 let ctx = analysis.analyzer;
63
64 let mut new_blocks = vec![];
65 let mut next_block_id = BlockId(lowered.blocks.len());
66
67 let mut var_renaming = UnorderedHashMap::<(VariableId, usize), VariableId>::default();
90
91 for fix in ctx.fixes {
95 let mut new_remapping = fix.remapping.clone();
98 let mut renamed_vars = OrderedHashMap::<VariableId, VariableId>::default();
99 for (var, dst) in fix.additional_remappings.iter() {
100 let new_var = *var_renaming
102 .entry((*var, fix.arm_idx))
103 .or_insert_with(|| lowered.variables.alloc(lowered.variables[*var].clone()));
104 new_remapping.insert(new_var, *dst);
105 renamed_vars.insert(*var, new_var);
106 }
107
108 let block = &mut lowered.blocks[fix.statement_location.0];
109 assert_eq!(
110 block.statements.len() - 1,
111 fix.statement_location.1 + fix.n_same_block_statement,
112 "Unexpected number of statements in block."
113 );
114
115 if fix.remove_enum_construct {
116 block.statements.remove(fix.statement_location.1);
117 }
118
119 handle_additional_statements(
120 &mut lowered.variables,
121 &mut var_renaming,
122 &mut new_remapping,
123 &mut renamed_vars,
124 block,
125 &fix,
126 );
127
128 block.end = BlockEnd::Goto(fix.target_block, new_remapping);
129 if fix.statement_location.0 == fix.match_block {
130 assert!(fix.additional_remappings.remapping.is_empty());
133 continue;
134 }
135
136 let block = &mut lowered.blocks[fix.match_block];
137 let BlockEnd::Match { info: MatchInfo::Enum(MatchEnumInfo { arms, location, .. }) } =
138 &mut block.end
139 else {
140 unreachable!("match block should end with a match.");
141 };
142
143 let arm = arms.get_mut(fix.arm_idx).unwrap();
144 if fix.target_block != arm.block_id {
145 continue;
147 }
148
149 let arm_var = arm.var_ids.get_mut(0).unwrap();
152 let orig_var = *arm_var;
153 *arm_var = lowered.variables.alloc(lowered.variables[orig_var].clone());
154 let mut new_block_remapping: VarRemapping<'_> = Default::default();
155
156 new_block_remapping.insert(orig_var, VarUsage { var_id: *arm_var, location: *location });
157 for (var, new_var) in renamed_vars.iter() {
158 new_block_remapping.insert(*new_var, VarUsage { var_id: *var, location: *location });
159 }
160
161 new_blocks.push(Block {
162 statements: vec![],
163 end: BlockEnd::Goto(arm.block_id, new_block_remapping),
164 });
165 arm.block_id = next_block_id;
166 next_block_id = next_block_id.next_block_id();
167
168 let mut var_renamer = VarRenamer { renamed_vars: renamed_vars.into_iter().collect() };
169 for block_id in fix.reachable_blocks {
171 let block = &mut lowered.blocks[block_id];
172 *block = var_renamer.rebuild_block(block);
173 }
174 }
175
176 for block in new_blocks {
177 lowered.blocks.push(block);
178 }
179}
180
181fn handle_additional_statements<'db>(
189 variables: &mut VariableArena<'db>,
190 var_renaming: &mut UnorderedHashMap<(VariableId, usize), VariableId>,
191 new_remapping: &mut VarRemapping<'db>,
192 renamed_vars: &mut OrderedHashMap<VariableId, VariableId>,
193 block: &mut Block<'db>,
194 fix: &FixInfo<'db>,
195) {
196 if fix.additional_stmts.is_empty() {
197 return;
198 }
199
200 let mut inputs_remapping = UnorderedHashMap::<VariableId, VariableId>::from_iter(
204 fix.additional_remappings.iter().map(|(k, v)| (*k, v.var_id)),
205 );
206 for mut stmt in fix.additional_stmts.iter().cloned() {
207 for input in stmt.inputs_mut() {
208 if let Some(orig_var) = inputs_remapping.get(&input.var_id) {
209 input.var_id = *orig_var;
210 }
211 }
212
213 for output in stmt.outputs_mut() {
214 let orig_output = *output;
215 *output = variables.alloc(variables[*output].clone());
217 inputs_remapping.insert(orig_output, *output);
218
219 let new_output = *var_renaming
221 .entry((orig_output, fix.arm_idx))
222 .or_insert_with(|| variables.alloc(variables[*output].clone()));
223 let location = variables[*output].location;
224 new_remapping.insert(new_output, VarUsage { var_id: *output, location });
225 renamed_vars.insert(orig_output, new_output);
226 }
227
228 block.statements.push(stmt);
229 }
230}
231
232fn try_get_fix_info<'db>(
236 StatementEnumConstruct { variant, input, output }: &StatementEnumConstruct<'db>,
237 info: &mut AnalysisInfo<'db, '_>,
238 candidate: &mut OptimizationCandidate<'db, '_>,
239 statement_location: (BlockId, usize),
240) -> Option<FixInfo<'db>> {
241 let (arm_idx, arm) = candidate
242 .match_arms
243 .iter()
244 .find_position(
245 |arm| matches!(&arm.arm_selector, MatchArmSelector::VariantId(v) if v == variant),
246 )
247 .expect("arm not found.");
248
249 let [var_id] = arm.var_ids.as_slice() else {
250 panic!("An arm of an EnumMatch should produce a single variable.");
251 };
252
253 let mut remapping = VarRemapping::default();
256 remapping.insert(*var_id, *input);
257
258 let mut demand = std::mem::take(&mut candidate.arm_demands[arm_idx]);
262
263 let additional_stmts = candidate
264 .statement_rev
265 .iter()
266 .rev()
267 .skip(candidate.n_same_block_statement)
268 .cloned()
269 .cloned()
270 .collect_vec();
271 for stmt in &additional_stmts {
272 demand.update(stmt);
273 }
274
275 demand
276 .apply_remapping(&mut EmptyDemandReporter {}, [(var_id, (&input.var_id, ()))].into_iter());
277
278 let additional_remappings = match candidate.remapping {
279 Some(remappings) => {
280 VarRemapping {
282 remapping: OrderedHashMap::from_iter(remappings.iter().filter_map(|(dst, src)| {
283 if demand.vars.contains_key(dst) { Some((*dst, *src)) } else { None }
284 })),
285 }
286 }
287 None => VarRemapping::default(),
288 };
289
290 if !additional_remappings.is_empty() && candidate.future_merge {
291 return None;
293 }
294
295 demand.apply_remapping(
296 &mut EmptyDemandReporter {},
297 additional_remappings.iter().map(|(dst, src_var_usage)| (dst, (&src_var_usage.var_id, ()))),
298 );
299
300 for stmt in candidate.statement_rev.iter().rev() {
301 demand.update(stmt);
302 }
303 info.demand = demand;
304 info.reachable_blocks = std::mem::take(&mut candidate.arm_reachable_blocks[arm_idx]);
305
306 Some(FixInfo {
307 statement_location,
308 match_block: candidate.match_block,
309 arm_idx,
310 target_block: arm.block_id,
311 remapping,
312 reachable_blocks: info.reachable_blocks.clone(),
313 additional_remappings,
314 n_same_block_statement: candidate.n_same_block_statement,
315 remove_enum_construct: !info.demand.vars.contains_key(output),
316 additional_stmts,
317 })
318}
319
320pub struct FixInfo<'db> {
321 statement_location: (BlockId, usize),
323 match_block: BlockId,
325 arm_idx: usize,
327 target_block: BlockId,
329 remapping: VarRemapping<'db>,
331 reachable_blocks: OrderedHashSet<BlockId>,
333 additional_remappings: VarRemapping<'db>,
335 n_same_block_statement: usize,
337 remove_enum_construct: bool,
339 additional_stmts: Vec<Statement<'db>>,
342}
343
344#[derive(Clone)]
345struct OptimizationCandidate<'db, 'a> {
346 match_variable: VariableId,
348
349 match_arms: &'a [MatchArm<'db>],
351
352 match_block: BlockId,
354
355 arm_demands: Vec<MatchOptimizerDemand<'db>>,
357
358 future_merge: bool,
360
361 arm_reachable_blocks: Vec<OrderedHashSet<BlockId>>,
363
364 remapping: Option<&'a VarRemapping<'db>>,
369
370 statement_rev: Vec<&'a Statement<'db>>,
372
373 n_same_block_statement: usize,
375}
376
377pub struct MatchOptimizerContext<'db> {
378 fixes: Vec<FixInfo<'db>>,
379}
380
381#[derive(Clone)]
382pub struct AnalysisInfo<'db, 'a> {
383 candidate: Option<OptimizationCandidate<'db, 'a>>,
384 demand: MatchOptimizerDemand<'db>,
385 reachable_blocks: OrderedHashSet<BlockId>,
387}
388
389impl<'db: 'a, 'a> Analyzer<'db, 'a> for MatchOptimizerContext<'db> {
390 type Info = AnalysisInfo<'db, 'a>;
391
392 fn visit_block_start(&mut self, info: &mut Self::Info, block_id: BlockId, _block: &Block<'db>) {
393 info.reachable_blocks.insert(block_id);
394 }
395
396 fn visit_stmt(
397 &mut self,
398 info: &mut Self::Info,
399 statement_location: StatementLocation,
400 stmt: &'a Statement<'db>,
401 ) {
402 if let Some(mut candidate) = info.candidate.take() {
403 match stmt {
404 Statement::EnumConstruct(enum_construct_stmt)
405 if enum_construct_stmt.output == candidate.match_variable =>
406 {
407 if let Some(fix_info) = try_get_fix_info(
408 enum_construct_stmt,
409 info,
410 &mut candidate,
411 statement_location,
412 ) {
413 self.fixes.push(fix_info);
414 return;
415 }
416
417 info.candidate = None;
420 }
421 _ => {
422 candidate.statement_rev.push(stmt);
423 candidate.n_same_block_statement += 1;
424 info.candidate = Some(candidate);
425 }
426 }
427 }
428
429 info.demand.update(stmt);
430 }
431
432 fn visit_goto(
433 &mut self,
434 info: &mut Self::Info,
435 _statement_location: StatementLocation,
436 _target_block_id: BlockId,
437 remapping: &'a VarRemapping<'db>,
438 ) {
439 info.demand.apply_remapping(
440 &mut EmptyDemandReporter {},
441 remapping.iter().map(|(dst, src)| (dst, (&src.var_id, ()))),
442 );
443
444 let Some(candidate) = &mut info.candidate else {
445 return;
446 };
447
448 candidate.n_same_block_statement = 0;
451
452 if candidate.future_merge
453 && candidate.statement_rev.iter().any(|stmt| !stmt.outputs().is_empty())
454 {
455 info.candidate = None;
458 return;
459 }
460
461 if remapping.is_empty() {
462 return;
463 }
464
465 if candidate.remapping.is_some() {
466 info.candidate = None;
467 return;
468 }
469
470 candidate.remapping = Some(remapping);
472 if let Some(var_usage) = remapping.get(&candidate.match_variable) {
473 candidate.match_variable = var_usage.var_id;
474 }
475 }
476
477 fn merge_match(
478 &mut self,
479 (block_id, _statement_idx): StatementLocation,
480 match_info: &'a MatchInfo<'db>,
481 infos: impl Iterator<Item = Self::Info>,
482 ) -> Self::Info {
483 let (arm_demands, arm_reachable_blocks): (Vec<_>, Vec<_>) =
484 infos.map(|info| (info.demand, info.reachable_blocks)).unzip();
485
486 let arm_demands_without_arm_var = zip_eq(match_info.arms(), &arm_demands)
487 .map(|(arm, demand)| {
488 let mut demand = demand.clone();
489 demand.variables_introduced(&mut EmptyDemandReporter {}, &arm.var_ids, ());
491
492 (demand, ())
493 })
494 .collect_vec();
495 let mut demand = MatchOptimizerDemand::merge_demands(
496 &arm_demands_without_arm_var,
497 &mut EmptyDemandReporter {},
498 );
499
500 let mut reachable_blocks = OrderedHashSet::default();
502 let mut max_possible_size = 0;
503 for cur_reachable_blocks in &arm_reachable_blocks {
504 reachable_blocks.extend(cur_reachable_blocks.iter().cloned());
505 max_possible_size += cur_reachable_blocks.len();
506 }
507 let found_collision = reachable_blocks.len() < max_possible_size;
510
511 let candidate = match match_info {
512 MatchInfo::Enum(MatchEnumInfo { input, arms, .. })
515 if !demand.vars.contains_key(&input.var_id) =>
516 {
517 Some(OptimizationCandidate {
518 match_variable: input.var_id,
519 match_arms: arms,
520 match_block: block_id,
521 arm_demands,
522 future_merge: found_collision,
523 arm_reachable_blocks,
524 remapping: None,
525 statement_rev: vec![],
526 n_same_block_statement: 0,
527 })
528 }
529 _ => None,
530 };
531
532 demand.variables_used(
533 &mut EmptyDemandReporter {},
534 match_info.inputs().iter().map(|VarUsage { var_id, .. }| (var_id, ())),
535 );
536
537 Self::Info { candidate, demand, reachable_blocks }
538 }
539
540 fn info_from_return(
541 &mut self,
542 _statement_location: StatementLocation,
543 vars: &[VarUsage<'db>],
544 ) -> Self::Info {
545 let mut demand = MatchOptimizerDemand::default();
546 demand.variables_used(
547 &mut EmptyDemandReporter {},
548 vars.iter().map(|VarUsage { var_id, .. }| (var_id, ())),
549 );
550 Self::Info { candidate: None, demand, reachable_blocks: Default::default() }
551 }
552}