1#[cfg(test)]
2#[path = "return_optimization_test.rs"]
3mod test;
4
5use cairo_lang_semantic::types::TypesSemantic;
6use cairo_lang_semantic::{self as semantic, ConcreteTypeId, TypeId, TypeLongId};
7use cairo_lang_utils::unordered_hash_map::UnorderedHashMap;
8use cairo_lang_utils::{Intern, require};
9use salsa::Database;
10use semantic::MatchArmSelector;
11
12use crate::analysis::{Analyzer, BackAnalysis, StatementLocation};
13use crate::ids::LocationId;
14use crate::{
15 Block, BlockEnd, BlockId, Lowered, MatchArm, MatchEnumInfo, MatchInfo, Statement,
16 StatementEnumConstruct, StatementStructConstruct, StatementStructDestructure, VarRemapping,
17 VarUsage, Variable, VariableArena, VariableId,
18};
19
20pub fn return_optimization<'db>(db: &'db dyn Database, lowered: &mut Lowered<'db>) {
26 if lowered.blocks.is_empty() {
27 return;
28 }
29 let ctx = ReturnOptimizerContext { db, lowered, fixes: vec![] };
30 let mut analysis = BackAnalysis::new(lowered, ctx);
31 analysis.get_root_info();
32 let ctx = analysis.analyzer;
33
34 let ReturnOptimizerContext { fixes, .. } = ctx;
35 for FixInfo { location: (block_id, statement_idx), return_info } in fixes {
36 let block = &mut lowered.blocks[block_id];
37 block.statements.truncate(statement_idx);
38 let mut ctx = EarlyReturnContext {
39 db,
40 constructed: UnorderedHashMap::default(),
41 variables: &mut lowered.variables,
42 statements: &mut block.statements,
43 location: return_info.location,
44 };
45 let vars = ctx.prepare_early_return_vars(&return_info.returned_vars);
46 block.end = BlockEnd::Return(vars, return_info.location)
47 }
48}
49
50struct EarlyReturnContext<'db, 'a> {
52 db: &'db dyn Database,
54 constructed: UnorderedHashMap<Construction<'db>, VariableId>,
57 variables: &'a mut VariableArena<'db>,
59 statements: &'a mut Vec<Statement<'db>>,
61 location: LocationId<'db>,
63}
64
65#[derive(Debug, Clone, PartialEq, Eq, Hash)]
68enum Construction<'db> {
69 Struct(TypeId<'db>, Vec<VariableId>),
71 Enum(semantic::ConcreteVariant<'db>, VariableId),
73}
74
75impl<'db, 'a> EarlyReturnContext<'db, 'a> {
76 fn prepare_early_return_vars(&mut self, ret_infos: &[ValueInfo<'db>]) -> Vec<VarUsage<'db>> {
80 let mut res = vec![];
81
82 for var_info in ret_infos.iter() {
83 match var_info {
84 ValueInfo::Var(var_usage) => {
85 res.push(*var_usage);
86 }
87 ValueInfo::StructConstruct { ty, var_infos } => {
88 let inputs = self.prepare_early_return_vars(var_infos);
89 let output = *self
90 .constructed
91 .entry(Construction::Struct(
92 *ty,
93 inputs.iter().map(|var_usage| var_usage.var_id).collect(),
94 ))
95 .or_insert_with(|| {
96 let output = self.variables.alloc(Variable::with_default_context(
97 self.db,
98 *ty,
99 self.location,
100 ));
101 self.statements.push(Statement::StructConstruct(
102 StatementStructConstruct { inputs, output },
103 ));
104 output
105 });
106 res.push(VarUsage { var_id: output, location: self.location });
107 }
108 ValueInfo::EnumConstruct { var_info, variant } => {
109 let input = self.prepare_early_return_vars(std::slice::from_ref(var_info))[0];
110
111 let ty = TypeLongId::Concrete(ConcreteTypeId::Enum(variant.concrete_enum_id))
112 .intern(self.db);
113
114 let output = *self
115 .constructed
116 .entry(Construction::Enum(*variant, input.var_id))
117 .or_insert_with(|| {
118 let output = self.variables.alloc(Variable::with_default_context(
119 self.db,
120 ty,
121 self.location,
122 ));
123 self.statements.push(Statement::EnumConstruct(
124 StatementEnumConstruct { variant: *variant, input, output },
125 ));
126 output
127 });
128 res.push(VarUsage { var_id: output, location: self.location });
129 }
130 ValueInfo::Interchangeable(_) => {
131 unreachable!("early_return_possible should have prevented this.")
132 }
133 }
134 }
135
136 res
137 }
138}
139
140pub struct ReturnOptimizerContext<'db, 'a> {
141 db: &'db dyn Database,
142 lowered: &'a Lowered<'db>,
143
144 fixes: Vec<FixInfo<'db>>,
146}
147impl<'db, 'a> ReturnOptimizerContext<'db, 'a> {
148 fn get_var_info(&self, var_usage: &VarUsage<'db>) -> ValueInfo<'db> {
150 let var_ty = &self.lowered.variables[var_usage.var_id].ty;
151 if self.is_droppable(var_usage.var_id) && self.db.single_value_type(*var_ty).unwrap() {
152 ValueInfo::Interchangeable(*var_ty)
153 } else {
154 ValueInfo::Var(*var_usage)
155 }
156 }
157
158 fn is_droppable(&self, var_id: VariableId) -> bool {
160 self.lowered.variables[var_id].info.droppable.is_ok()
161 }
162
163 fn try_merge_match(
166 &mut self,
167 match_info: &MatchInfo<'db>,
168 infos: impl Iterator<Item = AnalyzerInfo<'db>>,
169 ) -> Option<ReturnInfo<'db>> {
170 let MatchInfo::Enum(MatchEnumInfo { input, arms, .. }) = match_info else {
171 return None;
172 };
173 require(!arms.is_empty())?;
174
175 let input_info = self.get_var_info(input);
176 let mut opt_last_info = None;
177 for (arm, info) in arms.iter().zip(infos) {
178 let mut curr_info = info.clone();
179 curr_info.apply_match_arm(self.is_droppable(input.var_id), &input_info, arm);
180
181 match curr_info.try_get_early_return_info() {
182 Some(return_info)
183 if opt_last_info
184 .map(|x: ReturnInfo<'_>| x.returned_vars == return_info.returned_vars)
185 .unwrap_or(true) =>
186 {
187 opt_last_info = Some(return_info.clone())
190 }
191 _ => return None,
192 }
193 }
194
195 Some(opt_last_info.unwrap())
196 }
197}
198
199pub struct FixInfo<'db> {
201 location: StatementLocation,
203 return_info: ReturnInfo<'db>,
205}
206
207#[derive(Clone, Debug, PartialEq, Eq)]
209pub enum ValueInfo<'db> {
210 Var(VarUsage<'db>),
212 Interchangeable(semantic::TypeId<'db>),
214 StructConstruct {
216 ty: semantic::TypeId<'db>,
218 var_infos: Vec<ValueInfo<'db>>,
220 },
221 EnumConstruct {
223 var_info: Box<ValueInfo<'db>>,
225 variant: semantic::ConcreteVariant<'db>,
227 },
228}
229
230enum OpResult {
232 InputConsumed,
234 ValueInvalidated,
236 NoChange,
238}
239
240impl<'db> ValueInfo<'db> {
241 fn apply<F>(&mut self, f: &F)
243 where
244 F: Fn(&VarUsage<'db>) -> ValueInfo<'db>,
245 {
246 match self {
247 ValueInfo::Var(var_usage) => *self = f(var_usage),
248 ValueInfo::StructConstruct { ty: _, var_infos } => {
249 for var_info in var_infos.iter_mut() {
250 var_info.apply(f);
251 }
252 }
253 ValueInfo::EnumConstruct { var_info, .. } => {
254 var_info.apply(f);
255 }
256 ValueInfo::Interchangeable(_) => {}
257 }
258 }
259
260 fn apply_deconstruct(
263 &mut self,
264 ctx: &ReturnOptimizerContext<'db, '_>,
265 stmt: &StatementStructDestructure<'db>,
266 ) -> OpResult {
267 match self {
268 ValueInfo::Var(var_usage) => {
269 if stmt.outputs.contains(&var_usage.var_id) {
270 OpResult::ValueInvalidated
271 } else {
272 OpResult::NoChange
273 }
274 }
275 ValueInfo::StructConstruct { ty, var_infos } => {
276 let mut cancels_out = ty == &ctx.lowered.variables[stmt.input.var_id].ty
277 && var_infos.len() == stmt.outputs.len();
278 for (var_info, output) in var_infos.iter().zip(stmt.outputs.iter()) {
279 if !cancels_out {
280 break;
281 }
282
283 match var_info {
284 ValueInfo::Var(var_usage) if &var_usage.var_id == output => {}
285 ValueInfo::Interchangeable(ty)
286 if &ctx.lowered.variables[*output].ty == ty => {}
287 _ => cancels_out = false,
288 }
289 }
290
291 if cancels_out {
292 *self = ValueInfo::Var(stmt.input);
295 return OpResult::InputConsumed;
296 }
297
298 let mut input_consumed = false;
299 for var_info in var_infos.iter_mut() {
300 match var_info.apply_deconstruct(ctx, stmt) {
301 OpResult::InputConsumed => {
302 input_consumed = true;
303 }
304 OpResult::ValueInvalidated => {
305 return OpResult::ValueInvalidated;
308 }
309 OpResult::NoChange => {}
310 }
311 }
312
313 match input_consumed {
314 true => OpResult::InputConsumed,
315 false => OpResult::NoChange,
316 }
317 }
318 ValueInfo::EnumConstruct { var_info, .. } => var_info.apply_deconstruct(ctx, stmt),
319 ValueInfo::Interchangeable(_) => OpResult::NoChange,
320 }
321 }
322
323 fn apply_match_arm(&mut self, input: &ValueInfo<'db>, arm: &MatchArm<'db>) -> OpResult {
326 match self {
327 ValueInfo::Var(var_usage) => {
328 if arm.var_ids == [var_usage.var_id] {
329 OpResult::ValueInvalidated
330 } else {
331 OpResult::NoChange
332 }
333 }
334 ValueInfo::StructConstruct { ty: _, var_infos } => {
335 let mut input_consumed = false;
336 for var_info in var_infos.iter_mut() {
337 match var_info.apply_match_arm(input, arm) {
338 OpResult::InputConsumed => {
339 input_consumed = true;
340 }
341 OpResult::ValueInvalidated => return OpResult::ValueInvalidated,
342 OpResult::NoChange => {}
343 }
344 }
345
346 if input_consumed {
347 return OpResult::InputConsumed;
348 }
349 OpResult::NoChange
350 }
351 ValueInfo::EnumConstruct { var_info, variant } => {
352 let MatchArmSelector::VariantId(arm_variant) = &arm.arm_selector else {
353 panic!("Enum construct should not appear in value match");
354 };
355
356 if *variant == *arm_variant {
357 let cancels_out = match **var_info {
358 ValueInfo::Interchangeable(_) => true,
359 ValueInfo::Var(var_usage) if arm.var_ids == [var_usage.var_id] => true,
360 _ => false,
361 };
362
363 if cancels_out {
364 *self = input.clone();
367 return OpResult::InputConsumed;
368 }
369 }
370
371 var_info.apply_match_arm(input, arm)
372 }
373 ValueInfo::Interchangeable(_) => OpResult::NoChange,
374 }
375 }
376}
377
378#[derive(Clone, Debug, PartialEq, Eq)]
382pub struct ReturnInfo<'db> {
383 returned_vars: Vec<ValueInfo<'db>>,
384 location: LocationId<'db>,
385}
386
387#[derive(Clone, Debug, PartialEq, Eq)]
393pub struct AnalyzerInfo<'db> {
394 opt_return_info: Option<ReturnInfo<'db>>,
395}
396
397impl<'db> AnalyzerInfo<'db> {
398 fn invalidated() -> Self {
400 AnalyzerInfo { opt_return_info: None }
401 }
402
403 fn invalidate(&mut self) {
405 *self = Self::invalidated();
406 }
407
408 fn apply<F>(&mut self, f: &F)
410 where
411 F: Fn(&VarUsage<'db>) -> ValueInfo<'db>,
412 {
413 let Some(ReturnInfo { ref mut returned_vars, .. }) = self.opt_return_info else {
414 return;
415 };
416
417 for var_info in returned_vars.iter_mut() {
418 var_info.apply(f)
419 }
420 }
421
422 fn replace(&mut self, var_id: VariableId, var_info: ValueInfo<'db>) {
424 self.apply(&|var_usage| {
425 if var_usage.var_id == var_id { var_info.clone() } else { ValueInfo::Var(*var_usage) }
426 });
427 }
428
429 fn apply_deconstruct(
431 &mut self,
432 ctx: &ReturnOptimizerContext<'db, '_>,
433 stmt: &StatementStructDestructure<'db>,
434 ) {
435 let Some(ReturnInfo { ref mut returned_vars, .. }) = self.opt_return_info else { return };
436
437 let mut input_consumed = false;
438 for var_info in returned_vars.iter_mut() {
439 match var_info.apply_deconstruct(ctx, stmt) {
440 OpResult::InputConsumed => {
441 input_consumed = true;
442 }
443 OpResult::ValueInvalidated => {
444 self.invalidate();
445 return;
446 }
447 OpResult::NoChange => {}
448 };
449 }
450
451 if !(input_consumed || ctx.is_droppable(stmt.input.var_id)) {
452 self.invalidate();
453 }
454 }
455
456 fn apply_match_arm(&mut self, is_droppable: bool, input: &ValueInfo<'db>, arm: &MatchArm<'db>) {
458 let Some(ReturnInfo { ref mut returned_vars, .. }) = self.opt_return_info else { return };
459
460 let mut input_consumed = false;
461 for var_info in returned_vars.iter_mut() {
462 match var_info.apply_match_arm(input, arm) {
463 OpResult::InputConsumed => {
464 input_consumed = true;
465 }
466 OpResult::ValueInvalidated => {
467 self.invalidate();
468 return;
469 }
470 OpResult::NoChange => {}
471 };
472 }
473
474 if !(input_consumed || is_droppable) {
475 self.invalidate();
476 }
477 }
478
479 fn try_get_early_return_info(&self) -> Option<&ReturnInfo<'db>> {
481 let return_info = self.opt_return_info.as_ref()?;
482
483 let mut stack = return_info.returned_vars.clone();
484 while let Some(var_info) = stack.pop() {
485 match var_info {
486 ValueInfo::Var(_) => {}
487 ValueInfo::StructConstruct { ty: _, var_infos } => stack.extend(var_infos),
488 ValueInfo::EnumConstruct { var_info, variant: _ } => stack.push(*var_info),
489 ValueInfo::Interchangeable(_) => return None,
490 }
491 }
492
493 Some(return_info)
494 }
495}
496
497impl<'db, 'a> Analyzer<'db, 'a> for ReturnOptimizerContext<'db, 'a> {
498 type Info = AnalyzerInfo<'db>;
499
500 fn visit_block_start(&mut self, info: &mut Self::Info, block_id: BlockId, _block: &Block<'db>) {
501 if let Some(return_info) = info.try_get_early_return_info() {
502 self.fixes.push(FixInfo { location: (block_id, 0), return_info: return_info.clone() });
503 }
504 }
505
506 fn visit_stmt(
507 &mut self,
508 info: &mut Self::Info,
509 (block_idx, statement_idx): StatementLocation,
510 stmt: &'a Statement<'db>,
511 ) {
512 let opt_early_return_info = info.try_get_early_return_info().cloned();
513
514 match stmt {
515 Statement::StructConstruct(StatementStructConstruct { inputs, output }) => {
516 info.replace(
520 *output,
521 ValueInfo::StructConstruct {
522 ty: self.lowered.variables[*output].ty,
523 var_infos: inputs.iter().map(|input| self.get_var_info(input)).collect(),
524 },
525 );
526 }
527
528 Statement::StructDestructure(stmt) => info.apply_deconstruct(self, stmt),
529 Statement::EnumConstruct(StatementEnumConstruct { variant, input, output }) => {
530 info.replace(
531 *output,
532 ValueInfo::EnumConstruct {
533 var_info: Box::new(self.get_var_info(input)),
534 variant: *variant,
535 },
536 );
537 }
538 _ => info.invalidate(),
539 }
540
541 if let Some(early_return_info) = opt_early_return_info
542 && info.try_get_early_return_info().is_none()
543 {
544 self.fixes.push(FixInfo {
545 location: (block_idx, statement_idx + 1),
546 return_info: early_return_info,
547 });
548 }
549 }
550
551 fn visit_goto(
552 &mut self,
553 info: &mut Self::Info,
554 _statement_location: StatementLocation,
555 _target_block_id: BlockId,
556 remapping: &VarRemapping<'db>,
557 ) {
558 info.apply(&|var_usage| {
559 if let Some(usage) = remapping.get(&var_usage.var_id) {
560 ValueInfo::Var(*usage)
561 } else {
562 ValueInfo::Var(*var_usage)
563 }
564 });
565 }
566
567 fn merge_match(
568 &mut self,
569 _statement_location: StatementLocation,
570 match_info: &'a MatchInfo<'db>,
571 infos: impl Iterator<Item = Self::Info>,
572 ) -> Self::Info {
573 Self::Info { opt_return_info: self.try_merge_match(match_info, infos) }
574 }
575
576 fn info_from_return(
577 &mut self,
578 (block_id, _statement_idx): StatementLocation,
579 vars: &'a [VarUsage<'db>],
580 ) -> Self::Info {
581 let location = match &self.lowered.blocks[block_id].end {
582 BlockEnd::Return(_vars, location) => *location,
583 _ => unreachable!(),
584 };
585
586 AnalyzerInfo {
589 opt_return_info: Some(ReturnInfo {
590 returned_vars: vars.iter().map(|var_usage| ValueInfo::Var(*var_usage)).collect(),
591 location,
592 }),
593 }
594 }
595}