Skip to main content

cairo_lang_lowering/optimizations/
return_optimization.rs

1#[cfg(test)]
2#[path = "return_optimization_test.rs"]
3mod test;
4
5use cairo_lang_semantic::types::TypesSemantic;
6use cairo_lang_semantic::{self as semantic, ConcreteTypeId, TypeId, TypeLongId};
7use cairo_lang_utils::unordered_hash_map::UnorderedHashMap;
8use cairo_lang_utils::{Intern, require};
9use salsa::Database;
10use semantic::MatchArmSelector;
11
12use crate::analysis::{Analyzer, BackAnalysis, StatementLocation};
13use crate::ids::LocationId;
14use crate::{
15    Block, BlockEnd, BlockId, Lowered, MatchArm, MatchEnumInfo, MatchInfo, Statement,
16    StatementEnumConstruct, StatementStructConstruct, StatementStructDestructure, VarRemapping,
17    VarUsage, Variable, VariableArena, VariableId,
18};
19
20/// Adds early returns when applicable.
21///
22/// This optimization does backward analysis from return statement and keeps track of
23/// each returned value (see `ValueInfo`), whenever all the returned values are available at a block
24/// end and there were no side effects later, the end is replaced with a return statement.
25pub fn return_optimization<'db>(db: &'db dyn Database, lowered: &mut Lowered<'db>) {
26    if lowered.blocks.is_empty() {
27        return;
28    }
29    let ctx = ReturnOptimizerContext { db, lowered, fixes: vec![] };
30    let mut analysis = BackAnalysis::new(lowered, ctx);
31    analysis.get_root_info();
32    let ctx = analysis.analyzer;
33
34    let ReturnOptimizerContext { fixes, .. } = ctx;
35    for FixInfo { location: (block_id, statement_idx), return_info } in fixes {
36        let block = &mut lowered.blocks[block_id];
37        block.statements.truncate(statement_idx);
38        let mut ctx = EarlyReturnContext {
39            db,
40            constructed: UnorderedHashMap::default(),
41            variables: &mut lowered.variables,
42            statements: &mut block.statements,
43            location: return_info.location,
44        };
45        let vars = ctx.prepare_early_return_vars(&return_info.returned_vars);
46        block.end = BlockEnd::Return(vars, return_info.location)
47    }
48}
49
50/// Context for applying an early return to a block.
51struct EarlyReturnContext<'db, 'a> {
52    /// The lowering database.
53    db: &'db dyn Database,
54    /// A map from a `Construction` to the variable_id for Structs that were created
55    /// while processing the early return.
56    constructed: UnorderedHashMap<Construction<'db>, VariableId>,
57    /// A variable allocator.
58    variables: &'a mut VariableArena<'db>,
59    /// The statements in the block where the early return is going to happen.
60    statements: &'a mut Vec<Statement<'db>>,
61    /// The location associated with the early return.
62    location: LocationId<'db>,
63}
64
65/// A `Construction` represents a struct or enum construction that was created while processing
66/// the early return.
67#[derive(Debug, Clone, PartialEq, Eq, Hash)]
68enum Construction<'db> {
69    /// A construction of a struct.
70    Struct(TypeId<'db>, Vec<VariableId>),
71    /// A construction of an enum.
72    Enum(semantic::ConcreteVariant<'db>, VariableId),
73}
74
75impl<'db, 'a> EarlyReturnContext<'db, 'a> {
76    /// Returns a vector of VarUsage's based on the input `ret_infos`.
77    /// Adds `StructConstruct` and `EnumConstruct` statements to the block as needed.
78    /// Assumes that early return is possible for the given `ret_infos`.
79    fn prepare_early_return_vars(&mut self, ret_infos: &[ValueInfo<'db>]) -> Vec<VarUsage<'db>> {
80        let mut res = vec![];
81
82        for var_info in ret_infos.iter() {
83            match var_info {
84                ValueInfo::Var(var_usage) => {
85                    res.push(*var_usage);
86                }
87                ValueInfo::StructConstruct { ty, var_infos } => {
88                    let inputs = self.prepare_early_return_vars(var_infos);
89                    let output = *self
90                        .constructed
91                        .entry(Construction::Struct(
92                            *ty,
93                            inputs.iter().map(|var_usage| var_usage.var_id).collect(),
94                        ))
95                        .or_insert_with(|| {
96                            let output = self.variables.alloc(Variable::with_default_context(
97                                self.db,
98                                *ty,
99                                self.location,
100                            ));
101                            self.statements.push(Statement::StructConstruct(
102                                StatementStructConstruct { inputs, output },
103                            ));
104                            output
105                        });
106                    res.push(VarUsage { var_id: output, location: self.location });
107                }
108                ValueInfo::EnumConstruct { var_info, variant } => {
109                    let input = self.prepare_early_return_vars(std::slice::from_ref(var_info))[0];
110
111                    let ty = TypeLongId::Concrete(ConcreteTypeId::Enum(variant.concrete_enum_id))
112                        .intern(self.db);
113
114                    let output = *self
115                        .constructed
116                        .entry(Construction::Enum(*variant, input.var_id))
117                        .or_insert_with(|| {
118                            let output = self.variables.alloc(Variable::with_default_context(
119                                self.db,
120                                ty,
121                                self.location,
122                            ));
123                            self.statements.push(Statement::EnumConstruct(
124                                StatementEnumConstruct { variant: *variant, input, output },
125                            ));
126                            output
127                        });
128                    res.push(VarUsage { var_id: output, location: self.location });
129                }
130                ValueInfo::Interchangeable(_) => {
131                    unreachable!("early_return_possible should have prevented this.")
132                }
133            }
134        }
135
136        res
137    }
138}
139
140pub struct ReturnOptimizerContext<'db, 'a> {
141    db: &'db dyn Database,
142    lowered: &'a Lowered<'db>,
143
144    /// The list of fixes that should be applied.
145    fixes: Vec<FixInfo<'db>>,
146}
147impl<'db, 'a> ReturnOptimizerContext<'db, 'a> {
148    /// Given a VarUsage, returns the ValueInfo that corresponds to it.
149    fn get_var_info(&self, var_usage: &VarUsage<'db>) -> ValueInfo<'db> {
150        let var_ty = &self.lowered.variables[var_usage.var_id].ty;
151        if self.is_droppable(var_usage.var_id) && self.db.single_value_type(*var_ty).unwrap() {
152            ValueInfo::Interchangeable(*var_ty)
153        } else {
154            ValueInfo::Var(*var_usage)
155        }
156    }
157
158    /// Returns true if the variable is droppable.
159    fn is_droppable(&self, var_id: VariableId) -> bool {
160        self.lowered.variables[var_id].info.droppable.is_ok()
161    }
162
163    /// Helper function for `merge_match`.
164    /// Returns `Option<ReturnInfo>` rather than `AnalyzerInfo` to simplify early return.
165    fn try_merge_match(
166        &mut self,
167        match_info: &MatchInfo<'db>,
168        infos: impl Iterator<Item = AnalyzerInfo<'db>>,
169    ) -> Option<ReturnInfo<'db>> {
170        let MatchInfo::Enum(MatchEnumInfo { input, arms, .. }) = match_info else {
171            return None;
172        };
173        require(!arms.is_empty())?;
174
175        let input_info = self.get_var_info(input);
176        let mut opt_last_info = None;
177        for (arm, info) in arms.iter().zip(infos) {
178            let mut curr_info = info.clone();
179            curr_info.apply_match_arm(self.is_droppable(input.var_id), &input_info, arm);
180
181            match curr_info.try_get_early_return_info() {
182                Some(return_info)
183                    if opt_last_info
184                        .map(|x: ReturnInfo<'_>| x.returned_vars == return_info.returned_vars)
185                        .unwrap_or(true) =>
186                {
187                    // If this is the first iteration or the returned var are the same as the
188                    // previous iteration, then the optimization is still applicable.
189                    opt_last_info = Some(return_info.clone())
190                }
191                _ => return None,
192            }
193        }
194
195        Some(opt_last_info.unwrap())
196    }
197}
198
199/// Information about a fix that should be applied to the lowering.
200pub struct FixInfo<'db> {
201    /// A location where we `return_vars` can be returned.
202    location: StatementLocation,
203    /// The return info at the fix location.
204    return_info: ReturnInfo<'db>,
205}
206
207/// Information about the value that should be returned from the function.
208#[derive(Clone, Debug, PartialEq, Eq)]
209pub enum ValueInfo<'db> {
210    /// The value is available through the given var usage.
211    Var(VarUsage<'db>),
212    /// The value can be replaced with other values of the same type.
213    Interchangeable(semantic::TypeId<'db>),
214    /// The value is the result of a StructConstruct statement.
215    StructConstruct {
216        /// The type of the struct.
217        ty: semantic::TypeId<'db>,
218        /// The inputs to the StructConstruct statement.
219        var_infos: Vec<ValueInfo<'db>>,
220    },
221    /// The value is the result of an EnumConstruct statement.
222    EnumConstruct {
223        /// The input to the EnumConstruct.
224        var_info: Box<ValueInfo<'db>>,
225        /// The constructed variant.
226        variant: semantic::ConcreteVariant<'db>,
227    },
228}
229
230/// The result of applying an operation to a ValueInfo.
231enum OpResult {
232    /// The input of the operation was consumed.
233    InputConsumed,
234    /// One of the value is produced operation and therefore it is invalid before the operation.
235    ValueInvalidated,
236    /// The operation did not change the value info.
237    NoChange,
238}
239
240impl<'db> ValueInfo<'db> {
241    /// Applies the given function to the value info.
242    fn apply<F>(&mut self, f: &F)
243    where
244        F: Fn(&VarUsage<'db>) -> ValueInfo<'db>,
245    {
246        match self {
247            ValueInfo::Var(var_usage) => *self = f(var_usage),
248            ValueInfo::StructConstruct { ty: _, var_infos } => {
249                for var_info in var_infos.iter_mut() {
250                    var_info.apply(f);
251                }
252            }
253            ValueInfo::EnumConstruct { var_info, .. } => {
254                var_info.apply(f);
255            }
256            ValueInfo::Interchangeable(_) => {}
257        }
258    }
259
260    /// Updates the value to the state before the StructDeconstruct statement.
261    /// Returns OpResult.
262    fn apply_deconstruct(
263        &mut self,
264        ctx: &ReturnOptimizerContext<'db, '_>,
265        stmt: &StatementStructDestructure<'db>,
266    ) -> OpResult {
267        match self {
268            ValueInfo::Var(var_usage) => {
269                if stmt.outputs.contains(&var_usage.var_id) {
270                    OpResult::ValueInvalidated
271                } else {
272                    OpResult::NoChange
273                }
274            }
275            ValueInfo::StructConstruct { ty, var_infos } => {
276                let mut cancels_out = ty == &ctx.lowered.variables[stmt.input.var_id].ty
277                    && var_infos.len() == stmt.outputs.len();
278                for (var_info, output) in var_infos.iter().zip(stmt.outputs.iter()) {
279                    if !cancels_out {
280                        break;
281                    }
282
283                    match var_info {
284                        ValueInfo::Var(var_usage) if &var_usage.var_id == output => {}
285                        ValueInfo::Interchangeable(ty)
286                            if &ctx.lowered.variables[*output].ty == ty => {}
287                        _ => cancels_out = false,
288                    }
289                }
290
291                if cancels_out {
292                    // If the StructDeconstruct cancels out the StructConstruct, then we don't need
293                    // to `apply_deconstruct` to the inner var infos.
294                    *self = ValueInfo::Var(stmt.input);
295                    return OpResult::InputConsumed;
296                }
297
298                let mut input_consumed = false;
299                for var_info in var_infos.iter_mut() {
300                    match var_info.apply_deconstruct(ctx, stmt) {
301                        OpResult::InputConsumed => {
302                            input_consumed = true;
303                        }
304                        OpResult::ValueInvalidated => {
305                            // If one of the values is invalidated the optimization is no longer
306                            // applicable.
307                            return OpResult::ValueInvalidated;
308                        }
309                        OpResult::NoChange => {}
310                    }
311                }
312
313                match input_consumed {
314                    true => OpResult::InputConsumed,
315                    false => OpResult::NoChange,
316                }
317            }
318            ValueInfo::EnumConstruct { var_info, .. } => var_info.apply_deconstruct(ctx, stmt),
319            ValueInfo::Interchangeable(_) => OpResult::NoChange,
320        }
321    }
322
323    /// Updates the value to the expected value before the match arm.
324    /// Returns OpResult.
325    fn apply_match_arm(&mut self, input: &ValueInfo<'db>, arm: &MatchArm<'db>) -> OpResult {
326        match self {
327            ValueInfo::Var(var_usage) => {
328                if arm.var_ids == [var_usage.var_id] {
329                    OpResult::ValueInvalidated
330                } else {
331                    OpResult::NoChange
332                }
333            }
334            ValueInfo::StructConstruct { ty: _, var_infos } => {
335                let mut input_consumed = false;
336                for var_info in var_infos.iter_mut() {
337                    match var_info.apply_match_arm(input, arm) {
338                        OpResult::InputConsumed => {
339                            input_consumed = true;
340                        }
341                        OpResult::ValueInvalidated => return OpResult::ValueInvalidated,
342                        OpResult::NoChange => {}
343                    }
344                }
345
346                if input_consumed {
347                    return OpResult::InputConsumed;
348                }
349                OpResult::NoChange
350            }
351            ValueInfo::EnumConstruct { var_info, variant } => {
352                let MatchArmSelector::VariantId(arm_variant) = &arm.arm_selector else {
353                    panic!("Enum construct should not appear in value match");
354                };
355
356                if *variant == *arm_variant {
357                    let cancels_out = match **var_info {
358                        ValueInfo::Interchangeable(_) => true,
359                        ValueInfo::Var(var_usage) if arm.var_ids == [var_usage.var_id] => true,
360                        _ => false,
361                    };
362
363                    if cancels_out {
364                        // If the arm recreates the relevant enum variant, then the arm
365                        // assuming the other arms also cancel out.
366                        *self = input.clone();
367                        return OpResult::InputConsumed;
368                    }
369                }
370
371                var_info.apply_match_arm(input, arm)
372            }
373            ValueInfo::Interchangeable(_) => OpResult::NoChange,
374        }
375    }
376}
377
378/// Information about the current state of the analyzer.
379/// Used to track the value that should be returned from the function at the current
380/// analysis point.
381#[derive(Clone, Debug, PartialEq, Eq)]
382pub struct ReturnInfo<'db> {
383    returned_vars: Vec<ValueInfo<'db>>,
384    location: LocationId<'db>,
385}
386
387/// A wrapper around `ReturnInfo` that makes it optional.
388///
389/// None indicates that the return info is unknown.
390/// If early_return_possible() returns true, the function can return early as the return value is
391/// already known.
392#[derive(Clone, Debug, PartialEq, Eq)]
393pub struct AnalyzerInfo<'db> {
394    opt_return_info: Option<ReturnInfo<'db>>,
395}
396
397impl<'db> AnalyzerInfo<'db> {
398    /// Creates a state of the analyzer where the return optimization is not applicable.
399    fn invalidated() -> Self {
400        AnalyzerInfo { opt_return_info: None }
401    }
402
403    /// Invalidates the state of the analyzer, identifying early return is no longer possible.
404    fn invalidate(&mut self) {
405        *self = Self::invalidated();
406    }
407
408    /// Applies the given function to the returned_vars.
409    fn apply<F>(&mut self, f: &F)
410    where
411        F: Fn(&VarUsage<'db>) -> ValueInfo<'db>,
412    {
413        let Some(ReturnInfo { ref mut returned_vars, .. }) = self.opt_return_info else {
414            return;
415        };
416
417        for var_info in returned_vars.iter_mut() {
418            var_info.apply(f)
419        }
420    }
421
422    /// Replaces occurrences of `var_id` with `var_info`.
423    fn replace(&mut self, var_id: VariableId, var_info: ValueInfo<'db>) {
424        self.apply(&|var_usage| {
425            if var_usage.var_id == var_id { var_info.clone() } else { ValueInfo::Var(*var_usage) }
426        });
427    }
428
429    /// Updates the info to the state before the StructDeconstruct statement.
430    fn apply_deconstruct(
431        &mut self,
432        ctx: &ReturnOptimizerContext<'db, '_>,
433        stmt: &StatementStructDestructure<'db>,
434    ) {
435        let Some(ReturnInfo { ref mut returned_vars, .. }) = self.opt_return_info else { return };
436
437        let mut input_consumed = false;
438        for var_info in returned_vars.iter_mut() {
439            match var_info.apply_deconstruct(ctx, stmt) {
440                OpResult::InputConsumed => {
441                    input_consumed = true;
442                }
443                OpResult::ValueInvalidated => {
444                    self.invalidate();
445                    return;
446                }
447                OpResult::NoChange => {}
448            };
449        }
450
451        if !(input_consumed || ctx.is_droppable(stmt.input.var_id)) {
452            self.invalidate();
453        }
454    }
455
456    /// Updates the info to the state before match arm.
457    fn apply_match_arm(&mut self, is_droppable: bool, input: &ValueInfo<'db>, arm: &MatchArm<'db>) {
458        let Some(ReturnInfo { ref mut returned_vars, .. }) = self.opt_return_info else { return };
459
460        let mut input_consumed = false;
461        for var_info in returned_vars.iter_mut() {
462            match var_info.apply_match_arm(input, arm) {
463                OpResult::InputConsumed => {
464                    input_consumed = true;
465                }
466                OpResult::ValueInvalidated => {
467                    self.invalidate();
468                    return;
469                }
470                OpResult::NoChange => {}
471            };
472        }
473
474        if !(input_consumed || is_droppable) {
475            self.invalidate();
476        }
477    }
478
479    /// Returns a vector of ValueInfos for the returns or None.
480    fn try_get_early_return_info(&self) -> Option<&ReturnInfo<'db>> {
481        let return_info = self.opt_return_info.as_ref()?;
482
483        let mut stack = return_info.returned_vars.clone();
484        while let Some(var_info) = stack.pop() {
485            match var_info {
486                ValueInfo::Var(_) => {}
487                ValueInfo::StructConstruct { ty: _, var_infos } => stack.extend(var_infos),
488                ValueInfo::EnumConstruct { var_info, variant: _ } => stack.push(*var_info),
489                ValueInfo::Interchangeable(_) => return None,
490            }
491        }
492
493        Some(return_info)
494    }
495}
496
497impl<'db, 'a> Analyzer<'db, 'a> for ReturnOptimizerContext<'db, 'a> {
498    type Info = AnalyzerInfo<'db>;
499
500    fn visit_block_start(&mut self, info: &mut Self::Info, block_id: BlockId, _block: &Block<'db>) {
501        if let Some(return_info) = info.try_get_early_return_info() {
502            self.fixes.push(FixInfo { location: (block_id, 0), return_info: return_info.clone() });
503        }
504    }
505
506    fn visit_stmt(
507        &mut self,
508        info: &mut Self::Info,
509        (block_idx, statement_idx): StatementLocation,
510        stmt: &'a Statement<'db>,
511    ) {
512        let opt_early_return_info = info.try_get_early_return_info().cloned();
513
514        match stmt {
515            Statement::StructConstruct(StatementStructConstruct { inputs, output }) => {
516                // Note that the ValueInfo::StructConstruct can only be removed by
517                // a StructDeconstruct statement that produces its non-interchangeable inputs so
518                // allowing undroppable inputs is ok here.
519                info.replace(
520                    *output,
521                    ValueInfo::StructConstruct {
522                        ty: self.lowered.variables[*output].ty,
523                        var_infos: inputs.iter().map(|input| self.get_var_info(input)).collect(),
524                    },
525                );
526            }
527
528            Statement::StructDestructure(stmt) => info.apply_deconstruct(self, stmt),
529            Statement::EnumConstruct(StatementEnumConstruct { variant, input, output }) => {
530                info.replace(
531                    *output,
532                    ValueInfo::EnumConstruct {
533                        var_info: Box::new(self.get_var_info(input)),
534                        variant: *variant,
535                    },
536                );
537            }
538            _ => info.invalidate(),
539        }
540
541        if let Some(early_return_info) = opt_early_return_info
542            && info.try_get_early_return_info().is_none()
543        {
544            self.fixes.push(FixInfo {
545                location: (block_idx, statement_idx + 1),
546                return_info: early_return_info,
547            });
548        }
549    }
550
551    fn visit_goto(
552        &mut self,
553        info: &mut Self::Info,
554        _statement_location: StatementLocation,
555        _target_block_id: BlockId,
556        remapping: &VarRemapping<'db>,
557    ) {
558        info.apply(&|var_usage| {
559            if let Some(usage) = remapping.get(&var_usage.var_id) {
560                ValueInfo::Var(*usage)
561            } else {
562                ValueInfo::Var(*var_usage)
563            }
564        });
565    }
566
567    fn merge_match(
568        &mut self,
569        _statement_location: StatementLocation,
570        match_info: &'a MatchInfo<'db>,
571        infos: impl Iterator<Item = Self::Info>,
572    ) -> Self::Info {
573        Self::Info { opt_return_info: self.try_merge_match(match_info, infos) }
574    }
575
576    fn info_from_return(
577        &mut self,
578        (block_id, _statement_idx): StatementLocation,
579        vars: &'a [VarUsage<'db>],
580    ) -> Self::Info {
581        let location = match &self.lowered.blocks[block_id].end {
582            BlockEnd::Return(_vars, location) => *location,
583            _ => unreachable!(),
584        };
585
586        // Note that `self.get_var_info` is not used here because ValueInfo::Interchangeable is
587        // supported only inside other ValueInfo variants.
588        AnalyzerInfo {
589            opt_return_info: Some(ReturnInfo {
590                returned_vars: vars.iter().map(|var_usage| ValueInfo::Var(*var_usage)).collect(),
591                location,
592            }),
593        }
594    }
595}