cairo_lang_lowering/optimizations/
gas_redeposit.rs1#[cfg(test)]
2#[path = "gas_redeposit_test.rs"]
3mod test;
4
5use cairo_lang_filesystem::flag::FlagsGroup;
6use cairo_lang_filesystem::ids::SmolStrId;
7use cairo_lang_semantic::{ConcreteVariant, corelib};
8use itertools::{Itertools, zip_eq};
9use salsa::Database;
10
11use crate::analysis::{Analyzer, BackAnalysis, StatementLocation};
12use crate::ids::{ConcreteFunctionWithBodyId, LocationId, SemanticFunctionIdEx};
13use crate::implicits::FunctionImplicitsTrait;
14use crate::panic::PanicSignatureInfo;
15use crate::{
16 BlockId, Lowered, MatchInfo, Statement, StatementCall, StatementEnumConstruct, VarRemapping,
17 VarUsage, VariableId,
18};
19
20pub fn gas_redeposit<'db>(
32 db: &'db dyn Database,
33 function_id: ConcreteFunctionWithBodyId<'db>,
34 lowered: &mut Lowered<'db>,
35) {
36 if lowered.blocks.is_empty() {
37 return;
38 }
39 if !db.flag_add_withdraw_gas() {
40 return;
41 }
42 let gb_ty = corelib::get_core_ty_by_name(db, SmolStrId::from(db, "GasBuiltin"), vec![]);
43 if let Ok(implicits) = db.function_with_body_implicits(function_id)
45 && !implicits.into_iter().contains(&gb_ty)
46 {
47 return;
48 }
49 assert!(
50 lowered.parameters.iter().all(|p| lowered.variables[*p].ty != gb_ty),
51 "`GasRedeposit` stage must be called before `LowerImplicits` stage"
52 );
53
54 let panic_sig = PanicSignatureInfo::new(db, &function_id.signature(db).unwrap());
55 if panic_sig.always_panic {
56 return;
57 }
58 let ctx = GasRedepositContext { fixes: vec![], err_variant: panic_sig.err_variant };
59 let mut analysis = BackAnalysis::new(lowered, ctx);
60 analysis.get_root_info();
61
62 let redeposit_gas = corelib::get_function_id(
63 db,
64 corelib::core_submodule(db, SmolStrId::from(db, "gas")),
65 SmolStrId::from(db, "redeposit_gas"),
66 vec![],
67 )
68 .lowered(db);
69 for (block_id, location) in analysis.analyzer.fixes {
70 let block = &mut lowered.blocks[block_id];
71
72 block.statements.insert(
75 0,
76 Statement::Call(StatementCall {
77 function: redeposit_gas,
78 inputs: vec![],
79 with_coupon: false,
80 outputs: vec![],
81 location,
82 is_specialization_base_call: false,
83 }),
84 );
85 }
86}
87
88pub struct GasRedepositContext<'db> {
89 fixes: Vec<(BlockId, LocationId<'db>)>,
91 pub err_variant: ConcreteVariant<'db>,
93}
94
95#[derive(Clone, PartialEq, Debug)]
96pub enum RedepositState {
97 Required,
99 Unnecessary,
102 Return(VariableId),
105}
106
107impl<'db> Analyzer<'db, '_> for GasRedepositContext<'db> {
108 type Info = RedepositState;
109
110 fn visit_stmt(
111 &mut self,
112 info: &mut Self::Info,
113 _statement_location: StatementLocation,
114 stmt: &Statement<'db>,
115 ) {
116 let RedepositState::Return(var_id) = info else {
117 return;
118 };
119
120 let Statement::EnumConstruct(StatementEnumConstruct { variant, input: _, output }) = stmt
121 else {
122 return;
123 };
124
125 if output == var_id && *variant == self.err_variant {
126 *info = RedepositState::Unnecessary;
127 }
128 }
129
130 fn visit_goto(
131 &mut self,
132 info: &mut Self::Info,
133 _statement_location: StatementLocation,
134 _target_block_id: BlockId,
135 _remapping: &VarRemapping<'db>,
136 ) {
137 *info = RedepositState::Required
140 }
141
142 fn merge_match(
143 &mut self,
144 _st: StatementLocation,
145 match_info: &MatchInfo<'db>,
146 infos: impl Iterator<Item = Self::Info>,
147 ) -> Self::Info {
148 for (info, arm) in zip_eq(infos, match_info.arms()) {
149 match info {
150 RedepositState::Return(_) | RedepositState::Required => {
151 self.fixes.push((arm.block_id, *match_info.location()));
152 }
153 RedepositState::Unnecessary => {}
154 }
155 }
156
157 RedepositState::Unnecessary
159 }
160
161 fn info_from_return(&mut self, _: StatementLocation, vars: &[VarUsage<'db>]) -> Self::Info {
162 match vars.last() {
167 Some(VarUsage { var_id, location: _ }) => RedepositState::Return(*var_id),
168 None => RedepositState::Required,
169 }
170 }
171}