Skip to main content

cairo_lang_lowering/optimizations/
return_optimization.rs

1#[cfg(test)]
2#[path = "return_optimization_test.rs"]
3mod test;
4
5use cairo_lang_semantic::{self as semantic, ConcreteTypeId, TypeId, TypeLongId};
6use cairo_lang_utils::unordered_hash_map::UnorderedHashMap;
7use cairo_lang_utils::{Intern, require};
8use id_arena::Arena;
9use semantic::MatchArmSelector;
10
11use crate::borrow_check::analysis::{Analyzer, BackAnalysis, StatementLocation};
12use crate::db::LoweringGroup;
13use crate::ids::LocationId;
14use crate::{
15    Block, BlockEnd, BlockId, Lowered, MatchArm, MatchEnumInfo, MatchInfo, Statement,
16    StatementEnumConstruct, StatementStructConstruct, StatementStructDestructure, VarRemapping,
17    VarUsage, Variable, VariableId,
18};
19
20/// Adds early returns when applicable.
21///
22/// This optimization does backward analysis from return statement and keeps track of
23/// each returned value (see `ValueInfo`), whenever all the returned values are available at a block
24/// end and there were no side effects later, the end is replaced with a return statement.
25pub fn return_optimization(db: &dyn LoweringGroup, lowered: &mut Lowered) {
26    if lowered.blocks.is_empty() {
27        return;
28    }
29    let ctx = ReturnOptimizerContext { db, lowered, fixes: vec![] };
30    let mut analysis = BackAnalysis::new(lowered, ctx);
31    analysis.get_root_info();
32    let ctx = analysis.analyzer;
33
34    let ReturnOptimizerContext { fixes, .. } = ctx;
35    for FixInfo { location: (block_id, statement_idx), return_info } in fixes {
36        let block = &mut lowered.blocks[block_id];
37        block.statements.truncate(statement_idx);
38        let mut ctx = EarlyReturnContext {
39            db,
40            constructed: UnorderedHashMap::default(),
41            variables: &mut lowered.variables,
42            statements: &mut block.statements,
43            location: return_info.location,
44        };
45        let vars = ctx.prepare_early_return_vars(&return_info.returned_vars);
46        block.end = BlockEnd::Return(vars, return_info.location)
47    }
48}
49
50/// Context for applying an early return to a block.
51struct EarlyReturnContext<'a> {
52    /// The lowering database.
53    db: &'a dyn LoweringGroup,
54    /// A map from (type, inputs) to the variable_id for Structs/Enums that were created
55    /// while processing the early return.
56    constructed: UnorderedHashMap<(TypeId, Vec<VariableId>), VariableId>,
57    /// A variable allocator.
58    variables: &'a mut Arena<Variable>,
59    /// The statements in the block where the early return is going to happen.
60    statements: &'a mut Vec<Statement>,
61    /// The location associated with the early return.
62    location: LocationId,
63}
64
65impl EarlyReturnContext<'_> {
66    /// Return a vector of VarUsage's based on the input `ret_infos`.
67    /// Adds `StructConstruct` and `EnumConstruct` statements to the block as needed.
68    /// Assumes that early return is possible for the given `ret_infos`.
69    fn prepare_early_return_vars(&mut self, ret_infos: &[ValueInfo]) -> Vec<VarUsage> {
70        let mut res = vec![];
71
72        for var_info in ret_infos.iter() {
73            match var_info {
74                ValueInfo::Var(var_usage) => {
75                    res.push(*var_usage);
76                }
77                ValueInfo::StructConstruct { ty, var_infos } => {
78                    let inputs = self.prepare_early_return_vars(var_infos);
79                    let output = *self
80                        .constructed
81                        .entry((*ty, inputs.iter().map(|var_usage| var_usage.var_id).collect()))
82                        .or_insert_with(|| {
83                            let output = self.variables.alloc(Variable::with_default_context(
84                                self.db,
85                                *ty,
86                                self.location,
87                            ));
88                            self.statements.push(Statement::StructConstruct(
89                                StatementStructConstruct { inputs, output },
90                            ));
91                            output
92                        });
93                    res.push(VarUsage { var_id: output, location: self.location });
94                }
95                ValueInfo::EnumConstruct { var_info, variant } => {
96                    let input = self.prepare_early_return_vars(std::slice::from_ref(var_info))[0];
97
98                    let ty = TypeLongId::Concrete(ConcreteTypeId::Enum(variant.concrete_enum_id))
99                        .intern(self.db);
100
101                    let output =
102                        *self.constructed.entry((ty, vec![input.var_id])).or_insert_with(|| {
103                            let output = self.variables.alloc(Variable::with_default_context(
104                                self.db,
105                                ty,
106                                self.location,
107                            ));
108                            self.statements.push(Statement::EnumConstruct(
109                                StatementEnumConstruct { variant: *variant, input, output },
110                            ));
111                            output
112                        });
113                    res.push(VarUsage { var_id: output, location: self.location });
114                }
115                ValueInfo::Interchangeable(_) => {
116                    unreachable!("early_return_possible should have prevented this.")
117                }
118            }
119        }
120
121        res
122    }
123}
124
125pub struct ReturnOptimizerContext<'a> {
126    db: &'a dyn LoweringGroup,
127    lowered: &'a Lowered,
128
129    /// The list of fixes that should be applied.
130    fixes: Vec<FixInfo>,
131}
132impl ReturnOptimizerContext<'_> {
133    /// Given a VarUsage, returns the ValueInfo that corresponds to it.
134    fn get_var_info(&self, var_usage: &VarUsage) -> ValueInfo {
135        let var_ty = &self.lowered.variables[var_usage.var_id].ty;
136        if self.is_droppable(var_usage.var_id) && self.db.single_value_type(*var_ty).unwrap() {
137            ValueInfo::Interchangeable(*var_ty)
138        } else {
139            ValueInfo::Var(*var_usage)
140        }
141    }
142
143    /// Returns true if the variable is droppable
144    fn is_droppable(&self, var_id: VariableId) -> bool {
145        self.lowered.variables[var_id].info.droppable.is_ok()
146    }
147
148    /// Helper function for `merge_match`.
149    /// Returns `Option<ReturnInfo>` rather than `AnalyzerInfo` to simplify early return.
150    fn try_merge_match(
151        &mut self,
152        match_info: &MatchInfo,
153        infos: impl Iterator<Item = AnalyzerInfo>,
154    ) -> Option<ReturnInfo> {
155        let MatchInfo::Enum(MatchEnumInfo { input, arms, .. }) = match_info else {
156            return None;
157        };
158        require(!arms.is_empty())?;
159
160        let input_info = self.get_var_info(input);
161        let mut opt_last_info = None;
162        for (arm, info) in arms.iter().zip(infos) {
163            let mut curr_info = info.clone();
164            curr_info.apply_match_arm(self.is_droppable(input.var_id), &input_info, arm);
165
166            match curr_info.try_get_early_return_info() {
167                Some(return_info)
168                    if opt_last_info
169                        .map(|x: ReturnInfo| x.returned_vars == return_info.returned_vars)
170                        .unwrap_or(true) =>
171                {
172                    // If this is the first iteration or the returned var are the same as the
173                    // previous iteration, then the optimization is still applicable.
174                    opt_last_info = Some(return_info.clone())
175                }
176                _ => return None,
177            }
178        }
179
180        Some(opt_last_info.unwrap())
181    }
182}
183
184/// Information about a fix that should be applied to the lowering.
185pub struct FixInfo {
186    /// A location where we `return_vars` can be returned.
187    location: StatementLocation,
188    /// The return info at the fix location.
189    return_info: ReturnInfo,
190}
191
192/// Information about the value that should be returned from the function.
193#[derive(Clone, Debug, PartialEq, Eq)]
194pub enum ValueInfo {
195    /// The value is available through the given var usage.
196    Var(VarUsage),
197    /// The value can be replaced with other values of the same type.
198    Interchangeable(semantic::TypeId),
199    /// The value is the result of a StructConstruct statement.
200    StructConstruct {
201        /// The type of the struct.
202        ty: semantic::TypeId,
203        /// The inputs to the StructConstruct statement.
204        var_infos: Vec<ValueInfo>,
205    },
206    /// The value is the result of an EnumConstruct statement.
207    EnumConstruct {
208        /// The input to the EnumConstruct.
209        var_info: Box<ValueInfo>,
210        /// The constructed variant.
211        variant: semantic::ConcreteVariant,
212    },
213}
214
215/// The result of applying an operation to a ValueInfo.
216enum OpResult {
217    /// The input of the operation was consumed.
218    InputConsumed,
219    /// One of the value is produced operation and therefore it is invalid before the operation.
220    ValueInvalidated,
221    /// The operation did not change the value info.
222    NoChange,
223}
224
225impl ValueInfo {
226    /// Applies the given function to the value info.
227    fn apply<F>(&mut self, f: &F)
228    where
229        F: Fn(&VarUsage) -> ValueInfo,
230    {
231        match self {
232            ValueInfo::Var(var_usage) => *self = f(var_usage),
233            ValueInfo::StructConstruct { ty: _, ref mut var_infos } => {
234                for var_info in var_infos.iter_mut() {
235                    var_info.apply(f);
236                }
237            }
238            ValueInfo::EnumConstruct { ref mut var_info, .. } => {
239                var_info.apply(f);
240            }
241            ValueInfo::Interchangeable(_) => {}
242        }
243    }
244
245    /// Updates the value to the state before the StructDeconstruct statement.
246    /// Returns OpResult.
247    fn apply_deconstruct(
248        &mut self,
249        ctx: &ReturnOptimizerContext<'_>,
250        stmt: &StatementStructDestructure,
251    ) -> OpResult {
252        match self {
253            ValueInfo::Var(var_usage) => {
254                if stmt.outputs.contains(&var_usage.var_id) {
255                    OpResult::ValueInvalidated
256                } else {
257                    OpResult::NoChange
258                }
259            }
260            ValueInfo::StructConstruct { ty, var_infos } => {
261                let mut cancels_out = ty == &ctx.lowered.variables[stmt.input.var_id].ty
262                    && var_infos.len() == stmt.outputs.len();
263                for (var_info, output) in var_infos.iter().zip(stmt.outputs.iter()) {
264                    if !cancels_out {
265                        break;
266                    }
267
268                    match var_info {
269                        ValueInfo::Var(var_usage) if &var_usage.var_id == output => {}
270                        ValueInfo::Interchangeable(ty)
271                            if &ctx.lowered.variables[*output].ty == ty => {}
272                        _ => cancels_out = false,
273                    }
274                }
275
276                if cancels_out {
277                    // If the StructDeconstruct cancels out the StructConstruct, then we don't need
278                    // to `apply_deconstruct` to the inner var infos.
279                    *self = ValueInfo::Var(stmt.input);
280                    return OpResult::InputConsumed;
281                }
282
283                let mut input_consumed = false;
284                for var_info in var_infos.iter_mut() {
285                    match var_info.apply_deconstruct(ctx, stmt) {
286                        OpResult::InputConsumed => {
287                            input_consumed = true;
288                        }
289                        OpResult::ValueInvalidated => {
290                            // If one of the values is invalidated the optimization is no longer
291                            // applicable.
292                            return OpResult::ValueInvalidated;
293                        }
294                        OpResult::NoChange => {}
295                    }
296                }
297
298                match input_consumed {
299                    true => OpResult::InputConsumed,
300                    false => OpResult::NoChange,
301                }
302            }
303            ValueInfo::EnumConstruct { ref mut var_info, .. } => {
304                var_info.apply_deconstruct(ctx, stmt)
305            }
306            ValueInfo::Interchangeable(_) => OpResult::NoChange,
307        }
308    }
309
310    /// Updates the value to the expected value before the match arm.
311    /// Returns OpResult.
312    fn apply_match_arm(&mut self, input: &ValueInfo, arm: &MatchArm) -> OpResult {
313        match self {
314            ValueInfo::Var(var_usage) => {
315                if arm.var_ids == [var_usage.var_id] {
316                    OpResult::ValueInvalidated
317                } else {
318                    OpResult::NoChange
319                }
320            }
321            ValueInfo::StructConstruct { ty: _, ref mut var_infos } => {
322                let mut input_consumed = false;
323                for var_info in var_infos.iter_mut() {
324                    match var_info.apply_match_arm(input, arm) {
325                        OpResult::InputConsumed => {
326                            input_consumed = true;
327                        }
328                        OpResult::ValueInvalidated => return OpResult::ValueInvalidated,
329                        OpResult::NoChange => {}
330                    }
331                }
332
333                if input_consumed {
334                    return OpResult::InputConsumed;
335                }
336                OpResult::NoChange
337            }
338            ValueInfo::EnumConstruct { ref mut var_info, variant } => {
339                let MatchArmSelector::VariantId(arm_variant) = &arm.arm_selector else {
340                    panic!("Enum construct should not appear in value match");
341                };
342
343                if *variant == *arm_variant {
344                    let cancels_out = match **var_info {
345                        ValueInfo::Interchangeable(_) => true,
346                        ValueInfo::Var(var_usage) if arm.var_ids == [var_usage.var_id] => true,
347                        _ => false,
348                    };
349
350                    if cancels_out {
351                        // If the arm recreates the relevant enum variant, then the arm
352                        // assuming the other arms also cancel out.
353                        *self = input.clone();
354                        return OpResult::InputConsumed;
355                    }
356                }
357
358                var_info.apply_match_arm(input, arm)
359            }
360            ValueInfo::Interchangeable(_) => OpResult::NoChange,
361        }
362    }
363}
364
365/// Information about the current state of the analyzer.
366/// Used to track the value that should be returned from the function at the current
367/// analysis point
368#[derive(Clone, Debug, PartialEq, Eq)]
369pub struct ReturnInfo {
370    returned_vars: Vec<ValueInfo>,
371    location: LocationId,
372}
373
374/// A wrapper around `ReturnInfo` that makes it optional.
375///
376/// None indicates that the return info is unknown.
377/// If early_return_possible() returns true, the function can return early as the return value is
378/// already known.
379#[derive(Clone, Debug, PartialEq, Eq)]
380pub struct AnalyzerInfo {
381    opt_return_info: Option<ReturnInfo>,
382}
383
384impl AnalyzerInfo {
385    /// Creates a state of the analyzer where the return optimization is not applicable.
386    fn invalidated() -> Self {
387        AnalyzerInfo { opt_return_info: None }
388    }
389
390    /// Invalidates the state of the analyzer, identifying early return is no longer possible.
391    fn invalidate(&mut self) {
392        *self = Self::invalidated();
393    }
394
395    /// Applies the given function to the returned_vars
396    fn apply<F>(&mut self, f: &F)
397    where
398        F: Fn(&VarUsage) -> ValueInfo,
399    {
400        let Some(ReturnInfo { ref mut returned_vars, .. }) = self.opt_return_info else {
401            return;
402        };
403
404        for var_info in returned_vars.iter_mut() {
405            var_info.apply(f)
406        }
407    }
408
409    /// Replaces occurrences of `var_id` with `var_info`.
410    fn replace(&mut self, var_id: VariableId, var_info: ValueInfo) {
411        self.apply(&|var_usage| {
412            if var_usage.var_id == var_id { var_info.clone() } else { ValueInfo::Var(*var_usage) }
413        });
414    }
415
416    /// Updates the info to the state before the StructDeconstruct statement.
417    fn apply_deconstruct(
418        &mut self,
419        ctx: &ReturnOptimizerContext<'_>,
420        stmt: &StatementStructDestructure,
421    ) {
422        let Some(ReturnInfo { ref mut returned_vars, .. }) = self.opt_return_info else { return };
423
424        let mut input_consumed = false;
425        for var_info in returned_vars.iter_mut() {
426            match var_info.apply_deconstruct(ctx, stmt) {
427                OpResult::InputConsumed => {
428                    input_consumed = true;
429                }
430                OpResult::ValueInvalidated => {
431                    self.invalidate();
432                    return;
433                }
434                OpResult::NoChange => {}
435            };
436        }
437
438        if !(input_consumed || ctx.is_droppable(stmt.input.var_id)) {
439            self.invalidate();
440        }
441    }
442
443    /// Updates the info to the state before match arm.
444    fn apply_match_arm(&mut self, is_droppable: bool, input: &ValueInfo, arm: &MatchArm) {
445        let Some(ReturnInfo { ref mut returned_vars, .. }) = self.opt_return_info else { return };
446
447        let mut input_consumed = false;
448        for var_info in returned_vars.iter_mut() {
449            match var_info.apply_match_arm(input, arm) {
450                OpResult::InputConsumed => {
451                    input_consumed = true;
452                }
453                OpResult::ValueInvalidated => {
454                    self.invalidate();
455                    return;
456                }
457                OpResult::NoChange => {}
458            };
459        }
460
461        if !(input_consumed || is_droppable) {
462            self.invalidate();
463        }
464    }
465
466    /// Returns a vector of ValueInfos for the returns or None.
467    fn try_get_early_return_info(&self) -> Option<&ReturnInfo> {
468        let return_info = self.opt_return_info.as_ref()?;
469
470        let mut stack = return_info.returned_vars.clone();
471        while let Some(var_info) = stack.pop() {
472            match var_info {
473                ValueInfo::Var(_) => {}
474                ValueInfo::StructConstruct { ty: _, var_infos } => stack.extend(var_infos),
475                ValueInfo::EnumConstruct { var_info, variant: _ } => stack.push(*var_info),
476                ValueInfo::Interchangeable(_) => return None,
477            }
478        }
479
480        Some(return_info)
481    }
482}
483
484impl<'a> Analyzer<'a> for ReturnOptimizerContext<'_> {
485    type Info = AnalyzerInfo;
486
487    fn visit_block_start(&mut self, info: &mut Self::Info, block_id: BlockId, _block: &Block) {
488        if let Some(return_info) = info.try_get_early_return_info() {
489            self.fixes.push(FixInfo { location: (block_id, 0), return_info: return_info.clone() });
490        }
491    }
492
493    fn visit_stmt(
494        &mut self,
495        info: &mut Self::Info,
496        (block_idx, statement_idx): StatementLocation,
497        stmt: &'a Statement,
498    ) {
499        let opt_early_return_info = info.try_get_early_return_info().cloned();
500
501        match stmt {
502            Statement::StructConstruct(StatementStructConstruct { inputs, output }) => {
503                // Note that the ValueInfo::StructConstruct can only be removed by
504                // a StructDeconstruct statement that produces its non-interchangeable inputs so
505                // allowing undroppable inputs is ok here.
506                info.replace(
507                    *output,
508                    ValueInfo::StructConstruct {
509                        ty: self.lowered.variables[*output].ty,
510                        var_infos: inputs.iter().map(|input| self.get_var_info(input)).collect(),
511                    },
512                );
513            }
514
515            Statement::StructDestructure(stmt) => info.apply_deconstruct(self, stmt),
516            Statement::EnumConstruct(StatementEnumConstruct { variant, input, output }) => {
517                info.replace(
518                    *output,
519                    ValueInfo::EnumConstruct {
520                        var_info: Box::new(self.get_var_info(input)),
521                        variant: *variant,
522                    },
523                );
524            }
525            _ => info.invalidate(),
526        }
527
528        if let Some(early_return_info) = opt_early_return_info {
529            if info.try_get_early_return_info().is_none() {
530                self.fixes.push(FixInfo {
531                    location: (block_idx, statement_idx + 1),
532                    return_info: early_return_info,
533                });
534            }
535        }
536    }
537
538    fn visit_goto(
539        &mut self,
540        info: &mut Self::Info,
541        _statement_location: StatementLocation,
542        _target_block_id: BlockId,
543        remapping: &VarRemapping,
544    ) {
545        info.apply(&|var_usage| {
546            if let Some(usage) = remapping.get(&var_usage.var_id) {
547                ValueInfo::Var(*usage)
548            } else {
549                ValueInfo::Var(*var_usage)
550            }
551        });
552    }
553
554    fn merge_match(
555        &mut self,
556        _statement_location: StatementLocation,
557        match_info: &'a MatchInfo,
558        infos: impl Iterator<Item = Self::Info>,
559    ) -> Self::Info {
560        Self::Info { opt_return_info: self.try_merge_match(match_info, infos) }
561    }
562
563    fn info_from_return(
564        &mut self,
565        (block_id, _statement_idx): StatementLocation,
566        vars: &'a [VarUsage],
567    ) -> Self::Info {
568        let location = match &self.lowered.blocks[block_id].end {
569            BlockEnd::Return(_vars, location) => *location,
570            _ => unreachable!(),
571        };
572
573        // Note that `self.get_var_info` is not used here because ValueInfo::Interchangeable is
574        // supported only inside other ValueInfo variants.
575        AnalyzerInfo {
576            opt_return_info: Some(ReturnInfo {
577                returned_vars: vars.iter().map(|var_usage| ValueInfo::Var(*var_usage)).collect(),
578                location,
579            }),
580        }
581    }
582}