Skip to main content

cairo_lang_lowering/optimizations/
gas_redeposit.rs

1#[cfg(test)]
2#[path = "gas_redeposit_test.rs"]
3mod test;
4
5use cairo_lang_filesystem::flag::FlagsGroup;
6use cairo_lang_filesystem::ids::SmolStrId;
7use cairo_lang_semantic::{ConcreteVariant, corelib};
8use itertools::Itertools;
9use salsa::Database;
10
11use crate::analysis::core::StatementLocation;
12use crate::analysis::{DataflowAnalyzer, DataflowBackAnalysis, Direction, Edge};
13use crate::ids::{ConcreteFunctionWithBodyId, LocationId, SemanticFunctionIdEx};
14use crate::implicits::FunctionImplicitsTrait;
15use crate::panic::PanicSignatureInfo;
16use crate::{
17    BlockEnd, BlockId, Lowered, Statement, StatementCall, StatementEnumConstruct, VarUsage,
18    VariableId,
19};
20
21/// Adds redeposit gas actions.
22///
23/// The algorithm is as follows:
24/// Checks if the function will have the `GasBuiltin` implicit after the lower_implicits stage.
25/// If so, after every block that ends with match, add a call to `redeposit_gas` in every arm
26/// that is followed by a convergence point or a return.
27/// Note that assuming `reorganize_blocks` stage is applied before this stage, every `goto`
28/// statement is a convergence point.
29///
30/// Note that for implementation simplicity this stage must be applied before `LowerImplicits`
31/// stage.
32pub fn gas_redeposit<'db>(
33    db: &'db dyn Database,
34    function_id: ConcreteFunctionWithBodyId<'db>,
35    lowered: &mut Lowered<'db>,
36) {
37    if lowered.blocks.is_empty() {
38        return;
39    }
40    if !db.flag_add_withdraw_gas() {
41        return;
42    }
43    let gb_ty = corelib::get_core_ty_by_name(db, SmolStrId::from(db, "GasBuiltin"), vec![]);
44    // Checking if the implicits of this function past lowering includes `GasBuiltin`.
45    if let Ok(implicits) = db.function_with_body_implicits(function_id)
46        && !implicits.into_iter().contains(&gb_ty)
47    {
48        return;
49    }
50    assert!(
51        lowered.parameters.iter().all(|p| lowered.variables[*p].ty != gb_ty),
52        "`GasRedeposit` stage must be called before `LowerImplicits` stage"
53    );
54
55    let panic_sig = PanicSignatureInfo::new(db, &function_id.signature(db).unwrap());
56    if panic_sig.always_panic {
57        return;
58    }
59    let mut ctx = GasRedepositContext { err_variant: panic_sig.err_variant, fixes: vec![] };
60    DataflowBackAnalysis::new(lowered, &mut ctx).run();
61
62    let redeposit_gas = corelib::get_function_id(
63        db,
64        corelib::core_submodule(db, SmolStrId::from(db, "gas")),
65        SmolStrId::from(db, "redeposit_gas"),
66        vec![],
67    )
68    .lowered(db);
69    for (block_id, location) in ctx.fixes {
70        let block = &mut lowered.blocks[block_id];
71
72        // The `redeposit_gas` function is added at the beginning of the block as it result in
73        // smaller code when the GasBuiltin is revoked during the block.
74        block.statements.insert(
75            0,
76            Statement::Call(StatementCall {
77                function: redeposit_gas,
78                inputs: vec![],
79                with_coupon: false,
80                outputs: vec![],
81                location,
82                is_specialization_base_call: false,
83            }),
84        );
85    }
86}
87
88pub struct GasRedepositContext<'db> {
89    /// The panic error variant.
90    pub err_variant: ConcreteVariant<'db>,
91    /// Locations where we need to insert redeposit_gas.
92    pub fixes: Vec<(BlockId, LocationId<'db>)>,
93}
94
95/// Redeposit state for a point in the program.
96#[derive(Clone, Copy, PartialEq, Debug)]
97pub enum RedepositState {
98    /// Gas might be burned if we don't redeposit.
99    Required,
100    /// Redeposit is not necessary. This may occur if it has already been handled
101    /// or if the flow is terminating due to a panic.
102    Unnecessary,
103    /// The flow returns the given variable, redeposit is required unless the return var is of the
104    /// error variant.
105    Return(VariableId),
106}
107
108impl<'db, 'a> DataflowAnalyzer<'db, 'a> for GasRedepositContext<'db> {
109    type Info = RedepositState;
110    const DIRECTION: Direction = Direction::Backward;
111
112    fn initial_info(&mut self, _block_id: BlockId, block_end: &'a BlockEnd<'db>) -> Self::Info {
113        // If the function has multiple returns with different gas costs, gas will get burned unless
114        // we redeposit it.
115        // If however, this return corresponds to a panic, we don't redeposit due to code size
116        // concerns.
117        match block_end {
118            BlockEnd::Return(vars, _) => match vars.last() {
119                Some(VarUsage { var_id, location: _ }) => RedepositState::Return(*var_id),
120                None => RedepositState::Required,
121            },
122            _ => RedepositState::Unnecessary,
123        }
124    }
125
126    fn merge(
127        &mut self,
128        _lowered: &Lowered<'db>,
129        _statement_location: StatementLocation,
130        _info1: Self::Info,
131        _info2: Self::Info,
132    ) -> Self::Info {
133        // `redeposit_gas` was added, no need to add it until the next convergence point.
134        RedepositState::Unnecessary
135    }
136
137    fn transfer_stmt(
138        &mut self,
139        info: &mut Self::Info,
140        _statement_location: StatementLocation,
141        stmt: &'a Statement<'db>,
142    ) {
143        let RedepositState::Return(var_id) = *info else {
144            return;
145        };
146
147        let Statement::EnumConstruct(StatementEnumConstruct { variant, input: _, output }) = stmt
148        else {
149            return;
150        };
151
152        if *output == var_id && *variant == self.err_variant {
153            *info = RedepositState::Unnecessary;
154        }
155    }
156
157    fn transfer_edge(&mut self, info: &Self::Info, edge: &Edge<'db, 'a>) -> Self::Info {
158        match edge {
159            Edge::Goto { .. } => {
160                // A goto is a convergence point, gas will get burned unless it is redeposited
161                // before the convergence.
162                RedepositState::Required
163            }
164            Edge::MatchArm { arm, match_info } => {
165                if let RedepositState::Return(_) | RedepositState::Required = *info {
166                    self.fixes.push((arm.block_id, *match_info.location()));
167                }
168                RedepositState::Unnecessary
169            }
170            _ => *info,
171        }
172    }
173}