cairo_lang_lowering/optimizations/
return_optimization.rs

1#[cfg(test)]
2#[path = "return_optimization_test.rs"]
3mod test;
4
5use cairo_lang_semantic::types::TypesSemantic;
6use cairo_lang_semantic::{self as semantic, ConcreteTypeId, TypeId, TypeLongId};
7use cairo_lang_utils::unordered_hash_map::UnorderedHashMap;
8use cairo_lang_utils::{Intern, require};
9use salsa::Database;
10use semantic::MatchArmSelector;
11
12use crate::borrow_check::analysis::{Analyzer, BackAnalysis, StatementLocation};
13use crate::ids::LocationId;
14use crate::{
15    Block, BlockEnd, BlockId, Lowered, MatchArm, MatchEnumInfo, MatchInfo, Statement,
16    StatementEnumConstruct, StatementStructConstruct, StatementStructDestructure, VarRemapping,
17    VarUsage, Variable, VariableArena, VariableId,
18};
19
20/// Adds early returns when applicable.
21///
22/// This optimization does backward analysis from return statement and keeps track of
23/// each returned value (see `ValueInfo`), whenever all the returned values are available at a block
24/// end and there were no side effects later, the end is replaced with a return statement.
25pub fn return_optimization<'db>(db: &'db dyn Database, lowered: &mut Lowered<'db>) {
26    if lowered.blocks.is_empty() {
27        return;
28    }
29    let ctx = ReturnOptimizerContext { db, lowered, fixes: vec![] };
30    let mut analysis = BackAnalysis::new(lowered, ctx);
31    analysis.get_root_info();
32    let ctx = analysis.analyzer;
33
34    let ReturnOptimizerContext { fixes, .. } = ctx;
35    for FixInfo { location: (block_id, statement_idx), return_info } in fixes {
36        let block = &mut lowered.blocks[block_id];
37        block.statements.truncate(statement_idx);
38        let mut ctx = EarlyReturnContext {
39            db,
40            constructed: UnorderedHashMap::default(),
41            variables: &mut lowered.variables,
42            statements: &mut block.statements,
43            location: return_info.location,
44        };
45        let vars = ctx.prepare_early_return_vars(&return_info.returned_vars);
46        block.end = BlockEnd::Return(vars, return_info.location)
47    }
48}
49
50/// Context for applying an early return to a block.
51struct EarlyReturnContext<'db, 'a> {
52    /// The lowering database.
53    db: &'db dyn Database,
54    /// A map from (type, inputs) to the variable_id for Structs/Enums that were created
55    /// while processing the early return.
56    constructed: UnorderedHashMap<(TypeId<'db>, Vec<VariableId>), VariableId>,
57    /// A variable allocator.
58    variables: &'a mut VariableArena<'db>,
59    /// The statements in the block where the early return is going to happen.
60    statements: &'a mut Vec<Statement<'db>>,
61    /// The location associated with the early return.
62    location: LocationId<'db>,
63}
64
65impl<'db, 'a> EarlyReturnContext<'db, 'a> {
66    /// Return a vector of VarUsage's based on the input `ret_infos`.
67    /// Adds `StructConstruct` and `EnumConstruct` statements to the block as needed.
68    /// Assumes that early return is possible for the given `ret_infos`.
69    fn prepare_early_return_vars(&mut self, ret_infos: &[ValueInfo<'db>]) -> Vec<VarUsage<'db>> {
70        let mut res = vec![];
71
72        for var_info in ret_infos.iter() {
73            match var_info {
74                ValueInfo::Var(var_usage) => {
75                    res.push(*var_usage);
76                }
77                ValueInfo::StructConstruct { ty, var_infos } => {
78                    let inputs = self.prepare_early_return_vars(var_infos);
79                    let output = *self
80                        .constructed
81                        .entry((*ty, inputs.iter().map(|var_usage| var_usage.var_id).collect()))
82                        .or_insert_with(|| {
83                            let output = self.variables.alloc(Variable::with_default_context(
84                                self.db,
85                                *ty,
86                                self.location,
87                            ));
88                            self.statements.push(Statement::StructConstruct(
89                                StatementStructConstruct { inputs, output },
90                            ));
91                            output
92                        });
93                    res.push(VarUsage { var_id: output, location: self.location });
94                }
95                ValueInfo::EnumConstruct { var_info, variant } => {
96                    let input = self.prepare_early_return_vars(std::slice::from_ref(var_info))[0];
97
98                    let ty = TypeLongId::Concrete(ConcreteTypeId::Enum(variant.concrete_enum_id))
99                        .intern(self.db);
100
101                    let output =
102                        *self.constructed.entry((ty, vec![input.var_id])).or_insert_with(|| {
103                            let output = self.variables.alloc(Variable::with_default_context(
104                                self.db,
105                                ty,
106                                self.location,
107                            ));
108                            self.statements.push(Statement::EnumConstruct(
109                                StatementEnumConstruct { variant: *variant, input, output },
110                            ));
111                            output
112                        });
113                    res.push(VarUsage { var_id: output, location: self.location });
114                }
115                ValueInfo::Interchangeable(_) => {
116                    unreachable!("early_return_possible should have prevented this.")
117                }
118            }
119        }
120
121        res
122    }
123}
124
125pub struct ReturnOptimizerContext<'db, 'a> {
126    db: &'db dyn Database,
127    lowered: &'a Lowered<'db>,
128
129    /// The list of fixes that should be applied.
130    fixes: Vec<FixInfo<'db>>,
131}
132impl<'db, 'a> ReturnOptimizerContext<'db, 'a> {
133    /// Given a VarUsage, returns the ValueInfo that corresponds to it.
134    fn get_var_info(&self, var_usage: &VarUsage<'db>) -> ValueInfo<'db> {
135        let var_ty = &self.lowered.variables[var_usage.var_id].ty;
136        if self.is_droppable(var_usage.var_id) && self.db.single_value_type(*var_ty).unwrap() {
137            ValueInfo::Interchangeable(*var_ty)
138        } else {
139            ValueInfo::Var(*var_usage)
140        }
141    }
142
143    /// Returns true if the variable is droppable
144    fn is_droppable(&self, var_id: VariableId) -> bool {
145        self.lowered.variables[var_id].info.droppable.is_ok()
146    }
147
148    /// Helper function for `merge_match`.
149    /// Returns `Option<ReturnInfo>` rather than `AnalyzerInfo` to simplify early return.
150    fn try_merge_match(
151        &mut self,
152        match_info: &MatchInfo<'db>,
153        infos: impl Iterator<Item = AnalyzerInfo<'db>>,
154    ) -> Option<ReturnInfo<'db>> {
155        let MatchInfo::Enum(MatchEnumInfo { input, arms, .. }) = match_info else {
156            return None;
157        };
158        require(!arms.is_empty())?;
159
160        let input_info = self.get_var_info(input);
161        let mut opt_last_info = None;
162        for (arm, info) in arms.iter().zip(infos) {
163            let mut curr_info = info.clone();
164            curr_info.apply_match_arm(self.is_droppable(input.var_id), &input_info, arm);
165
166            match curr_info.try_get_early_return_info() {
167                Some(return_info)
168                    if opt_last_info
169                        .map(|x: ReturnInfo<'_>| x.returned_vars == return_info.returned_vars)
170                        .unwrap_or(true) =>
171                {
172                    // If this is the first iteration or the returned var are the same as the
173                    // previous iteration, then the optimization is still applicable.
174                    opt_last_info = Some(return_info.clone())
175                }
176                _ => return None,
177            }
178        }
179
180        Some(opt_last_info.unwrap())
181    }
182}
183
184/// Information about a fix that should be applied to the lowering.
185pub struct FixInfo<'db> {
186    /// A location where we `return_vars` can be returned.
187    location: StatementLocation,
188    /// The return info at the fix location.
189    return_info: ReturnInfo<'db>,
190}
191
192/// Information about the value that should be returned from the function.
193#[derive(Clone, Debug, PartialEq, Eq)]
194pub enum ValueInfo<'db> {
195    /// The value is available through the given var usage.
196    Var(VarUsage<'db>),
197    /// The value can be replaced with other values of the same type.
198    Interchangeable(semantic::TypeId<'db>),
199    /// The value is the result of a StructConstruct statement.
200    StructConstruct {
201        /// The type of the struct.
202        ty: semantic::TypeId<'db>,
203        /// The inputs to the StructConstruct statement.
204        var_infos: Vec<ValueInfo<'db>>,
205    },
206    /// The value is the result of an EnumConstruct statement.
207    EnumConstruct {
208        /// The input to the EnumConstruct.
209        var_info: Box<ValueInfo<'db>>,
210        /// The constructed variant.
211        variant: semantic::ConcreteVariant<'db>,
212    },
213}
214
215/// The result of applying an operation to a ValueInfo.
216enum OpResult {
217    /// The input of the operation was consumed.
218    InputConsumed,
219    /// One of the value is produced operation and therefore it is invalid before the operation.
220    ValueInvalidated,
221    /// The operation did not change the value info.
222    NoChange,
223}
224
225impl<'db> ValueInfo<'db> {
226    /// Applies the given function to the value info.
227    fn apply<F>(&mut self, f: &F)
228    where
229        F: Fn(&VarUsage<'db>) -> ValueInfo<'db>,
230    {
231        match self {
232            ValueInfo::Var(var_usage) => *self = f(var_usage),
233            ValueInfo::StructConstruct { ty: _, var_infos } => {
234                for var_info in var_infos.iter_mut() {
235                    var_info.apply(f);
236                }
237            }
238            ValueInfo::EnumConstruct { var_info, .. } => {
239                var_info.apply(f);
240            }
241            ValueInfo::Interchangeable(_) => {}
242        }
243    }
244
245    /// Updates the value to the state before the StructDeconstruct statement.
246    /// Returns OpResult.
247    fn apply_deconstruct(
248        &mut self,
249        ctx: &ReturnOptimizerContext<'db, '_>,
250        stmt: &StatementStructDestructure<'db>,
251    ) -> OpResult {
252        match self {
253            ValueInfo::Var(var_usage) => {
254                if stmt.outputs.contains(&var_usage.var_id) {
255                    OpResult::ValueInvalidated
256                } else {
257                    OpResult::NoChange
258                }
259            }
260            ValueInfo::StructConstruct { ty, var_infos } => {
261                let mut cancels_out = ty == &ctx.lowered.variables[stmt.input.var_id].ty
262                    && var_infos.len() == stmt.outputs.len();
263                for (var_info, output) in var_infos.iter().zip(stmt.outputs.iter()) {
264                    if !cancels_out {
265                        break;
266                    }
267
268                    match var_info {
269                        ValueInfo::Var(var_usage) if &var_usage.var_id == output => {}
270                        ValueInfo::Interchangeable(ty)
271                            if &ctx.lowered.variables[*output].ty == ty => {}
272                        _ => cancels_out = false,
273                    }
274                }
275
276                if cancels_out {
277                    // If the StructDeconstruct cancels out the StructConstruct, then we don't need
278                    // to `apply_deconstruct` to the inner var infos.
279                    *self = ValueInfo::Var(stmt.input);
280                    return OpResult::InputConsumed;
281                }
282
283                let mut input_consumed = false;
284                for var_info in var_infos.iter_mut() {
285                    match var_info.apply_deconstruct(ctx, stmt) {
286                        OpResult::InputConsumed => {
287                            input_consumed = true;
288                        }
289                        OpResult::ValueInvalidated => {
290                            // If one of the values is invalidated the optimization is no longer
291                            // applicable.
292                            return OpResult::ValueInvalidated;
293                        }
294                        OpResult::NoChange => {}
295                    }
296                }
297
298                match input_consumed {
299                    true => OpResult::InputConsumed,
300                    false => OpResult::NoChange,
301                }
302            }
303            ValueInfo::EnumConstruct { var_info, .. } => var_info.apply_deconstruct(ctx, stmt),
304            ValueInfo::Interchangeable(_) => OpResult::NoChange,
305        }
306    }
307
308    /// Updates the value to the expected value before the match arm.
309    /// Returns OpResult.
310    fn apply_match_arm(&mut self, input: &ValueInfo<'db>, arm: &MatchArm<'db>) -> OpResult {
311        match self {
312            ValueInfo::Var(var_usage) => {
313                if arm.var_ids == [var_usage.var_id] {
314                    OpResult::ValueInvalidated
315                } else {
316                    OpResult::NoChange
317                }
318            }
319            ValueInfo::StructConstruct { ty: _, var_infos } => {
320                let mut input_consumed = false;
321                for var_info in var_infos.iter_mut() {
322                    match var_info.apply_match_arm(input, arm) {
323                        OpResult::InputConsumed => {
324                            input_consumed = true;
325                        }
326                        OpResult::ValueInvalidated => return OpResult::ValueInvalidated,
327                        OpResult::NoChange => {}
328                    }
329                }
330
331                if input_consumed {
332                    return OpResult::InputConsumed;
333                }
334                OpResult::NoChange
335            }
336            ValueInfo::EnumConstruct { var_info, variant } => {
337                let MatchArmSelector::VariantId(arm_variant) = &arm.arm_selector else {
338                    panic!("Enum construct should not appear in value match");
339                };
340
341                if *variant == *arm_variant {
342                    let cancels_out = match **var_info {
343                        ValueInfo::Interchangeable(_) => true,
344                        ValueInfo::Var(var_usage) if arm.var_ids == [var_usage.var_id] => true,
345                        _ => false,
346                    };
347
348                    if cancels_out {
349                        // If the arm recreates the relevant enum variant, then the arm
350                        // assuming the other arms also cancel out.
351                        *self = input.clone();
352                        return OpResult::InputConsumed;
353                    }
354                }
355
356                var_info.apply_match_arm(input, arm)
357            }
358            ValueInfo::Interchangeable(_) => OpResult::NoChange,
359        }
360    }
361}
362
363/// Information about the current state of the analyzer.
364/// Used to track the value that should be returned from the function at the current
365/// analysis point
366#[derive(Clone, Debug, PartialEq, Eq)]
367pub struct ReturnInfo<'db> {
368    returned_vars: Vec<ValueInfo<'db>>,
369    location: LocationId<'db>,
370}
371
372/// A wrapper around `ReturnInfo` that makes it optional.
373///
374/// None indicates that the return info is unknown.
375/// If early_return_possible() returns true, the function can return early as the return value is
376/// already known.
377#[derive(Clone, Debug, PartialEq, Eq)]
378pub struct AnalyzerInfo<'db> {
379    opt_return_info: Option<ReturnInfo<'db>>,
380}
381
382impl<'db> AnalyzerInfo<'db> {
383    /// Creates a state of the analyzer where the return optimization is not applicable.
384    fn invalidated() -> Self {
385        AnalyzerInfo { opt_return_info: None }
386    }
387
388    /// Invalidates the state of the analyzer, identifying early return is no longer possible.
389    fn invalidate(&mut self) {
390        *self = Self::invalidated();
391    }
392
393    /// Applies the given function to the returned_vars
394    fn apply<F>(&mut self, f: &F)
395    where
396        F: Fn(&VarUsage<'db>) -> ValueInfo<'db>,
397    {
398        let Some(ReturnInfo { ref mut returned_vars, .. }) = self.opt_return_info else {
399            return;
400        };
401
402        for var_info in returned_vars.iter_mut() {
403            var_info.apply(f)
404        }
405    }
406
407    /// Replaces occurrences of `var_id` with `var_info`.
408    fn replace(&mut self, var_id: VariableId, var_info: ValueInfo<'db>) {
409        self.apply(&|var_usage| {
410            if var_usage.var_id == var_id { var_info.clone() } else { ValueInfo::Var(*var_usage) }
411        });
412    }
413
414    /// Updates the info to the state before the StructDeconstruct statement.
415    fn apply_deconstruct(
416        &mut self,
417        ctx: &ReturnOptimizerContext<'db, '_>,
418        stmt: &StatementStructDestructure<'db>,
419    ) {
420        let Some(ReturnInfo { ref mut returned_vars, .. }) = self.opt_return_info else { return };
421
422        let mut input_consumed = false;
423        for var_info in returned_vars.iter_mut() {
424            match var_info.apply_deconstruct(ctx, stmt) {
425                OpResult::InputConsumed => {
426                    input_consumed = true;
427                }
428                OpResult::ValueInvalidated => {
429                    self.invalidate();
430                    return;
431                }
432                OpResult::NoChange => {}
433            };
434        }
435
436        if !(input_consumed || ctx.is_droppable(stmt.input.var_id)) {
437            self.invalidate();
438        }
439    }
440
441    /// Updates the info to the state before match arm.
442    fn apply_match_arm(&mut self, is_droppable: bool, input: &ValueInfo<'db>, arm: &MatchArm<'db>) {
443        let Some(ReturnInfo { ref mut returned_vars, .. }) = self.opt_return_info else { return };
444
445        let mut input_consumed = false;
446        for var_info in returned_vars.iter_mut() {
447            match var_info.apply_match_arm(input, arm) {
448                OpResult::InputConsumed => {
449                    input_consumed = true;
450                }
451                OpResult::ValueInvalidated => {
452                    self.invalidate();
453                    return;
454                }
455                OpResult::NoChange => {}
456            };
457        }
458
459        if !(input_consumed || is_droppable) {
460            self.invalidate();
461        }
462    }
463
464    /// Returns a vector of ValueInfos for the returns or None.
465    fn try_get_early_return_info(&self) -> Option<&ReturnInfo<'db>> {
466        let return_info = self.opt_return_info.as_ref()?;
467
468        let mut stack = return_info.returned_vars.clone();
469        while let Some(var_info) = stack.pop() {
470            match var_info {
471                ValueInfo::Var(_) => {}
472                ValueInfo::StructConstruct { ty: _, var_infos } => stack.extend(var_infos),
473                ValueInfo::EnumConstruct { var_info, variant: _ } => stack.push(*var_info),
474                ValueInfo::Interchangeable(_) => return None,
475            }
476        }
477
478        Some(return_info)
479    }
480}
481
482impl<'db, 'a> Analyzer<'db, 'a> for ReturnOptimizerContext<'db, 'a> {
483    type Info = AnalyzerInfo<'db>;
484
485    fn visit_block_start(&mut self, info: &mut Self::Info, block_id: BlockId, _block: &Block<'db>) {
486        if let Some(return_info) = info.try_get_early_return_info() {
487            self.fixes.push(FixInfo { location: (block_id, 0), return_info: return_info.clone() });
488        }
489    }
490
491    fn visit_stmt(
492        &mut self,
493        info: &mut Self::Info,
494        (block_idx, statement_idx): StatementLocation,
495        stmt: &'a Statement<'db>,
496    ) {
497        let opt_early_return_info = info.try_get_early_return_info().cloned();
498
499        match stmt {
500            Statement::StructConstruct(StatementStructConstruct { inputs, output }) => {
501                // Note that the ValueInfo::StructConstruct can only be removed by
502                // a StructDeconstruct statement that produces its non-interchangeable inputs so
503                // allowing undroppable inputs is ok here.
504                info.replace(
505                    *output,
506                    ValueInfo::StructConstruct {
507                        ty: self.lowered.variables[*output].ty,
508                        var_infos: inputs.iter().map(|input| self.get_var_info(input)).collect(),
509                    },
510                );
511            }
512
513            Statement::StructDestructure(stmt) => info.apply_deconstruct(self, stmt),
514            Statement::EnumConstruct(StatementEnumConstruct { variant, input, output }) => {
515                info.replace(
516                    *output,
517                    ValueInfo::EnumConstruct {
518                        var_info: Box::new(self.get_var_info(input)),
519                        variant: *variant,
520                    },
521                );
522            }
523            _ => info.invalidate(),
524        }
525
526        if let Some(early_return_info) = opt_early_return_info
527            && info.try_get_early_return_info().is_none()
528        {
529            self.fixes.push(FixInfo {
530                location: (block_idx, statement_idx + 1),
531                return_info: early_return_info,
532            });
533        }
534    }
535
536    fn visit_goto(
537        &mut self,
538        info: &mut Self::Info,
539        _statement_location: StatementLocation,
540        _target_block_id: BlockId,
541        remapping: &VarRemapping<'db>,
542    ) {
543        info.apply(&|var_usage| {
544            if let Some(usage) = remapping.get(&var_usage.var_id) {
545                ValueInfo::Var(*usage)
546            } else {
547                ValueInfo::Var(*var_usage)
548            }
549        });
550    }
551
552    fn merge_match(
553        &mut self,
554        _statement_location: StatementLocation,
555        match_info: &'a MatchInfo<'db>,
556        infos: impl Iterator<Item = Self::Info>,
557    ) -> Self::Info {
558        Self::Info { opt_return_info: self.try_merge_match(match_info, infos) }
559    }
560
561    fn info_from_return(
562        &mut self,
563        (block_id, _statement_idx): StatementLocation,
564        vars: &'a [VarUsage<'db>],
565    ) -> Self::Info {
566        let location = match &self.lowered.blocks[block_id].end {
567            BlockEnd::Return(_vars, location) => *location,
568            _ => unreachable!(),
569        };
570
571        // Note that `self.get_var_info` is not used here because ValueInfo::Interchangeable is
572        // supported only inside other ValueInfo variants.
573        AnalyzerInfo {
574            opt_return_info: Some(ReturnInfo {
575                returned_vars: vars.iter().map(|var_usage| ValueInfo::Var(*var_usage)).collect(),
576                location,
577            }),
578        }
579    }
580}