1#[cfg(test)]
2#[path = "return_optimization_test.rs"]
3mod test;
4
5use cairo_lang_semantic::{self as semantic, ConcreteTypeId, TypeId, TypeLongId};
6use cairo_lang_utils::unordered_hash_map::UnorderedHashMap;
7use cairo_lang_utils::{Intern, require};
8use id_arena::Arena;
9use semantic::MatchArmSelector;
10
11use crate::borrow_check::analysis::{Analyzer, BackAnalysis, StatementLocation};
12use crate::db::LoweringGroup;
13use crate::ids::LocationId;
14use crate::{
15 Block, BlockEnd, BlockId, Lowered, MatchArm, MatchEnumInfo, MatchInfo, Statement,
16 StatementEnumConstruct, StatementStructConstruct, StatementStructDestructure, VarRemapping,
17 VarUsage, Variable, VariableId,
18};
19
20pub fn return_optimization(db: &dyn LoweringGroup, lowered: &mut Lowered) {
26 if lowered.blocks.is_empty() {
27 return;
28 }
29 let ctx = ReturnOptimizerContext { db, lowered, fixes: vec![] };
30 let mut analysis = BackAnalysis::new(lowered, ctx);
31 analysis.get_root_info();
32 let ctx = analysis.analyzer;
33
34 let ReturnOptimizerContext { fixes, .. } = ctx;
35 for FixInfo { location: (block_id, statement_idx), return_info } in fixes {
36 let block = &mut lowered.blocks[block_id];
37 block.statements.truncate(statement_idx);
38 let mut ctx = EarlyReturnContext {
39 db,
40 constructed: UnorderedHashMap::default(),
41 variables: &mut lowered.variables,
42 statements: &mut block.statements,
43 location: return_info.location,
44 };
45 let vars = ctx.prepare_early_return_vars(&return_info.returned_vars);
46 block.end = BlockEnd::Return(vars, return_info.location)
47 }
48}
49
50struct EarlyReturnContext<'a> {
52 db: &'a dyn LoweringGroup,
54 constructed: UnorderedHashMap<(TypeId, Vec<VariableId>), VariableId>,
57 variables: &'a mut Arena<Variable>,
59 statements: &'a mut Vec<Statement>,
61 location: LocationId,
63}
64
65impl EarlyReturnContext<'_> {
66 fn prepare_early_return_vars(&mut self, ret_infos: &[ValueInfo]) -> Vec<VarUsage> {
70 let mut res = vec![];
71
72 for var_info in ret_infos.iter() {
73 match var_info {
74 ValueInfo::Var(var_usage) => {
75 res.push(*var_usage);
76 }
77 ValueInfo::StructConstruct { ty, var_infos } => {
78 let inputs = self.prepare_early_return_vars(var_infos);
79 let output = *self
80 .constructed
81 .entry((*ty, inputs.iter().map(|var_usage| var_usage.var_id).collect()))
82 .or_insert_with(|| {
83 let output = self.variables.alloc(Variable::with_default_context(
84 self.db,
85 *ty,
86 self.location,
87 ));
88 self.statements.push(Statement::StructConstruct(
89 StatementStructConstruct { inputs, output },
90 ));
91 output
92 });
93 res.push(VarUsage { var_id: output, location: self.location });
94 }
95 ValueInfo::EnumConstruct { var_info, variant } => {
96 let input = self.prepare_early_return_vars(std::slice::from_ref(var_info))[0];
97
98 let ty = TypeLongId::Concrete(ConcreteTypeId::Enum(variant.concrete_enum_id))
99 .intern(self.db);
100
101 let output =
102 *self.constructed.entry((ty, vec![input.var_id])).or_insert_with(|| {
103 let output = self.variables.alloc(Variable::with_default_context(
104 self.db,
105 ty,
106 self.location,
107 ));
108 self.statements.push(Statement::EnumConstruct(
109 StatementEnumConstruct { variant: *variant, input, output },
110 ));
111 output
112 });
113 res.push(VarUsage { var_id: output, location: self.location });
114 }
115 ValueInfo::Interchangeable(_) => {
116 unreachable!("early_return_possible should have prevented this.")
117 }
118 }
119 }
120
121 res
122 }
123}
124
125pub struct ReturnOptimizerContext<'a> {
126 db: &'a dyn LoweringGroup,
127 lowered: &'a Lowered,
128
129 fixes: Vec<FixInfo>,
131}
132impl ReturnOptimizerContext<'_> {
133 fn get_var_info(&self, var_usage: &VarUsage) -> ValueInfo {
135 let var_ty = &self.lowered.variables[var_usage.var_id].ty;
136 if self.is_droppable(var_usage.var_id) && self.db.single_value_type(*var_ty).unwrap() {
137 ValueInfo::Interchangeable(*var_ty)
138 } else {
139 ValueInfo::Var(*var_usage)
140 }
141 }
142
143 fn is_droppable(&self, var_id: VariableId) -> bool {
145 self.lowered.variables[var_id].info.droppable.is_ok()
146 }
147
148 fn try_merge_match(
151 &mut self,
152 match_info: &MatchInfo,
153 infos: impl Iterator<Item = AnalyzerInfo>,
154 ) -> Option<ReturnInfo> {
155 let MatchInfo::Enum(MatchEnumInfo { input, arms, .. }) = match_info else {
156 return None;
157 };
158 require(!arms.is_empty())?;
159
160 let input_info = self.get_var_info(input);
161 let mut opt_last_info = None;
162 for (arm, info) in arms.iter().zip(infos) {
163 let mut curr_info = info.clone();
164 curr_info.apply_match_arm(self.is_droppable(input.var_id), &input_info, arm);
165
166 match curr_info.try_get_early_return_info() {
167 Some(return_info)
168 if opt_last_info
169 .map(|x: ReturnInfo| x.returned_vars == return_info.returned_vars)
170 .unwrap_or(true) =>
171 {
172 opt_last_info = Some(return_info.clone())
175 }
176 _ => return None,
177 }
178 }
179
180 Some(opt_last_info.unwrap())
181 }
182}
183
184pub struct FixInfo {
186 location: StatementLocation,
188 return_info: ReturnInfo,
190}
191
192#[derive(Clone, Debug, PartialEq, Eq)]
194pub enum ValueInfo {
195 Var(VarUsage),
197 Interchangeable(semantic::TypeId),
199 StructConstruct {
201 ty: semantic::TypeId,
203 var_infos: Vec<ValueInfo>,
205 },
206 EnumConstruct {
208 var_info: Box<ValueInfo>,
210 variant: semantic::ConcreteVariant,
212 },
213}
214
215enum OpResult {
217 InputConsumed,
219 ValueInvalidated,
221 NoChange,
223}
224
225impl ValueInfo {
226 fn apply<F>(&mut self, f: &F)
228 where
229 F: Fn(&VarUsage) -> ValueInfo,
230 {
231 match self {
232 ValueInfo::Var(var_usage) => *self = f(var_usage),
233 ValueInfo::StructConstruct { ty: _, ref mut var_infos } => {
234 for var_info in var_infos.iter_mut() {
235 var_info.apply(f);
236 }
237 }
238 ValueInfo::EnumConstruct { ref mut var_info, .. } => {
239 var_info.apply(f);
240 }
241 ValueInfo::Interchangeable(_) => {}
242 }
243 }
244
245 fn apply_deconstruct(
248 &mut self,
249 ctx: &ReturnOptimizerContext<'_>,
250 stmt: &StatementStructDestructure,
251 ) -> OpResult {
252 match self {
253 ValueInfo::Var(var_usage) => {
254 if stmt.outputs.contains(&var_usage.var_id) {
255 OpResult::ValueInvalidated
256 } else {
257 OpResult::NoChange
258 }
259 }
260 ValueInfo::StructConstruct { ty, var_infos } => {
261 let mut cancels_out = ty == &ctx.lowered.variables[stmt.input.var_id].ty
262 && var_infos.len() == stmt.outputs.len();
263 for (var_info, output) in var_infos.iter().zip(stmt.outputs.iter()) {
264 if !cancels_out {
265 break;
266 }
267
268 match var_info {
269 ValueInfo::Var(var_usage) if &var_usage.var_id == output => {}
270 ValueInfo::Interchangeable(ty)
271 if &ctx.lowered.variables[*output].ty == ty => {}
272 _ => cancels_out = false,
273 }
274 }
275
276 if cancels_out {
277 *self = ValueInfo::Var(stmt.input);
280 return OpResult::InputConsumed;
281 }
282
283 let mut input_consumed = false;
284 for var_info in var_infos.iter_mut() {
285 match var_info.apply_deconstruct(ctx, stmt) {
286 OpResult::InputConsumed => {
287 input_consumed = true;
288 }
289 OpResult::ValueInvalidated => {
290 return OpResult::ValueInvalidated;
293 }
294 OpResult::NoChange => {}
295 }
296 }
297
298 match input_consumed {
299 true => OpResult::InputConsumed,
300 false => OpResult::NoChange,
301 }
302 }
303 ValueInfo::EnumConstruct { ref mut var_info, .. } => {
304 var_info.apply_deconstruct(ctx, stmt)
305 }
306 ValueInfo::Interchangeable(_) => OpResult::NoChange,
307 }
308 }
309
310 fn apply_match_arm(&mut self, input: &ValueInfo, arm: &MatchArm) -> OpResult {
313 match self {
314 ValueInfo::Var(var_usage) => {
315 if arm.var_ids == [var_usage.var_id] {
316 OpResult::ValueInvalidated
317 } else {
318 OpResult::NoChange
319 }
320 }
321 ValueInfo::StructConstruct { ty: _, ref mut var_infos } => {
322 let mut input_consumed = false;
323 for var_info in var_infos.iter_mut() {
324 match var_info.apply_match_arm(input, arm) {
325 OpResult::InputConsumed => {
326 input_consumed = true;
327 }
328 OpResult::ValueInvalidated => return OpResult::ValueInvalidated,
329 OpResult::NoChange => {}
330 }
331 }
332
333 if input_consumed {
334 return OpResult::InputConsumed;
335 }
336 OpResult::NoChange
337 }
338 ValueInfo::EnumConstruct { ref mut var_info, variant } => {
339 let MatchArmSelector::VariantId(arm_variant) = &arm.arm_selector else {
340 panic!("Enum construct should not appear in value match");
341 };
342
343 if *variant == *arm_variant {
344 let cancels_out = match **var_info {
345 ValueInfo::Interchangeable(_) => true,
346 ValueInfo::Var(var_usage) if arm.var_ids == [var_usage.var_id] => true,
347 _ => false,
348 };
349
350 if cancels_out {
351 *self = input.clone();
354 return OpResult::InputConsumed;
355 }
356 }
357
358 var_info.apply_match_arm(input, arm)
359 }
360 ValueInfo::Interchangeable(_) => OpResult::NoChange,
361 }
362 }
363}
364
365#[derive(Clone, Debug, PartialEq, Eq)]
369pub struct ReturnInfo {
370 returned_vars: Vec<ValueInfo>,
371 location: LocationId,
372}
373
374#[derive(Clone, Debug, PartialEq, Eq)]
380pub struct AnalyzerInfo {
381 opt_return_info: Option<ReturnInfo>,
382}
383
384impl AnalyzerInfo {
385 fn invalidated() -> Self {
387 AnalyzerInfo { opt_return_info: None }
388 }
389
390 fn invalidate(&mut self) {
392 *self = Self::invalidated();
393 }
394
395 fn apply<F>(&mut self, f: &F)
397 where
398 F: Fn(&VarUsage) -> ValueInfo,
399 {
400 let Some(ReturnInfo { ref mut returned_vars, .. }) = self.opt_return_info else {
401 return;
402 };
403
404 for var_info in returned_vars.iter_mut() {
405 var_info.apply(f)
406 }
407 }
408
409 fn replace(&mut self, var_id: VariableId, var_info: ValueInfo) {
411 self.apply(&|var_usage| {
412 if var_usage.var_id == var_id { var_info.clone() } else { ValueInfo::Var(*var_usage) }
413 });
414 }
415
416 fn apply_deconstruct(
418 &mut self,
419 ctx: &ReturnOptimizerContext<'_>,
420 stmt: &StatementStructDestructure,
421 ) {
422 let Some(ReturnInfo { ref mut returned_vars, .. }) = self.opt_return_info else { return };
423
424 let mut input_consumed = false;
425 for var_info in returned_vars.iter_mut() {
426 match var_info.apply_deconstruct(ctx, stmt) {
427 OpResult::InputConsumed => {
428 input_consumed = true;
429 }
430 OpResult::ValueInvalidated => {
431 self.invalidate();
432 return;
433 }
434 OpResult::NoChange => {}
435 };
436 }
437
438 if !(input_consumed || ctx.is_droppable(stmt.input.var_id)) {
439 self.invalidate();
440 }
441 }
442
443 fn apply_match_arm(&mut self, is_droppable: bool, input: &ValueInfo, arm: &MatchArm) {
445 let Some(ReturnInfo { ref mut returned_vars, .. }) = self.opt_return_info else { return };
446
447 let mut input_consumed = false;
448 for var_info in returned_vars.iter_mut() {
449 match var_info.apply_match_arm(input, arm) {
450 OpResult::InputConsumed => {
451 input_consumed = true;
452 }
453 OpResult::ValueInvalidated => {
454 self.invalidate();
455 return;
456 }
457 OpResult::NoChange => {}
458 };
459 }
460
461 if !(input_consumed || is_droppable) {
462 self.invalidate();
463 }
464 }
465
466 fn try_get_early_return_info(&self) -> Option<&ReturnInfo> {
468 let return_info = self.opt_return_info.as_ref()?;
469
470 let mut stack = return_info.returned_vars.clone();
471 while let Some(var_info) = stack.pop() {
472 match var_info {
473 ValueInfo::Var(_) => {}
474 ValueInfo::StructConstruct { ty: _, var_infos } => stack.extend(var_infos),
475 ValueInfo::EnumConstruct { var_info, variant: _ } => stack.push(*var_info),
476 ValueInfo::Interchangeable(_) => return None,
477 }
478 }
479
480 Some(return_info)
481 }
482}
483
484impl<'a> Analyzer<'a> for ReturnOptimizerContext<'_> {
485 type Info = AnalyzerInfo;
486
487 fn visit_block_start(&mut self, info: &mut Self::Info, block_id: BlockId, _block: &Block) {
488 if let Some(return_info) = info.try_get_early_return_info() {
489 self.fixes.push(FixInfo { location: (block_id, 0), return_info: return_info.clone() });
490 }
491 }
492
493 fn visit_stmt(
494 &mut self,
495 info: &mut Self::Info,
496 (block_idx, statement_idx): StatementLocation,
497 stmt: &'a Statement,
498 ) {
499 let opt_early_return_info = info.try_get_early_return_info().cloned();
500
501 match stmt {
502 Statement::StructConstruct(StatementStructConstruct { inputs, output }) => {
503 info.replace(
507 *output,
508 ValueInfo::StructConstruct {
509 ty: self.lowered.variables[*output].ty,
510 var_infos: inputs.iter().map(|input| self.get_var_info(input)).collect(),
511 },
512 );
513 }
514
515 Statement::StructDestructure(stmt) => info.apply_deconstruct(self, stmt),
516 Statement::EnumConstruct(StatementEnumConstruct { variant, input, output }) => {
517 info.replace(
518 *output,
519 ValueInfo::EnumConstruct {
520 var_info: Box::new(self.get_var_info(input)),
521 variant: *variant,
522 },
523 );
524 }
525 _ => info.invalidate(),
526 }
527
528 if let Some(early_return_info) = opt_early_return_info {
529 if info.try_get_early_return_info().is_none() {
530 self.fixes.push(FixInfo {
531 location: (block_idx, statement_idx + 1),
532 return_info: early_return_info,
533 });
534 }
535 }
536 }
537
538 fn visit_goto(
539 &mut self,
540 info: &mut Self::Info,
541 _statement_location: StatementLocation,
542 _target_block_id: BlockId,
543 remapping: &VarRemapping,
544 ) {
545 info.apply(&|var_usage| {
546 if let Some(usage) = remapping.get(&var_usage.var_id) {
547 ValueInfo::Var(*usage)
548 } else {
549 ValueInfo::Var(*var_usage)
550 }
551 });
552 }
553
554 fn merge_match(
555 &mut self,
556 _statement_location: StatementLocation,
557 match_info: &'a MatchInfo,
558 infos: impl Iterator<Item = Self::Info>,
559 ) -> Self::Info {
560 Self::Info { opt_return_info: self.try_merge_match(match_info, infos) }
561 }
562
563 fn info_from_return(
564 &mut self,
565 (block_id, _statement_idx): StatementLocation,
566 vars: &'a [VarUsage],
567 ) -> Self::Info {
568 let location = match &self.lowered.blocks[block_id].end {
569 BlockEnd::Return(_vars, location) => *location,
570 _ => unreachable!(),
571 };
572
573 AnalyzerInfo {
576 opt_return_info: Some(ReturnInfo {
577 returned_vars: vars.iter().map(|var_usage| ValueInfo::Var(*var_usage)).collect(),
578 location,
579 }),
580 }
581 }
582}