1#[cfg(test)]
2#[path = "return_optimization_test.rs"]
3mod test;
4
5use cairo_lang_semantic::{self as semantic, ConcreteTypeId, TypeId, TypeLongId};
6use cairo_lang_utils::unordered_hash_map::UnorderedHashMap;
7use cairo_lang_utils::{Intern, require};
8use semantic::MatchArmSelector;
9
10use crate::borrow_check::analysis::{Analyzer, BackAnalysis, StatementLocation};
11use crate::db::LoweringGroup;
12use crate::ids::{ConcreteFunctionWithBodyId, LocationId};
13use crate::lower::context::{VarRequest, VariableAllocator};
14use crate::{
15 BlockId, FlatBlock, FlatBlockEnd, FlatLowered, MatchArm, MatchEnumInfo, MatchInfo, Statement,
16 StatementEnumConstruct, StatementStructConstruct, StatementStructDestructure, VarRemapping,
17 VarUsage, VariableId,
18};
19
20pub fn return_optimization(
26 db: &dyn LoweringGroup,
27 function_id: ConcreteFunctionWithBodyId,
28 lowered: &mut FlatLowered,
29) {
30 if lowered.blocks.is_empty() {
31 return;
32 }
33 let ctx = ReturnOptimizerContext { db, lowered, fixes: vec![] };
34 let mut analysis = BackAnalysis::new(lowered, ctx);
35 analysis.get_root_info();
36 let ctx = analysis.analyzer;
37
38 let mut variables = VariableAllocator::new(
39 db,
40 function_id.function_with_body_id(db).base_semantic_function(db),
41 lowered.variables.clone(),
42 )
43 .unwrap();
44
45 for FixInfo { location: (block_id, statement_idx), return_info } in ctx.fixes {
46 let block = &mut lowered.blocks[block_id];
47 block.statements.truncate(statement_idx);
48 let mut ctx = EarlyReturnContext {
49 constructed: UnorderedHashMap::default(),
50 variables: &mut variables,
51 statements: &mut block.statements,
52 location: return_info.location,
53 };
54 let vars = ctx.prepare_early_return_vars(&return_info.returned_vars);
55 block.end = FlatBlockEnd::Return(vars, return_info.location)
56 }
57
58 lowered.variables = variables.variables;
59}
60
61struct EarlyReturnContext<'a, 'b> {
63 constructed: UnorderedHashMap<(TypeId, Vec<VariableId>), VariableId>,
66 variables: &'a mut VariableAllocator<'b>,
68 statements: &'a mut Vec<Statement>,
70 location: LocationId,
72}
73
74impl EarlyReturnContext<'_, '_> {
75 fn prepare_early_return_vars(&mut self, ret_infos: &[ValueInfo]) -> Vec<VarUsage> {
79 let mut res = vec![];
80
81 for var_info in ret_infos.iter() {
82 match var_info {
83 ValueInfo::Var(var_usage) => {
84 res.push(*var_usage);
85 }
86 ValueInfo::StructConstruct { ty, var_infos } => {
87 let inputs = self.prepare_early_return_vars(var_infos);
88 let output = *self
89 .constructed
90 .entry((*ty, inputs.iter().map(|var_usage| var_usage.var_id).collect()))
91 .or_insert_with(|| {
92 let output = self
93 .variables
94 .new_var(VarRequest { ty: *ty, location: self.location });
95 self.statements.push(Statement::StructConstruct(
96 StatementStructConstruct { inputs, output },
97 ));
98 output
99 });
100 res.push(VarUsage { var_id: output, location: self.location });
101 }
102 ValueInfo::EnumConstruct { var_info, variant } => {
103 let input = self.prepare_early_return_vars(std::slice::from_ref(var_info))[0];
104
105 let ty = TypeLongId::Concrete(ConcreteTypeId::Enum(variant.concrete_enum_id))
106 .intern(self.variables.db);
107
108 let output =
109 *self.constructed.entry((ty, vec![input.var_id])).or_insert_with(|| {
110 let output =
111 self.variables.new_var(VarRequest { ty, location: self.location });
112 self.statements.push(Statement::EnumConstruct(
113 StatementEnumConstruct { variant: variant.clone(), input, output },
114 ));
115 output
116 });
117 res.push(VarUsage { var_id: output, location: self.location });
118 }
119 ValueInfo::Interchangeable(_) => {
120 unreachable!("early_return_possible should have prevented this.")
121 }
122 }
123 }
124
125 res
126 }
127}
128
129pub struct ReturnOptimizerContext<'a> {
130 db: &'a dyn LoweringGroup,
131 lowered: &'a FlatLowered,
132
133 fixes: Vec<FixInfo>,
135}
136impl ReturnOptimizerContext<'_> {
137 fn get_var_info(&self, var_usage: &VarUsage) -> ValueInfo {
139 let var_ty = &self.lowered.variables[var_usage.var_id].ty;
140 if self.is_droppable(var_usage.var_id) && self.db.single_value_type(*var_ty).unwrap() {
141 ValueInfo::Interchangeable(*var_ty)
142 } else {
143 ValueInfo::Var(*var_usage)
144 }
145 }
146
147 fn is_droppable(&self, var_id: VariableId) -> bool {
149 self.lowered.variables[var_id].droppable.is_ok()
150 }
151
152 fn try_merge_match(
155 &mut self,
156 match_info: &MatchInfo,
157 infos: &[AnalyzerInfo],
158 ) -> Option<ReturnInfo> {
159 let MatchInfo::Enum(MatchEnumInfo { input, arms, .. }) = match_info else {
160 return None;
161 };
162 require(!arms.is_empty())?;
163
164 let input_info = self.get_var_info(input);
165 let mut opt_last_info = None;
166 for (arm, info) in arms.iter().zip(infos) {
167 let mut curr_info = info.clone();
168 curr_info.apply_match_arm(self.is_droppable(input.var_id), &input_info, arm);
169
170 match curr_info.try_get_early_return_info() {
171 Some(return_info)
172 if opt_last_info
173 .map(|x: ReturnInfo| x.returned_vars == return_info.returned_vars)
174 .unwrap_or(true) =>
175 {
176 opt_last_info = Some(return_info.clone())
179 }
180 _ => return None,
181 }
182 }
183
184 Some(opt_last_info.unwrap())
185 }
186}
187
188pub struct FixInfo {
190 location: StatementLocation,
192 return_info: ReturnInfo,
194}
195
196#[derive(Clone, Debug, PartialEq, Eq)]
198pub enum ValueInfo {
199 Var(VarUsage),
201 Interchangeable(semantic::TypeId),
203 StructConstruct {
205 ty: semantic::TypeId,
207 var_infos: Vec<ValueInfo>,
209 },
210 EnumConstruct {
212 var_info: Box<ValueInfo>,
214 variant: semantic::ConcreteVariant,
216 },
217}
218
219enum OpResult {
221 InputConsumed,
223 ValueInvalidated,
225 NoChange,
227}
228
229impl ValueInfo {
230 fn apply<F>(&mut self, f: &F)
232 where
233 F: Fn(&VarUsage) -> ValueInfo,
234 {
235 match self {
236 ValueInfo::Var(var_usage) => *self = f(var_usage),
237 ValueInfo::StructConstruct { ty: _, ref mut var_infos } => {
238 for var_info in var_infos.iter_mut() {
239 var_info.apply(f);
240 }
241 }
242 ValueInfo::EnumConstruct { ref mut var_info, .. } => {
243 var_info.apply(f);
244 }
245 ValueInfo::Interchangeable(_) => {}
246 }
247 }
248
249 fn apply_deconstruct(
252 &mut self,
253 ctx: &ReturnOptimizerContext<'_>,
254 stmt: &StatementStructDestructure,
255 ) -> OpResult {
256 match self {
257 ValueInfo::Var(var_usage) => {
258 if stmt.outputs.contains(&var_usage.var_id) {
259 OpResult::ValueInvalidated
260 } else {
261 OpResult::NoChange
262 }
263 }
264 ValueInfo::StructConstruct { ty, var_infos } => {
265 let mut cancels_out = ty == &ctx.lowered.variables[stmt.input.var_id].ty
266 && var_infos.len() == stmt.outputs.len();
267 for (var_info, output) in var_infos.iter().zip(stmt.outputs.iter()) {
268 if !cancels_out {
269 break;
270 }
271
272 match var_info {
273 ValueInfo::Var(var_usage) if &var_usage.var_id == output => {}
274 ValueInfo::Interchangeable(ty)
275 if &ctx.lowered.variables[*output].ty == ty => {}
276 _ => cancels_out = false,
277 }
278 }
279
280 if cancels_out {
281 *self = ValueInfo::Var(stmt.input);
284 return OpResult::InputConsumed;
285 }
286
287 let mut input_consumed = false;
288 for var_info in var_infos.iter_mut() {
289 match var_info.apply_deconstruct(ctx, stmt) {
290 OpResult::InputConsumed => {
291 input_consumed = true;
292 }
293 OpResult::ValueInvalidated => {
294 return OpResult::ValueInvalidated;
297 }
298 OpResult::NoChange => {}
299 }
300 }
301
302 match input_consumed {
303 true => OpResult::InputConsumed,
304 false => OpResult::NoChange,
305 }
306 }
307 ValueInfo::EnumConstruct { ref mut var_info, .. } => {
308 var_info.apply_deconstruct(ctx, stmt)
309 }
310 ValueInfo::Interchangeable(_) => OpResult::NoChange,
311 }
312 }
313
314 fn apply_match_arm(&mut self, input: &ValueInfo, arm: &MatchArm) -> OpResult {
317 match self {
318 ValueInfo::Var(var_usage) => {
319 if arm.var_ids == [var_usage.var_id] {
320 OpResult::ValueInvalidated
321 } else {
322 OpResult::NoChange
323 }
324 }
325 ValueInfo::StructConstruct { ty: _, ref mut var_infos } => {
326 let mut input_consumed = false;
327 for var_info in var_infos.iter_mut() {
328 match var_info.apply_match_arm(input, arm) {
329 OpResult::InputConsumed => {
330 input_consumed = true;
331 }
332 OpResult::ValueInvalidated => return OpResult::ValueInvalidated,
333 OpResult::NoChange => {}
334 }
335 }
336
337 if input_consumed {
338 return OpResult::InputConsumed;
339 }
340 OpResult::NoChange
341 }
342 ValueInfo::EnumConstruct { ref mut var_info, variant } => {
343 let MatchArmSelector::VariantId(arm_variant) = &arm.arm_selector else {
344 panic!("Enum construct should not appear in value match");
345 };
346
347 if *variant == *arm_variant {
348 let cancels_out = match **var_info {
349 ValueInfo::Interchangeable(_) => true,
350 ValueInfo::Var(var_usage) if arm.var_ids == [var_usage.var_id] => true,
351 _ => false,
352 };
353
354 if cancels_out {
355 *self = input.clone();
358 return OpResult::InputConsumed;
359 }
360 }
361
362 var_info.apply_match_arm(input, arm)
363 }
364 ValueInfo::Interchangeable(_) => OpResult::NoChange,
365 }
366 }
367}
368
369#[derive(Clone, Debug, PartialEq, Eq)]
373pub struct ReturnInfo {
374 returned_vars: Vec<ValueInfo>,
375 location: LocationId,
376}
377
378#[derive(Clone, Debug, PartialEq, Eq)]
384pub struct AnalyzerInfo {
385 opt_return_info: Option<ReturnInfo>,
386}
387
388impl AnalyzerInfo {
389 fn invalidated() -> Self {
391 AnalyzerInfo { opt_return_info: None }
392 }
393
394 fn invalidate(&mut self) {
396 *self = Self::invalidated();
397 }
398
399 fn apply<F>(&mut self, f: &F)
401 where
402 F: Fn(&VarUsage) -> ValueInfo,
403 {
404 let Some(ReturnInfo { ref mut returned_vars, .. }) = self.opt_return_info else {
405 return;
406 };
407
408 for var_info in returned_vars.iter_mut() {
409 var_info.apply(f)
410 }
411 }
412
413 fn replace(&mut self, var_id: VariableId, var_info: ValueInfo) {
415 self.apply(&|var_usage| {
416 if var_usage.var_id == var_id { var_info.clone() } else { ValueInfo::Var(*var_usage) }
417 });
418 }
419
420 fn apply_deconstruct(
422 &mut self,
423 ctx: &ReturnOptimizerContext<'_>,
424 stmt: &StatementStructDestructure,
425 ) {
426 let Some(ReturnInfo { ref mut returned_vars, .. }) = self.opt_return_info else { return };
427
428 let mut input_consumed = false;
429 for var_info in returned_vars.iter_mut() {
430 match var_info.apply_deconstruct(ctx, stmt) {
431 OpResult::InputConsumed => {
432 input_consumed = true;
433 }
434 OpResult::ValueInvalidated => {
435 self.invalidate();
436 return;
437 }
438 OpResult::NoChange => {}
439 };
440 }
441
442 if !(input_consumed || ctx.is_droppable(stmt.input.var_id)) {
443 self.invalidate();
444 }
445 }
446
447 fn apply_match_arm(&mut self, is_droppable: bool, input: &ValueInfo, arm: &MatchArm) {
449 let Some(ReturnInfo { ref mut returned_vars, .. }) = self.opt_return_info else { return };
450
451 let mut input_consumed = false;
452 for var_info in returned_vars.iter_mut() {
453 match var_info.apply_match_arm(input, arm) {
454 OpResult::InputConsumed => {
455 input_consumed = true;
456 }
457 OpResult::ValueInvalidated => {
458 self.invalidate();
459 return;
460 }
461 OpResult::NoChange => {}
462 };
463 }
464
465 if !(input_consumed || is_droppable) {
466 self.invalidate();
467 }
468 }
469
470 fn try_get_early_return_info(&self) -> Option<&ReturnInfo> {
472 let return_info = self.opt_return_info.as_ref()?;
473
474 let mut stack = return_info.returned_vars.clone();
475 while let Some(var_info) = stack.pop() {
476 match var_info {
477 ValueInfo::Var(_) => {}
478 ValueInfo::StructConstruct { ty: _, var_infos } => stack.extend(var_infos),
479 ValueInfo::EnumConstruct { var_info, variant: _ } => stack.push(*var_info),
480 ValueInfo::Interchangeable(_) => return None,
481 }
482 }
483
484 Some(return_info)
485 }
486}
487
488impl<'a> Analyzer<'a> for ReturnOptimizerContext<'_> {
489 type Info = AnalyzerInfo;
490
491 fn visit_block_start(&mut self, info: &mut Self::Info, block_id: BlockId, _block: &FlatBlock) {
492 if let Some(return_info) = info.try_get_early_return_info() {
493 self.fixes.push(FixInfo { location: (block_id, 0), return_info: return_info.clone() });
494 }
495 }
496
497 fn visit_stmt(
498 &mut self,
499 info: &mut Self::Info,
500 (block_idx, statement_idx): StatementLocation,
501 stmt: &'a Statement,
502 ) {
503 let opt_early_return_info = info.try_get_early_return_info().cloned();
504
505 match stmt {
506 Statement::StructConstruct(StatementStructConstruct { inputs, output }) => {
507 info.replace(
511 *output,
512 ValueInfo::StructConstruct {
513 ty: self.lowered.variables[*output].ty,
514 var_infos: inputs.iter().map(|input| self.get_var_info(input)).collect(),
515 },
516 );
517 }
518
519 Statement::StructDestructure(stmt) => info.apply_deconstruct(self, stmt),
520 Statement::EnumConstruct(StatementEnumConstruct { variant, input, output }) => {
521 info.replace(
522 *output,
523 ValueInfo::EnumConstruct {
524 var_info: Box::new(self.get_var_info(input)),
525 variant: variant.clone(),
526 },
527 );
528 }
529 _ => info.invalidate(),
530 }
531
532 if let Some(early_return_info) = opt_early_return_info {
533 if info.try_get_early_return_info().is_none() {
534 self.fixes.push(FixInfo {
535 location: (block_idx, statement_idx + 1),
536 return_info: early_return_info,
537 });
538 }
539 }
540 }
541
542 fn visit_goto(
543 &mut self,
544 info: &mut Self::Info,
545 _statement_location: StatementLocation,
546 _target_block_id: BlockId,
547 remapping: &VarRemapping,
548 ) {
549 info.apply(&|var_usage| {
550 if let Some(usage) = remapping.get(&var_usage.var_id) {
551 ValueInfo::Var(*usage)
552 } else {
553 ValueInfo::Var(*var_usage)
554 }
555 });
556 }
557
558 fn merge_match(
559 &mut self,
560 _statement_location: StatementLocation,
561 match_info: &'a MatchInfo,
562 infos: impl Iterator<Item = Self::Info>,
563 ) -> Self::Info {
564 let infos: Vec<_> = infos.collect();
565 let opt_return_info = self.try_merge_match(match_info, &infos);
566 Self::Info { opt_return_info }
567 }
568
569 fn info_from_return(
570 &mut self,
571 (block_id, _statement_idx): StatementLocation,
572 vars: &'a [VarUsage],
573 ) -> Self::Info {
574 let location = match &self.lowered.blocks[block_id].end {
575 FlatBlockEnd::Return(_vars, location) => *location,
576 _ => unreachable!(),
577 };
578
579 AnalyzerInfo {
582 opt_return_info: Some(ReturnInfo {
583 returned_vars: vars.iter().map(|var_usage| ValueInfo::Var(*var_usage)).collect(),
584 location,
585 }),
586 }
587 }
588}