cairo_lang_lowering/optimizations/
return_optimization.rs

1#[cfg(test)]
2#[path = "return_optimization_test.rs"]
3mod test;
4
5use cairo_lang_semantic::{self as semantic, ConcreteTypeId, TypeId, TypeLongId};
6use cairo_lang_utils::unordered_hash_map::UnorderedHashMap;
7use cairo_lang_utils::{Intern, require};
8use semantic::MatchArmSelector;
9
10use crate::borrow_check::analysis::{Analyzer, BackAnalysis, StatementLocation};
11use crate::db::LoweringGroup;
12use crate::ids::{ConcreteFunctionWithBodyId, LocationId};
13use crate::lower::context::{VarRequest, VariableAllocator};
14use crate::{
15    BlockId, FlatBlock, FlatBlockEnd, FlatLowered, MatchArm, MatchEnumInfo, MatchInfo, Statement,
16    StatementEnumConstruct, StatementStructConstruct, StatementStructDestructure, VarRemapping,
17    VarUsage, VariableId,
18};
19
20/// Adds early returns when applicable.
21///
22/// This optimization does backward analysis from return statement and keeps track of
23/// each returned value (see `ValueInfo`), whenever all the returned values are available at a block
24/// end and there was no side effects later, the end is replaced with a return statement.
25pub fn return_optimization(
26    db: &dyn LoweringGroup,
27    function_id: ConcreteFunctionWithBodyId,
28    lowered: &mut FlatLowered,
29) {
30    if lowered.blocks.is_empty() {
31        return;
32    }
33    let ctx = ReturnOptimizerContext { db, lowered, fixes: vec![] };
34    let mut analysis = BackAnalysis::new(lowered, ctx);
35    analysis.get_root_info();
36    let ctx = analysis.analyzer;
37
38    let mut variables = VariableAllocator::new(
39        db,
40        function_id.function_with_body_id(db).base_semantic_function(db),
41        lowered.variables.clone(),
42    )
43    .unwrap();
44
45    for FixInfo { location: (block_id, statement_idx), return_info } in ctx.fixes {
46        let block = &mut lowered.blocks[block_id];
47        block.statements.truncate(statement_idx);
48        let mut ctx = EarlyReturnContext {
49            constructed: UnorderedHashMap::default(),
50            variables: &mut variables,
51            statements: &mut block.statements,
52            location: return_info.location,
53        };
54        let vars = ctx.prepare_early_return_vars(&return_info.returned_vars);
55        block.end = FlatBlockEnd::Return(vars, return_info.location)
56    }
57
58    lowered.variables = variables.variables;
59}
60
61/// Context for applying an early return to a block.
62struct EarlyReturnContext<'a, 'b> {
63    /// A map from (type, inputs) to the variable_id for Structs/Enums that were created
64    /// while processing the early return.
65    constructed: UnorderedHashMap<(TypeId, Vec<VariableId>), VariableId>,
66    /// A variable allocator.
67    variables: &'a mut VariableAllocator<'b>,
68    /// The statements in the block where the early return is going to happen.
69    statements: &'a mut Vec<Statement>,
70    /// The location associated with the early return.
71    location: LocationId,
72}
73
74impl EarlyReturnContext<'_, '_> {
75    /// Return a vector of VarUsage's based on the input `ret_infos`.
76    /// Adds `StructConstruct` and `EnumConstruct` statements to the block as needed.
77    /// Assumes that early return is possible for the given `ret_infos`.
78    fn prepare_early_return_vars(&mut self, ret_infos: &[ValueInfo]) -> Vec<VarUsage> {
79        let mut res = vec![];
80
81        for var_info in ret_infos.iter() {
82            match var_info {
83                ValueInfo::Var(var_usage) => {
84                    res.push(*var_usage);
85                }
86                ValueInfo::StructConstruct { ty, var_infos } => {
87                    let inputs = self.prepare_early_return_vars(var_infos);
88                    let output = *self
89                        .constructed
90                        .entry((*ty, inputs.iter().map(|var_usage| var_usage.var_id).collect()))
91                        .or_insert_with(|| {
92                            let output = self
93                                .variables
94                                .new_var(VarRequest { ty: *ty, location: self.location });
95                            self.statements.push(Statement::StructConstruct(
96                                StatementStructConstruct { inputs, output },
97                            ));
98                            output
99                        });
100                    res.push(VarUsage { var_id: output, location: self.location });
101                }
102                ValueInfo::EnumConstruct { var_info, variant } => {
103                    let input = self.prepare_early_return_vars(std::slice::from_ref(var_info))[0];
104
105                    let ty = TypeLongId::Concrete(ConcreteTypeId::Enum(variant.concrete_enum_id))
106                        .intern(self.variables.db);
107
108                    let output =
109                        *self.constructed.entry((ty, vec![input.var_id])).or_insert_with(|| {
110                            let output =
111                                self.variables.new_var(VarRequest { ty, location: self.location });
112                            self.statements.push(Statement::EnumConstruct(
113                                StatementEnumConstruct { variant: variant.clone(), input, output },
114                            ));
115                            output
116                        });
117                    res.push(VarUsage { var_id: output, location: self.location });
118                }
119                ValueInfo::Interchangeable(_) => {
120                    unreachable!("early_return_possible should have prevented this.")
121                }
122            }
123        }
124
125        res
126    }
127}
128
129pub struct ReturnOptimizerContext<'a> {
130    db: &'a dyn LoweringGroup,
131    lowered: &'a FlatLowered,
132
133    /// The list of fixes that should be applied.
134    fixes: Vec<FixInfo>,
135}
136impl ReturnOptimizerContext<'_> {
137    /// Given a VarUsage, returns the ValueInfo that corresponds to it.
138    fn get_var_info(&self, var_usage: &VarUsage) -> ValueInfo {
139        let var_ty = &self.lowered.variables[var_usage.var_id].ty;
140        if self.is_droppable(var_usage.var_id) && self.db.single_value_type(*var_ty).unwrap() {
141            ValueInfo::Interchangeable(*var_ty)
142        } else {
143            ValueInfo::Var(*var_usage)
144        }
145    }
146
147    /// Returns true if the variable is droppable
148    fn is_droppable(&self, var_id: VariableId) -> bool {
149        self.lowered.variables[var_id].droppable.is_ok()
150    }
151
152    /// Helper function for `merge_match`.
153    /// Returns `Option<ReturnInfo>` rather than `AnalyzerInfo` to simplify early return.
154    fn try_merge_match(
155        &mut self,
156        match_info: &MatchInfo,
157        infos: &[AnalyzerInfo],
158    ) -> Option<ReturnInfo> {
159        let MatchInfo::Enum(MatchEnumInfo { input, arms, .. }) = match_info else {
160            return None;
161        };
162        require(!arms.is_empty())?;
163
164        let input_info = self.get_var_info(input);
165        let mut opt_last_info = None;
166        for (arm, info) in arms.iter().zip(infos) {
167            let mut curr_info = info.clone();
168            curr_info.apply_match_arm(self.is_droppable(input.var_id), &input_info, arm);
169
170            match curr_info.try_get_early_return_info() {
171                Some(return_info)
172                    if opt_last_info
173                        .map(|x: ReturnInfo| x.returned_vars == return_info.returned_vars)
174                        .unwrap_or(true) =>
175                {
176                    // If this is the first iteration or the returned var are the same as the
177                    // previous iteration, then the optimization is still applicable.
178                    opt_last_info = Some(return_info.clone())
179                }
180                _ => return None,
181            }
182        }
183
184        Some(opt_last_info.unwrap())
185    }
186}
187
188/// Information about a fix that should be applied to the lowering.
189pub struct FixInfo {
190    /// A location where we `return_vars` can be returned.
191    location: StatementLocation,
192    /// The return info at the fix location.
193    return_info: ReturnInfo,
194}
195
196/// Information about the value that should be returned from the function.
197#[derive(Clone, Debug, PartialEq, Eq)]
198pub enum ValueInfo {
199    /// The value is available through the given var usage.
200    Var(VarUsage),
201    /// The value can be replaced with other values of the same type.
202    Interchangeable(semantic::TypeId),
203    /// The value is the result of a StructConstruct statement.
204    StructConstruct {
205        /// The type of the struct.
206        ty: semantic::TypeId,
207        /// The inputs to the StructConstruct statement.
208        var_infos: Vec<ValueInfo>,
209    },
210    /// The value is the result of an EnumConstruct statement.
211    EnumConstruct {
212        /// The input to the EnumConstruct.
213        var_info: Box<ValueInfo>,
214        /// The constructed variant.
215        variant: semantic::ConcreteVariant,
216    },
217}
218
219/// The result of applying an operation to a ValueInfo.
220enum OpResult {
221    /// The input of the operation was consumed.
222    InputConsumed,
223    /// One of the value is produced operation and therefore it is invalid before the operation.
224    ValueInvalidated,
225    /// The operation did not change the value info.
226    NoChange,
227}
228
229impl ValueInfo {
230    /// Applies the given function to the value info.
231    fn apply<F>(&mut self, f: &F)
232    where
233        F: Fn(&VarUsage) -> ValueInfo,
234    {
235        match self {
236            ValueInfo::Var(var_usage) => *self = f(var_usage),
237            ValueInfo::StructConstruct { ty: _, ref mut var_infos } => {
238                for var_info in var_infos.iter_mut() {
239                    var_info.apply(f);
240                }
241            }
242            ValueInfo::EnumConstruct { ref mut var_info, .. } => {
243                var_info.apply(f);
244            }
245            ValueInfo::Interchangeable(_) => {}
246        }
247    }
248
249    /// Updates the value to the state before the StructDeconstruct statement.
250    /// Returns OpResult.
251    fn apply_deconstruct(
252        &mut self,
253        ctx: &ReturnOptimizerContext<'_>,
254        stmt: &StatementStructDestructure,
255    ) -> OpResult {
256        match self {
257            ValueInfo::Var(var_usage) => {
258                if stmt.outputs.contains(&var_usage.var_id) {
259                    OpResult::ValueInvalidated
260                } else {
261                    OpResult::NoChange
262                }
263            }
264            ValueInfo::StructConstruct { ty, var_infos } => {
265                let mut cancels_out = ty == &ctx.lowered.variables[stmt.input.var_id].ty
266                    && var_infos.len() == stmt.outputs.len();
267                for (var_info, output) in var_infos.iter().zip(stmt.outputs.iter()) {
268                    if !cancels_out {
269                        break;
270                    }
271
272                    match var_info {
273                        ValueInfo::Var(var_usage) if &var_usage.var_id == output => {}
274                        ValueInfo::Interchangeable(ty)
275                            if &ctx.lowered.variables[*output].ty == ty => {}
276                        _ => cancels_out = false,
277                    }
278                }
279
280                if cancels_out {
281                    // If the StructDeconstruct cancels out the StructConstruct, then we don't need
282                    // to `apply_deconstruct` to the inner var infos.
283                    *self = ValueInfo::Var(stmt.input);
284                    return OpResult::InputConsumed;
285                }
286
287                let mut input_consumed = false;
288                for var_info in var_infos.iter_mut() {
289                    match var_info.apply_deconstruct(ctx, stmt) {
290                        OpResult::InputConsumed => {
291                            input_consumed = true;
292                        }
293                        OpResult::ValueInvalidated => {
294                            // If one of the values is invalidated the optimization is no longer
295                            // applicable.
296                            return OpResult::ValueInvalidated;
297                        }
298                        OpResult::NoChange => {}
299                    }
300                }
301
302                match input_consumed {
303                    true => OpResult::InputConsumed,
304                    false => OpResult::NoChange,
305                }
306            }
307            ValueInfo::EnumConstruct { ref mut var_info, .. } => {
308                var_info.apply_deconstruct(ctx, stmt)
309            }
310            ValueInfo::Interchangeable(_) => OpResult::NoChange,
311        }
312    }
313
314    /// Updates the value to the expected value before the match arm.
315    /// Returns OpResult.
316    fn apply_match_arm(&mut self, input: &ValueInfo, arm: &MatchArm) -> OpResult {
317        match self {
318            ValueInfo::Var(var_usage) => {
319                if arm.var_ids == [var_usage.var_id] {
320                    OpResult::ValueInvalidated
321                } else {
322                    OpResult::NoChange
323                }
324            }
325            ValueInfo::StructConstruct { ty: _, ref mut var_infos } => {
326                let mut input_consumed = false;
327                for var_info in var_infos.iter_mut() {
328                    match var_info.apply_match_arm(input, arm) {
329                        OpResult::InputConsumed => {
330                            input_consumed = true;
331                        }
332                        OpResult::ValueInvalidated => return OpResult::ValueInvalidated,
333                        OpResult::NoChange => {}
334                    }
335                }
336
337                if input_consumed {
338                    return OpResult::InputConsumed;
339                }
340                OpResult::NoChange
341            }
342            ValueInfo::EnumConstruct { ref mut var_info, variant } => {
343                let MatchArmSelector::VariantId(arm_variant) = &arm.arm_selector else {
344                    panic!("Enum construct should not appear in value match");
345                };
346
347                if *variant == *arm_variant {
348                    let cancels_out = match **var_info {
349                        ValueInfo::Interchangeable(_) => true,
350                        ValueInfo::Var(var_usage) if arm.var_ids == [var_usage.var_id] => true,
351                        _ => false,
352                    };
353
354                    if cancels_out {
355                        // If the arm recreates the relevant enum variant, then the arm
356                        // assuming the other arms also cancel out.
357                        *self = input.clone();
358                        return OpResult::InputConsumed;
359                    }
360                }
361
362                var_info.apply_match_arm(input, arm)
363            }
364            ValueInfo::Interchangeable(_) => OpResult::NoChange,
365        }
366    }
367}
368
369/// Information about the current state of the analyzer.
370/// Used to track the value that should be returned from the function at the current
371/// analysis point
372#[derive(Clone, Debug, PartialEq, Eq)]
373pub struct ReturnInfo {
374    returned_vars: Vec<ValueInfo>,
375    location: LocationId,
376}
377
378/// A wrapper around `ReturnInfo` that makes it optional.
379///
380/// None indicates that the return info is unknown.
381/// If early_return_possible() returns true, the function can return early as the return value is
382/// already known.
383#[derive(Clone, Debug, PartialEq, Eq)]
384pub struct AnalyzerInfo {
385    opt_return_info: Option<ReturnInfo>,
386}
387
388impl AnalyzerInfo {
389    /// Creates a state of the analyzer where the return optimization is not applicable.
390    fn invalidated() -> Self {
391        AnalyzerInfo { opt_return_info: None }
392    }
393
394    /// Invalidates the state of the analyzer, identifying early return is no longer possible.
395    fn invalidate(&mut self) {
396        *self = Self::invalidated();
397    }
398
399    /// Applies the given function to the returned_vars
400    fn apply<F>(&mut self, f: &F)
401    where
402        F: Fn(&VarUsage) -> ValueInfo,
403    {
404        let Some(ReturnInfo { ref mut returned_vars, .. }) = self.opt_return_info else {
405            return;
406        };
407
408        for var_info in returned_vars.iter_mut() {
409            var_info.apply(f)
410        }
411    }
412
413    /// Replaces occurrences of `var_id` with `var_info`.
414    fn replace(&mut self, var_id: VariableId, var_info: ValueInfo) {
415        self.apply(&|var_usage| {
416            if var_usage.var_id == var_id { var_info.clone() } else { ValueInfo::Var(*var_usage) }
417        });
418    }
419
420    /// Updates the info to the state before the StructDeconstruct statement.
421    fn apply_deconstruct(
422        &mut self,
423        ctx: &ReturnOptimizerContext<'_>,
424        stmt: &StatementStructDestructure,
425    ) {
426        let Some(ReturnInfo { ref mut returned_vars, .. }) = self.opt_return_info else { return };
427
428        let mut input_consumed = false;
429        for var_info in returned_vars.iter_mut() {
430            match var_info.apply_deconstruct(ctx, stmt) {
431                OpResult::InputConsumed => {
432                    input_consumed = true;
433                }
434                OpResult::ValueInvalidated => {
435                    self.invalidate();
436                    return;
437                }
438                OpResult::NoChange => {}
439            };
440        }
441
442        if !(input_consumed || ctx.is_droppable(stmt.input.var_id)) {
443            self.invalidate();
444        }
445    }
446
447    /// Updates the info to the state before match arm.
448    fn apply_match_arm(&mut self, is_droppable: bool, input: &ValueInfo, arm: &MatchArm) {
449        let Some(ReturnInfo { ref mut returned_vars, .. }) = self.opt_return_info else { return };
450
451        let mut input_consumed = false;
452        for var_info in returned_vars.iter_mut() {
453            match var_info.apply_match_arm(input, arm) {
454                OpResult::InputConsumed => {
455                    input_consumed = true;
456                }
457                OpResult::ValueInvalidated => {
458                    self.invalidate();
459                    return;
460                }
461                OpResult::NoChange => {}
462            };
463        }
464
465        if !(input_consumed || is_droppable) {
466            self.invalidate();
467        }
468    }
469
470    /// Returns a vector of ValueInfos for the returns or None.
471    fn try_get_early_return_info(&self) -> Option<&ReturnInfo> {
472        let return_info = self.opt_return_info.as_ref()?;
473
474        let mut stack = return_info.returned_vars.clone();
475        while let Some(var_info) = stack.pop() {
476            match var_info {
477                ValueInfo::Var(_) => {}
478                ValueInfo::StructConstruct { ty: _, var_infos } => stack.extend(var_infos),
479                ValueInfo::EnumConstruct { var_info, variant: _ } => stack.push(*var_info),
480                ValueInfo::Interchangeable(_) => return None,
481            }
482        }
483
484        Some(return_info)
485    }
486}
487
488impl<'a> Analyzer<'a> for ReturnOptimizerContext<'_> {
489    type Info = AnalyzerInfo;
490
491    fn visit_block_start(&mut self, info: &mut Self::Info, block_id: BlockId, _block: &FlatBlock) {
492        if let Some(return_info) = info.try_get_early_return_info() {
493            self.fixes.push(FixInfo { location: (block_id, 0), return_info: return_info.clone() });
494        }
495    }
496
497    fn visit_stmt(
498        &mut self,
499        info: &mut Self::Info,
500        (block_idx, statement_idx): StatementLocation,
501        stmt: &'a Statement,
502    ) {
503        let opt_early_return_info = info.try_get_early_return_info().cloned();
504
505        match stmt {
506            Statement::StructConstruct(StatementStructConstruct { inputs, output }) => {
507                // Note that the ValueInfo::StructConstruct can only be removed by
508                // a StructDeconstruct statement that produces its non-interchangeable inputs so
509                // allowing undroppable inputs is ok here.
510                info.replace(
511                    *output,
512                    ValueInfo::StructConstruct {
513                        ty: self.lowered.variables[*output].ty,
514                        var_infos: inputs.iter().map(|input| self.get_var_info(input)).collect(),
515                    },
516                );
517            }
518
519            Statement::StructDestructure(stmt) => info.apply_deconstruct(self, stmt),
520            Statement::EnumConstruct(StatementEnumConstruct { variant, input, output }) => {
521                info.replace(
522                    *output,
523                    ValueInfo::EnumConstruct {
524                        var_info: Box::new(self.get_var_info(input)),
525                        variant: variant.clone(),
526                    },
527                );
528            }
529            _ => info.invalidate(),
530        }
531
532        if let Some(early_return_info) = opt_early_return_info {
533            if info.try_get_early_return_info().is_none() {
534                self.fixes.push(FixInfo {
535                    location: (block_idx, statement_idx + 1),
536                    return_info: early_return_info,
537                });
538            }
539        }
540    }
541
542    fn visit_goto(
543        &mut self,
544        info: &mut Self::Info,
545        _statement_location: StatementLocation,
546        _target_block_id: BlockId,
547        remapping: &VarRemapping,
548    ) {
549        info.apply(&|var_usage| {
550            if let Some(usage) = remapping.get(&var_usage.var_id) {
551                ValueInfo::Var(*usage)
552            } else {
553                ValueInfo::Var(*var_usage)
554            }
555        });
556    }
557
558    fn merge_match(
559        &mut self,
560        _statement_location: StatementLocation,
561        match_info: &'a MatchInfo,
562        infos: impl Iterator<Item = Self::Info>,
563    ) -> Self::Info {
564        let infos: Vec<_> = infos.collect();
565        let opt_return_info = self.try_merge_match(match_info, &infos);
566        Self::Info { opt_return_info }
567    }
568
569    fn info_from_return(
570        &mut self,
571        (block_id, _statement_idx): StatementLocation,
572        vars: &'a [VarUsage],
573    ) -> Self::Info {
574        let location = match &self.lowered.blocks[block_id].end {
575            FlatBlockEnd::Return(_vars, location) => *location,
576            _ => unreachable!(),
577        };
578
579        // Note that `self.get_var_info` is not used here because ValueInfo::Interchangeable is
580        // supported only inside other ValueInfo variants.
581        AnalyzerInfo {
582            opt_return_info: Some(ReturnInfo {
583                returned_vars: vars.iter().map(|var_usage| ValueInfo::Var(*var_usage)).collect(),
584                location,
585            }),
586        }
587    }
588}