1#[cfg(test)]
2#[path = "return_optimization_test.rs"]
3mod test;
4
5use cairo_lang_semantic::types::TypesSemantic;
6use cairo_lang_semantic::{self as semantic, ConcreteTypeId, TypeId, TypeLongId};
7use cairo_lang_utils::unordered_hash_map::UnorderedHashMap;
8use cairo_lang_utils::{Intern, require};
9use salsa::Database;
10use semantic::MatchArmSelector;
11
12use crate::borrow_check::analysis::{Analyzer, BackAnalysis, StatementLocation};
13use crate::ids::LocationId;
14use crate::{
15 Block, BlockEnd, BlockId, Lowered, MatchArm, MatchEnumInfo, MatchInfo, Statement,
16 StatementEnumConstruct, StatementStructConstruct, StatementStructDestructure, VarRemapping,
17 VarUsage, Variable, VariableArena, VariableId,
18};
19
20pub fn return_optimization<'db>(db: &'db dyn Database, lowered: &mut Lowered<'db>) {
26 if lowered.blocks.is_empty() {
27 return;
28 }
29 let ctx = ReturnOptimizerContext { db, lowered, fixes: vec![] };
30 let mut analysis = BackAnalysis::new(lowered, ctx);
31 analysis.get_root_info();
32 let ctx = analysis.analyzer;
33
34 let ReturnOptimizerContext { fixes, .. } = ctx;
35 for FixInfo { location: (block_id, statement_idx), return_info } in fixes {
36 let block = &mut lowered.blocks[block_id];
37 block.statements.truncate(statement_idx);
38 let mut ctx = EarlyReturnContext {
39 db,
40 constructed: UnorderedHashMap::default(),
41 variables: &mut lowered.variables,
42 statements: &mut block.statements,
43 location: return_info.location,
44 };
45 let vars = ctx.prepare_early_return_vars(&return_info.returned_vars);
46 block.end = BlockEnd::Return(vars, return_info.location)
47 }
48}
49
50struct EarlyReturnContext<'db, 'a> {
52 db: &'db dyn Database,
54 constructed: UnorderedHashMap<(TypeId<'db>, Vec<VariableId>), VariableId>,
57 variables: &'a mut VariableArena<'db>,
59 statements: &'a mut Vec<Statement<'db>>,
61 location: LocationId<'db>,
63}
64
65impl<'db, 'a> EarlyReturnContext<'db, 'a> {
66 fn prepare_early_return_vars(&mut self, ret_infos: &[ValueInfo<'db>]) -> Vec<VarUsage<'db>> {
70 let mut res = vec![];
71
72 for var_info in ret_infos.iter() {
73 match var_info {
74 ValueInfo::Var(var_usage) => {
75 res.push(*var_usage);
76 }
77 ValueInfo::StructConstruct { ty, var_infos } => {
78 let inputs = self.prepare_early_return_vars(var_infos);
79 let output = *self
80 .constructed
81 .entry((*ty, inputs.iter().map(|var_usage| var_usage.var_id).collect()))
82 .or_insert_with(|| {
83 let output = self.variables.alloc(Variable::with_default_context(
84 self.db,
85 *ty,
86 self.location,
87 ));
88 self.statements.push(Statement::StructConstruct(
89 StatementStructConstruct { inputs, output },
90 ));
91 output
92 });
93 res.push(VarUsage { var_id: output, location: self.location });
94 }
95 ValueInfo::EnumConstruct { var_info, variant } => {
96 let input = self.prepare_early_return_vars(std::slice::from_ref(var_info))[0];
97
98 let ty = TypeLongId::Concrete(ConcreteTypeId::Enum(variant.concrete_enum_id))
99 .intern(self.db);
100
101 let output =
102 *self.constructed.entry((ty, vec![input.var_id])).or_insert_with(|| {
103 let output = self.variables.alloc(Variable::with_default_context(
104 self.db,
105 ty,
106 self.location,
107 ));
108 self.statements.push(Statement::EnumConstruct(
109 StatementEnumConstruct { variant: *variant, input, output },
110 ));
111 output
112 });
113 res.push(VarUsage { var_id: output, location: self.location });
114 }
115 ValueInfo::Interchangeable(_) => {
116 unreachable!("early_return_possible should have prevented this.")
117 }
118 }
119 }
120
121 res
122 }
123}
124
125pub struct ReturnOptimizerContext<'db, 'a> {
126 db: &'db dyn Database,
127 lowered: &'a Lowered<'db>,
128
129 fixes: Vec<FixInfo<'db>>,
131}
132impl<'db, 'a> ReturnOptimizerContext<'db, 'a> {
133 fn get_var_info(&self, var_usage: &VarUsage<'db>) -> ValueInfo<'db> {
135 let var_ty = &self.lowered.variables[var_usage.var_id].ty;
136 if self.is_droppable(var_usage.var_id) && self.db.single_value_type(*var_ty).unwrap() {
137 ValueInfo::Interchangeable(*var_ty)
138 } else {
139 ValueInfo::Var(*var_usage)
140 }
141 }
142
143 fn is_droppable(&self, var_id: VariableId) -> bool {
145 self.lowered.variables[var_id].info.droppable.is_ok()
146 }
147
148 fn try_merge_match(
151 &mut self,
152 match_info: &MatchInfo<'db>,
153 infos: impl Iterator<Item = AnalyzerInfo<'db>>,
154 ) -> Option<ReturnInfo<'db>> {
155 let MatchInfo::Enum(MatchEnumInfo { input, arms, .. }) = match_info else {
156 return None;
157 };
158 require(!arms.is_empty())?;
159
160 let input_info = self.get_var_info(input);
161 let mut opt_last_info = None;
162 for (arm, info) in arms.iter().zip(infos) {
163 let mut curr_info = info.clone();
164 curr_info.apply_match_arm(self.is_droppable(input.var_id), &input_info, arm);
165
166 match curr_info.try_get_early_return_info() {
167 Some(return_info)
168 if opt_last_info
169 .map(|x: ReturnInfo<'_>| x.returned_vars == return_info.returned_vars)
170 .unwrap_or(true) =>
171 {
172 opt_last_info = Some(return_info.clone())
175 }
176 _ => return None,
177 }
178 }
179
180 Some(opt_last_info.unwrap())
181 }
182}
183
184pub struct FixInfo<'db> {
186 location: StatementLocation,
188 return_info: ReturnInfo<'db>,
190}
191
192#[derive(Clone, Debug, PartialEq, Eq)]
194pub enum ValueInfo<'db> {
195 Var(VarUsage<'db>),
197 Interchangeable(semantic::TypeId<'db>),
199 StructConstruct {
201 ty: semantic::TypeId<'db>,
203 var_infos: Vec<ValueInfo<'db>>,
205 },
206 EnumConstruct {
208 var_info: Box<ValueInfo<'db>>,
210 variant: semantic::ConcreteVariant<'db>,
212 },
213}
214
215enum OpResult {
217 InputConsumed,
219 ValueInvalidated,
221 NoChange,
223}
224
225impl<'db> ValueInfo<'db> {
226 fn apply<F>(&mut self, f: &F)
228 where
229 F: Fn(&VarUsage<'db>) -> ValueInfo<'db>,
230 {
231 match self {
232 ValueInfo::Var(var_usage) => *self = f(var_usage),
233 ValueInfo::StructConstruct { ty: _, var_infos } => {
234 for var_info in var_infos.iter_mut() {
235 var_info.apply(f);
236 }
237 }
238 ValueInfo::EnumConstruct { var_info, .. } => {
239 var_info.apply(f);
240 }
241 ValueInfo::Interchangeable(_) => {}
242 }
243 }
244
245 fn apply_deconstruct(
248 &mut self,
249 ctx: &ReturnOptimizerContext<'db, '_>,
250 stmt: &StatementStructDestructure<'db>,
251 ) -> OpResult {
252 match self {
253 ValueInfo::Var(var_usage) => {
254 if stmt.outputs.contains(&var_usage.var_id) {
255 OpResult::ValueInvalidated
256 } else {
257 OpResult::NoChange
258 }
259 }
260 ValueInfo::StructConstruct { ty, var_infos } => {
261 let mut cancels_out = ty == &ctx.lowered.variables[stmt.input.var_id].ty
262 && var_infos.len() == stmt.outputs.len();
263 for (var_info, output) in var_infos.iter().zip(stmt.outputs.iter()) {
264 if !cancels_out {
265 break;
266 }
267
268 match var_info {
269 ValueInfo::Var(var_usage) if &var_usage.var_id == output => {}
270 ValueInfo::Interchangeable(ty)
271 if &ctx.lowered.variables[*output].ty == ty => {}
272 _ => cancels_out = false,
273 }
274 }
275
276 if cancels_out {
277 *self = ValueInfo::Var(stmt.input);
280 return OpResult::InputConsumed;
281 }
282
283 let mut input_consumed = false;
284 for var_info in var_infos.iter_mut() {
285 match var_info.apply_deconstruct(ctx, stmt) {
286 OpResult::InputConsumed => {
287 input_consumed = true;
288 }
289 OpResult::ValueInvalidated => {
290 return OpResult::ValueInvalidated;
293 }
294 OpResult::NoChange => {}
295 }
296 }
297
298 match input_consumed {
299 true => OpResult::InputConsumed,
300 false => OpResult::NoChange,
301 }
302 }
303 ValueInfo::EnumConstruct { var_info, .. } => var_info.apply_deconstruct(ctx, stmt),
304 ValueInfo::Interchangeable(_) => OpResult::NoChange,
305 }
306 }
307
308 fn apply_match_arm(&mut self, input: &ValueInfo<'db>, arm: &MatchArm<'db>) -> OpResult {
311 match self {
312 ValueInfo::Var(var_usage) => {
313 if arm.var_ids == [var_usage.var_id] {
314 OpResult::ValueInvalidated
315 } else {
316 OpResult::NoChange
317 }
318 }
319 ValueInfo::StructConstruct { ty: _, var_infos } => {
320 let mut input_consumed = false;
321 for var_info in var_infos.iter_mut() {
322 match var_info.apply_match_arm(input, arm) {
323 OpResult::InputConsumed => {
324 input_consumed = true;
325 }
326 OpResult::ValueInvalidated => return OpResult::ValueInvalidated,
327 OpResult::NoChange => {}
328 }
329 }
330
331 if input_consumed {
332 return OpResult::InputConsumed;
333 }
334 OpResult::NoChange
335 }
336 ValueInfo::EnumConstruct { var_info, variant } => {
337 let MatchArmSelector::VariantId(arm_variant) = &arm.arm_selector else {
338 panic!("Enum construct should not appear in value match");
339 };
340
341 if *variant == *arm_variant {
342 let cancels_out = match **var_info {
343 ValueInfo::Interchangeable(_) => true,
344 ValueInfo::Var(var_usage) if arm.var_ids == [var_usage.var_id] => true,
345 _ => false,
346 };
347
348 if cancels_out {
349 *self = input.clone();
352 return OpResult::InputConsumed;
353 }
354 }
355
356 var_info.apply_match_arm(input, arm)
357 }
358 ValueInfo::Interchangeable(_) => OpResult::NoChange,
359 }
360 }
361}
362
363#[derive(Clone, Debug, PartialEq, Eq)]
367pub struct ReturnInfo<'db> {
368 returned_vars: Vec<ValueInfo<'db>>,
369 location: LocationId<'db>,
370}
371
372#[derive(Clone, Debug, PartialEq, Eq)]
378pub struct AnalyzerInfo<'db> {
379 opt_return_info: Option<ReturnInfo<'db>>,
380}
381
382impl<'db> AnalyzerInfo<'db> {
383 fn invalidated() -> Self {
385 AnalyzerInfo { opt_return_info: None }
386 }
387
388 fn invalidate(&mut self) {
390 *self = Self::invalidated();
391 }
392
393 fn apply<F>(&mut self, f: &F)
395 where
396 F: Fn(&VarUsage<'db>) -> ValueInfo<'db>,
397 {
398 let Some(ReturnInfo { ref mut returned_vars, .. }) = self.opt_return_info else {
399 return;
400 };
401
402 for var_info in returned_vars.iter_mut() {
403 var_info.apply(f)
404 }
405 }
406
407 fn replace(&mut self, var_id: VariableId, var_info: ValueInfo<'db>) {
409 self.apply(&|var_usage| {
410 if var_usage.var_id == var_id { var_info.clone() } else { ValueInfo::Var(*var_usage) }
411 });
412 }
413
414 fn apply_deconstruct(
416 &mut self,
417 ctx: &ReturnOptimizerContext<'db, '_>,
418 stmt: &StatementStructDestructure<'db>,
419 ) {
420 let Some(ReturnInfo { ref mut returned_vars, .. }) = self.opt_return_info else { return };
421
422 let mut input_consumed = false;
423 for var_info in returned_vars.iter_mut() {
424 match var_info.apply_deconstruct(ctx, stmt) {
425 OpResult::InputConsumed => {
426 input_consumed = true;
427 }
428 OpResult::ValueInvalidated => {
429 self.invalidate();
430 return;
431 }
432 OpResult::NoChange => {}
433 };
434 }
435
436 if !(input_consumed || ctx.is_droppable(stmt.input.var_id)) {
437 self.invalidate();
438 }
439 }
440
441 fn apply_match_arm(&mut self, is_droppable: bool, input: &ValueInfo<'db>, arm: &MatchArm<'db>) {
443 let Some(ReturnInfo { ref mut returned_vars, .. }) = self.opt_return_info else { return };
444
445 let mut input_consumed = false;
446 for var_info in returned_vars.iter_mut() {
447 match var_info.apply_match_arm(input, arm) {
448 OpResult::InputConsumed => {
449 input_consumed = true;
450 }
451 OpResult::ValueInvalidated => {
452 self.invalidate();
453 return;
454 }
455 OpResult::NoChange => {}
456 };
457 }
458
459 if !(input_consumed || is_droppable) {
460 self.invalidate();
461 }
462 }
463
464 fn try_get_early_return_info(&self) -> Option<&ReturnInfo<'db>> {
466 let return_info = self.opt_return_info.as_ref()?;
467
468 let mut stack = return_info.returned_vars.clone();
469 while let Some(var_info) = stack.pop() {
470 match var_info {
471 ValueInfo::Var(_) => {}
472 ValueInfo::StructConstruct { ty: _, var_infos } => stack.extend(var_infos),
473 ValueInfo::EnumConstruct { var_info, variant: _ } => stack.push(*var_info),
474 ValueInfo::Interchangeable(_) => return None,
475 }
476 }
477
478 Some(return_info)
479 }
480}
481
482impl<'db, 'a> Analyzer<'db, 'a> for ReturnOptimizerContext<'db, 'a> {
483 type Info = AnalyzerInfo<'db>;
484
485 fn visit_block_start(&mut self, info: &mut Self::Info, block_id: BlockId, _block: &Block<'db>) {
486 if let Some(return_info) = info.try_get_early_return_info() {
487 self.fixes.push(FixInfo { location: (block_id, 0), return_info: return_info.clone() });
488 }
489 }
490
491 fn visit_stmt(
492 &mut self,
493 info: &mut Self::Info,
494 (block_idx, statement_idx): StatementLocation,
495 stmt: &'a Statement<'db>,
496 ) {
497 let opt_early_return_info = info.try_get_early_return_info().cloned();
498
499 match stmt {
500 Statement::StructConstruct(StatementStructConstruct { inputs, output }) => {
501 info.replace(
505 *output,
506 ValueInfo::StructConstruct {
507 ty: self.lowered.variables[*output].ty,
508 var_infos: inputs.iter().map(|input| self.get_var_info(input)).collect(),
509 },
510 );
511 }
512
513 Statement::StructDestructure(stmt) => info.apply_deconstruct(self, stmt),
514 Statement::EnumConstruct(StatementEnumConstruct { variant, input, output }) => {
515 info.replace(
516 *output,
517 ValueInfo::EnumConstruct {
518 var_info: Box::new(self.get_var_info(input)),
519 variant: *variant,
520 },
521 );
522 }
523 _ => info.invalidate(),
524 }
525
526 if let Some(early_return_info) = opt_early_return_info
527 && info.try_get_early_return_info().is_none()
528 {
529 self.fixes.push(FixInfo {
530 location: (block_idx, statement_idx + 1),
531 return_info: early_return_info,
532 });
533 }
534 }
535
536 fn visit_goto(
537 &mut self,
538 info: &mut Self::Info,
539 _statement_location: StatementLocation,
540 _target_block_id: BlockId,
541 remapping: &VarRemapping<'db>,
542 ) {
543 info.apply(&|var_usage| {
544 if let Some(usage) = remapping.get(&var_usage.var_id) {
545 ValueInfo::Var(*usage)
546 } else {
547 ValueInfo::Var(*var_usage)
548 }
549 });
550 }
551
552 fn merge_match(
553 &mut self,
554 _statement_location: StatementLocation,
555 match_info: &'a MatchInfo<'db>,
556 infos: impl Iterator<Item = Self::Info>,
557 ) -> Self::Info {
558 Self::Info { opt_return_info: self.try_merge_match(match_info, infos) }
559 }
560
561 fn info_from_return(
562 &mut self,
563 (block_id, _statement_idx): StatementLocation,
564 vars: &'a [VarUsage<'db>],
565 ) -> Self::Info {
566 let location = match &self.lowered.blocks[block_id].end {
567 BlockEnd::Return(_vars, location) => *location,
568 _ => unreachable!(),
569 };
570
571 AnalyzerInfo {
574 opt_return_info: Some(ReturnInfo {
575 returned_vars: vars.iter().map(|var_usage| ValueInfo::Var(*var_usage)).collect(),
576 location,
577 }),
578 }
579 }
580}