1#[cfg(test)]
2#[path = "match_optimizer_test.rs"]
3mod test;
4
5use cairo_lang_semantic::MatchArmSelector;
6use cairo_lang_utils::ordered_hash_map::OrderedHashMap;
7use cairo_lang_utils::ordered_hash_set::OrderedHashSet;
8use cairo_lang_utils::unordered_hash_map::UnorderedHashMap;
9use id_arena::Arena;
10use itertools::{Itertools, zip_eq};
11
12use super::var_renamer::VarRenamer;
13use crate::borrow_check::Demand;
14use crate::borrow_check::analysis::{Analyzer, BackAnalysis, StatementLocation};
15use crate::borrow_check::demand::EmptyDemandReporter;
16use crate::utils::RebuilderEx;
17use crate::{
18 Block, BlockEnd, BlockId, Lowered, MatchArm, MatchEnumInfo, MatchInfo, Statement,
19 StatementEnumConstruct, VarRemapping, VarUsage, Variable, VariableId,
20};
21
22pub type MatchOptimizerDemand = Demand<VariableId, (), ()>;
23
24impl MatchOptimizerDemand {
25 fn update(&mut self, statement: &Statement) {
26 self.variables_introduced(&mut EmptyDemandReporter {}, statement.outputs(), ());
27 self.variables_used(
28 &mut EmptyDemandReporter {},
29 statement.inputs().iter().map(|VarUsage { var_id, .. }| (var_id, ())),
30 );
31 }
32}
33
34pub fn optimize_matches(lowered: &mut Lowered) {
57 if lowered.blocks.is_empty() {
58 return;
59 }
60 let ctx = MatchOptimizerContext { fixes: vec![] };
61 let mut analysis = BackAnalysis::new(lowered, ctx);
62 analysis.get_root_info();
63 let ctx = analysis.analyzer;
64
65 let mut new_blocks = vec![];
66 let mut next_block_id = BlockId(lowered.blocks.len());
67
68 let mut var_renaming = UnorderedHashMap::<(VariableId, usize), VariableId>::default();
91
92 for fix in ctx.fixes {
96 let mut new_remapping = fix.remapping.clone();
99 let mut renamed_vars = OrderedHashMap::<VariableId, VariableId>::default();
100 for (var, dst) in fix.additional_remappings.iter() {
101 let new_var = *var_renaming
103 .entry((*var, fix.arm_idx))
104 .or_insert_with(|| lowered.variables.alloc(lowered.variables[*var].clone()));
105 new_remapping.insert(new_var, *dst);
106 renamed_vars.insert(*var, new_var);
107 }
108
109 let block = &mut lowered.blocks[fix.statement_location.0];
110 assert_eq!(
111 block.statements.len() - 1,
112 fix.statement_location.1 + fix.n_same_block_statement,
113 "Unexpected number of statements in block."
114 );
115
116 if fix.remove_enum_construct {
117 block.statements.remove(fix.statement_location.1);
118 }
119
120 handle_additional_statements(
121 &mut lowered.variables,
122 &mut var_renaming,
123 &mut new_remapping,
124 &mut renamed_vars,
125 block,
126 &fix,
127 );
128
129 block.end = BlockEnd::Goto(fix.target_block, new_remapping);
130 if fix.statement_location.0 == fix.match_block {
131 assert!(fix.additional_remappings.remapping.is_empty());
134 continue;
135 }
136
137 let block = &mut lowered.blocks[fix.match_block];
138 let BlockEnd::Match { info: MatchInfo::Enum(MatchEnumInfo { arms, location, .. }) } =
139 &mut block.end
140 else {
141 unreachable!("match block should end with a match.");
142 };
143
144 let arm = arms.get_mut(fix.arm_idx).unwrap();
145 if fix.target_block != arm.block_id {
146 continue;
148 }
149
150 let arm_var = arm.var_ids.get_mut(0).unwrap();
153 let orig_var = *arm_var;
154 *arm_var = lowered.variables.alloc(lowered.variables[orig_var].clone());
155 let mut new_block_remapping: VarRemapping = Default::default();
156
157 new_block_remapping.insert(orig_var, VarUsage { var_id: *arm_var, location: *location });
158 for (var, new_var) in renamed_vars.iter() {
159 new_block_remapping.insert(*new_var, VarUsage { var_id: *var, location: *location });
160 }
161
162 new_blocks.push(Block {
163 statements: vec![],
164 end: BlockEnd::Goto(arm.block_id, new_block_remapping),
165 });
166 arm.block_id = next_block_id;
167 next_block_id = next_block_id.next_block_id();
168
169 let mut var_renamer = VarRenamer { renamed_vars: renamed_vars.into_iter().collect() };
170 for block_id in fix.reachable_blocks {
172 let block = &mut lowered.blocks[block_id];
173 *block = var_renamer.rebuild_block(block);
174 }
175 }
176
177 for block in new_blocks {
178 lowered.blocks.push(block);
179 }
180}
181
182fn handle_additional_statements(
190 variables: &mut Arena<Variable>,
191 var_renaming: &mut UnorderedHashMap<(VariableId, usize), VariableId>,
192 new_remapping: &mut VarRemapping,
193 renamed_vars: &mut OrderedHashMap<VariableId, VariableId>,
194 block: &mut Block,
195 fix: &FixInfo,
196) {
197 if fix.additional_stmts.is_empty() {
198 return;
199 }
200
201 let mut inputs_remapping = UnorderedHashMap::<VariableId, VariableId>::from_iter(
205 fix.additional_remappings.iter().map(|(k, v)| (*k, v.var_id)),
206 );
207 for mut stmt in fix.additional_stmts.iter().cloned() {
208 for input in stmt.inputs_mut() {
209 if let Some(orig_var) = inputs_remapping.get(&input.var_id) {
210 input.var_id = *orig_var;
211 }
212 }
213
214 for output in stmt.outputs_mut() {
215 let orig_output = *output;
216 *output = variables.alloc(variables[*output].clone());
218 inputs_remapping.insert(orig_output, *output);
219
220 let new_output = *var_renaming
222 .entry((orig_output, fix.arm_idx))
223 .or_insert_with(|| variables.alloc(variables[*output].clone()));
224 let location = variables[*output].location;
225 new_remapping.insert(new_output, VarUsage { var_id: *output, location });
226 renamed_vars.insert(orig_output, new_output);
227 }
228
229 block.statements.push(stmt);
230 }
231}
232
233fn try_get_fix_info(
237 StatementEnumConstruct { variant, input, output }: &StatementEnumConstruct,
238 info: &mut AnalysisInfo<'_>,
239 candidate: &mut OptimizationCandidate<'_>,
240 statement_location: (BlockId, usize),
241) -> Option<FixInfo> {
242 let (arm_idx, arm) = candidate
243 .match_arms
244 .iter()
245 .find_position(
246 |arm| matches!(&arm.arm_selector, MatchArmSelector::VariantId(v) if v == variant),
247 )
248 .expect("arm not found.");
249
250 let [var_id] = arm.var_ids.as_slice() else {
251 panic!("An arm of an EnumMatch should produce a single variable.");
252 };
253
254 let mut remapping = VarRemapping::default();
257 remapping.insert(*var_id, *input);
258
259 let mut demand = std::mem::take(&mut candidate.arm_demands[arm_idx]);
263
264 let additional_stmts = candidate
265 .statement_rev
266 .iter()
267 .rev()
268 .skip(candidate.n_same_block_statement)
269 .cloned()
270 .cloned()
271 .collect_vec();
272 for stmt in &additional_stmts {
273 demand.update(stmt);
274 }
275
276 demand
277 .apply_remapping(&mut EmptyDemandReporter {}, [(var_id, (&input.var_id, ()))].into_iter());
278
279 let additional_remappings = match candidate.remapping {
280 Some(remappings) => {
281 VarRemapping {
283 remapping: OrderedHashMap::from_iter(remappings.iter().filter_map(|(dst, src)| {
284 if demand.vars.contains_key(dst) { Some((*dst, *src)) } else { None }
285 })),
286 }
287 }
288 None => VarRemapping::default(),
289 };
290
291 if !additional_remappings.is_empty() && candidate.future_merge {
292 return None;
294 }
295
296 demand.apply_remapping(
297 &mut EmptyDemandReporter {},
298 additional_remappings.iter().map(|(dst, src_var_usage)| (dst, (&src_var_usage.var_id, ()))),
299 );
300
301 for stmt in candidate.statement_rev.iter().rev() {
302 demand.update(stmt);
303 }
304 info.demand = demand;
305 info.reachable_blocks = std::mem::take(&mut candidate.arm_reachable_blocks[arm_idx]);
306
307 Some(FixInfo {
308 statement_location,
309 match_block: candidate.match_block,
310 arm_idx,
311 target_block: arm.block_id,
312 remapping,
313 reachable_blocks: info.reachable_blocks.clone(),
314 additional_remappings,
315 n_same_block_statement: candidate.n_same_block_statement,
316 remove_enum_construct: !info.demand.vars.contains_key(output),
317 additional_stmts,
318 })
319}
320
321pub struct FixInfo {
322 statement_location: (BlockId, usize),
324 match_block: BlockId,
326 arm_idx: usize,
328 target_block: BlockId,
330 remapping: VarRemapping,
332 reachable_blocks: OrderedHashSet<BlockId>,
334 additional_remappings: VarRemapping,
336 n_same_block_statement: usize,
338 remove_enum_construct: bool,
340 additional_stmts: Vec<Statement>,
343}
344
345#[derive(Clone)]
346struct OptimizationCandidate<'a> {
347 match_variable: VariableId,
349
350 match_arms: &'a [MatchArm],
352
353 match_block: BlockId,
355
356 arm_demands: Vec<MatchOptimizerDemand>,
358
359 future_merge: bool,
361
362 arm_reachable_blocks: Vec<OrderedHashSet<BlockId>>,
364
365 remapping: Option<&'a VarRemapping>,
370
371 statement_rev: Vec<&'a Statement>,
373
374 n_same_block_statement: usize,
376}
377
378pub struct MatchOptimizerContext {
379 fixes: Vec<FixInfo>,
380}
381
382#[derive(Clone)]
383pub struct AnalysisInfo<'a> {
384 candidate: Option<OptimizationCandidate<'a>>,
385 demand: MatchOptimizerDemand,
386 reachable_blocks: OrderedHashSet<BlockId>,
388}
389
390impl<'a> Analyzer<'a> for MatchOptimizerContext {
391 type Info = AnalysisInfo<'a>;
392
393 fn visit_block_start(&mut self, info: &mut Self::Info, block_id: BlockId, _block: &Block) {
394 info.reachable_blocks.insert(block_id);
395 }
396
397 fn visit_stmt(
398 &mut self,
399 info: &mut Self::Info,
400 statement_location: StatementLocation,
401 stmt: &'a Statement,
402 ) {
403 if let Some(mut candidate) = info.candidate.take() {
404 match stmt {
405 Statement::EnumConstruct(enum_construct_stmt)
406 if enum_construct_stmt.output == candidate.match_variable =>
407 {
408 if let Some(fix_info) = try_get_fix_info(
409 enum_construct_stmt,
410 info,
411 &mut candidate,
412 statement_location,
413 ) {
414 self.fixes.push(fix_info);
415 return;
416 }
417
418 info.candidate = None;
421 }
422 _ => {
423 candidate.statement_rev.push(stmt);
424 candidate.n_same_block_statement += 1;
425 info.candidate = Some(candidate);
426 }
427 }
428 }
429
430 info.demand.update(stmt);
431 }
432
433 fn visit_goto(
434 &mut self,
435 info: &mut Self::Info,
436 _statement_location: StatementLocation,
437 _target_block_id: BlockId,
438 remapping: &'a VarRemapping,
439 ) {
440 info.demand.apply_remapping(
441 &mut EmptyDemandReporter {},
442 remapping.iter().map(|(dst, src)| (dst, (&src.var_id, ()))),
443 );
444
445 let Some(ref mut candidate) = &mut info.candidate else {
446 return;
447 };
448
449 candidate.n_same_block_statement = 0;
452
453 if candidate.future_merge
454 && candidate.statement_rev.iter().any(|stmt| !stmt.outputs().is_empty())
455 {
456 info.candidate = None;
459 return;
460 }
461
462 if remapping.is_empty() {
463 return;
464 }
465
466 if candidate.remapping.is_some() {
467 info.candidate = None;
468 return;
469 }
470
471 candidate.remapping = Some(remapping);
473 if let Some(var_usage) = remapping.get(&candidate.match_variable) {
474 candidate.match_variable = var_usage.var_id;
475 }
476 }
477
478 fn merge_match(
479 &mut self,
480 (block_id, _statement_idx): StatementLocation,
481 match_info: &'a MatchInfo,
482 infos: impl Iterator<Item = Self::Info>,
483 ) -> Self::Info {
484 let (arm_demands, arm_reachable_blocks): (Vec<_>, Vec<_>) =
485 infos.map(|info| (info.demand, info.reachable_blocks)).unzip();
486
487 let arm_demands_without_arm_var = zip_eq(match_info.arms(), &arm_demands)
488 .map(|(arm, demand)| {
489 let mut demand = demand.clone();
490 demand.variables_introduced(&mut EmptyDemandReporter {}, &arm.var_ids, ());
492
493 (demand, ())
494 })
495 .collect_vec();
496 let mut demand = MatchOptimizerDemand::merge_demands(
497 &arm_demands_without_arm_var,
498 &mut EmptyDemandReporter {},
499 );
500
501 let mut reachable_blocks = OrderedHashSet::default();
503 let mut max_possible_size = 0;
504 for cur_reachable_blocks in &arm_reachable_blocks {
505 reachable_blocks.extend(cur_reachable_blocks.iter().cloned());
506 max_possible_size += cur_reachable_blocks.len();
507 }
508 let found_collision = reachable_blocks.len() < max_possible_size;
511
512 let candidate = match match_info {
513 MatchInfo::Enum(MatchEnumInfo { input, arms, .. })
516 if !demand.vars.contains_key(&input.var_id) =>
517 {
518 Some(OptimizationCandidate {
519 match_variable: input.var_id,
520 match_arms: arms,
521 match_block: block_id,
522 arm_demands,
523 future_merge: found_collision,
524 arm_reachable_blocks,
525 remapping: None,
526 statement_rev: vec![],
527 n_same_block_statement: 0,
528 })
529 }
530 _ => None,
531 };
532
533 demand.variables_used(
534 &mut EmptyDemandReporter {},
535 match_info.inputs().iter().map(|VarUsage { var_id, .. }| (var_id, ())),
536 );
537
538 Self::Info { candidate, demand, reachable_blocks }
539 }
540
541 fn info_from_return(
542 &mut self,
543 _statement_location: StatementLocation,
544 vars: &[VarUsage],
545 ) -> Self::Info {
546 let mut demand = MatchOptimizerDemand::default();
547 demand.variables_used(
548 &mut EmptyDemandReporter {},
549 vars.iter().map(|VarUsage { var_id, .. }| (var_id, ())),
550 );
551 Self::Info { candidate: None, demand, reachable_blocks: Default::default() }
552 }
553}