Skip to main content

cairo_lang_lowering/optimizations/
gas_redeposit.rs

1#[cfg(test)]
2#[path = "gas_redeposit_test.rs"]
3mod test;
4
5use cairo_lang_filesystem::flag::FlagsGroup;
6use cairo_lang_filesystem::ids::SmolStrId;
7use cairo_lang_semantic::{ConcreteVariant, corelib};
8use itertools::{Itertools, zip_eq};
9use salsa::Database;
10
11use crate::analysis::{Analyzer, BackAnalysis, StatementLocation};
12use crate::ids::{ConcreteFunctionWithBodyId, LocationId, SemanticFunctionIdEx};
13use crate::implicits::FunctionImplicitsTrait;
14use crate::panic::PanicSignatureInfo;
15use crate::{
16    BlockId, Lowered, MatchInfo, Statement, StatementCall, StatementEnumConstruct, VarRemapping,
17    VarUsage, VariableId,
18};
19
20/// Adds redeposit gas actions.
21///
22/// The algorithm is as follows:
23/// Checks if the function will have the `GasBuiltin` implicit after the lower_implicits stage.
24/// If so, after every block that ends with match, add a call to `redeposit_gas` in every arm
25/// that is followed by a convergence point or a return.
26/// Note that assuming `reorganize_blocks` stage is applied before this stage, every `goto`
27/// statement is a convergence point.
28///
29/// Note that for implementation simplicity this stage must be applied before `LowerImplicits`
30/// stage.
31pub fn gas_redeposit<'db>(
32    db: &'db dyn Database,
33    function_id: ConcreteFunctionWithBodyId<'db>,
34    lowered: &mut Lowered<'db>,
35) {
36    if lowered.blocks.is_empty() {
37        return;
38    }
39    if !db.flag_add_withdraw_gas() {
40        return;
41    }
42    let gb_ty = corelib::get_core_ty_by_name(db, SmolStrId::from(db, "GasBuiltin"), vec![]);
43    // Checking if the implicits of this function past lowering includes `GasBuiltin`.
44    if let Ok(implicits) = db.function_with_body_implicits(function_id)
45        && !implicits.into_iter().contains(&gb_ty)
46    {
47        return;
48    }
49    assert!(
50        lowered.parameters.iter().all(|p| lowered.variables[*p].ty != gb_ty),
51        "`GasRedeposit` stage must be called before `LowerImplicits` stage"
52    );
53
54    let panic_sig = PanicSignatureInfo::new(db, &function_id.signature(db).unwrap());
55    if panic_sig.always_panic {
56        return;
57    }
58    let ctx = GasRedepositContext { fixes: vec![], err_variant: panic_sig.err_variant };
59    let mut analysis = BackAnalysis::new(lowered, ctx);
60    analysis.get_root_info();
61
62    let redeposit_gas = corelib::get_function_id(
63        db,
64        corelib::core_submodule(db, SmolStrId::from(db, "gas")),
65        SmolStrId::from(db, "redeposit_gas"),
66        vec![],
67    )
68    .lowered(db);
69    for (block_id, location) in analysis.analyzer.fixes {
70        let block = &mut lowered.blocks[block_id];
71
72        // The `redeposit_gas` function is added at the beginning of the block as it result in
73        // smaller code when the GasBuiltin is revoked during the block.
74        block.statements.insert(
75            0,
76            Statement::Call(StatementCall {
77                function: redeposit_gas,
78                inputs: vec![],
79                with_coupon: false,
80                outputs: vec![],
81                location,
82                is_specialization_base_call: false,
83            }),
84        );
85    }
86}
87
88pub struct GasRedepositContext<'db> {
89    /// The list of blocks where we need to insert redeposit_gas.
90    fixes: Vec<(BlockId, LocationId<'db>)>,
91    /// The panic error variant.
92    pub err_variant: ConcreteVariant<'db>,
93}
94
95#[derive(Clone, PartialEq, Debug)]
96pub enum RedepositState {
97    /// Gas might be burned if we don't redeposit.
98    Required,
99    /// Redeposit is not necessary. This may occur if it has already been handled
100    /// or if the flow is terminating due to a panic.
101    Unnecessary,
102    /// The flow returns the given variable, redeposit is required unless the return var is of the
103    /// error variant.
104    Return(VariableId),
105}
106
107impl<'db> Analyzer<'db, '_> for GasRedepositContext<'db> {
108    type Info = RedepositState;
109
110    fn visit_stmt(
111        &mut self,
112        info: &mut Self::Info,
113        _statement_location: StatementLocation,
114        stmt: &Statement<'db>,
115    ) {
116        let RedepositState::Return(var_id) = info else {
117            return;
118        };
119
120        let Statement::EnumConstruct(StatementEnumConstruct { variant, input: _, output }) = stmt
121        else {
122            return;
123        };
124
125        if output == var_id && *variant == self.err_variant {
126            *info = RedepositState::Unnecessary;
127        }
128    }
129
130    fn visit_goto(
131        &mut self,
132        info: &mut Self::Info,
133        _statement_location: StatementLocation,
134        _target_block_id: BlockId,
135        _remapping: &VarRemapping<'db>,
136    ) {
137        // A goto is a convergence point, gas will get burned unless it is redeposited before the
138        // convergence.
139        *info = RedepositState::Required
140    }
141
142    fn merge_match(
143        &mut self,
144        _st: StatementLocation,
145        match_info: &MatchInfo<'db>,
146        infos: impl Iterator<Item = Self::Info>,
147    ) -> Self::Info {
148        for (info, arm) in zip_eq(infos, match_info.arms()) {
149            match info {
150                RedepositState::Return(_) | RedepositState::Required => {
151                    self.fixes.push((arm.block_id, *match_info.location()));
152                }
153                RedepositState::Unnecessary => {}
154            }
155        }
156
157        // `redeposit_gas` was added, no need to add it until the next convergence point.
158        RedepositState::Unnecessary
159    }
160
161    fn info_from_return(&mut self, _: StatementLocation, vars: &[VarUsage<'db>]) -> Self::Info {
162        // If the function has multiple returns with different gas costs, gas will get burned unless
163        // we redeposit it.
164        // If however, this return corresponds to a panic, we don't redeposit due to code size
165        // concerns.
166        match vars.last() {
167            Some(VarUsage { var_id, location: _ }) => RedepositState::Return(*var_id),
168            None => RedepositState::Required,
169        }
170    }
171}