cairo_lang_lowering/optimizations/
gas_redeposit.rs1#[cfg(test)]
2#[path = "gas_redeposit_test.rs"]
3mod test;
4
5use cairo_lang_filesystem::flag::FlagsGroup;
6use cairo_lang_filesystem::ids::SmolStrId;
7use cairo_lang_semantic::{ConcreteVariant, corelib};
8use itertools::Itertools;
9use salsa::Database;
10
11use crate::analysis::core::StatementLocation;
12use crate::analysis::{DataflowAnalyzer, DataflowBackAnalysis, Direction, Edge};
13use crate::ids::{ConcreteFunctionWithBodyId, LocationId, SemanticFunctionIdEx};
14use crate::implicits::FunctionImplicitsTrait;
15use crate::panic::PanicSignatureInfo;
16use crate::{
17 BlockEnd, BlockId, Lowered, Statement, StatementCall, StatementEnumConstruct, VarUsage,
18 VariableId,
19};
20
21pub fn gas_redeposit<'db>(
33 db: &'db dyn Database,
34 function_id: ConcreteFunctionWithBodyId<'db>,
35 lowered: &mut Lowered<'db>,
36) {
37 if lowered.blocks.is_empty() {
38 return;
39 }
40 if !db.flag_add_withdraw_gas() {
41 return;
42 }
43 let gb_ty = corelib::get_core_ty_by_name(db, SmolStrId::from(db, "GasBuiltin"), vec![]);
44 if let Ok(implicits) = db.function_with_body_implicits(function_id)
46 && !implicits.into_iter().contains(&gb_ty)
47 {
48 return;
49 }
50 assert!(
51 lowered.parameters.iter().all(|p| lowered.variables[*p].ty != gb_ty),
52 "`GasRedeposit` stage must be called before `LowerImplicits` stage"
53 );
54
55 let panic_sig = PanicSignatureInfo::new(db, &function_id.signature(db).unwrap());
56 if panic_sig.always_panic {
57 return;
58 }
59 let mut ctx = GasRedepositContext { err_variant: panic_sig.err_variant, fixes: vec![] };
60 DataflowBackAnalysis::new(lowered, &mut ctx).run();
61
62 let redeposit_gas = corelib::get_function_id(
63 db,
64 corelib::core_submodule(db, SmolStrId::from(db, "gas")),
65 SmolStrId::from(db, "redeposit_gas"),
66 vec![],
67 )
68 .lowered(db);
69 for (block_id, location) in ctx.fixes {
70 let block = &mut lowered.blocks[block_id];
71
72 block.statements.insert(
75 0,
76 Statement::Call(StatementCall {
77 function: redeposit_gas,
78 inputs: vec![],
79 with_coupon: false,
80 outputs: vec![],
81 location,
82 is_specialization_base_call: false,
83 }),
84 );
85 }
86}
87
88pub struct GasRedepositContext<'db> {
89 pub err_variant: ConcreteVariant<'db>,
91 pub fixes: Vec<(BlockId, LocationId<'db>)>,
93}
94
95#[derive(Clone, Copy, PartialEq, Debug)]
97pub enum RedepositState {
98 Required,
100 Unnecessary,
103 Return(VariableId),
106}
107
108impl<'db, 'a> DataflowAnalyzer<'db, 'a> for GasRedepositContext<'db> {
109 type Info = RedepositState;
110 const DIRECTION: Direction = Direction::Backward;
111
112 fn initial_info(&mut self, _block_id: BlockId, block_end: &'a BlockEnd<'db>) -> Self::Info {
113 match block_end {
118 BlockEnd::Return(vars, _) => match vars.last() {
119 Some(VarUsage { var_id, location: _ }) => RedepositState::Return(*var_id),
120 None => RedepositState::Required,
121 },
122 _ => RedepositState::Unnecessary,
123 }
124 }
125
126 fn merge(
127 &mut self,
128 _lowered: &Lowered<'db>,
129 _statement_location: StatementLocation,
130 _info1: Self::Info,
131 _info2: Self::Info,
132 ) -> Self::Info {
133 RedepositState::Unnecessary
135 }
136
137 fn transfer_stmt(
138 &mut self,
139 info: &mut Self::Info,
140 _statement_location: StatementLocation,
141 stmt: &'a Statement<'db>,
142 ) {
143 let RedepositState::Return(var_id) = *info else {
144 return;
145 };
146
147 let Statement::EnumConstruct(StatementEnumConstruct { variant, input: _, output }) = stmt
148 else {
149 return;
150 };
151
152 if *output == var_id && *variant == self.err_variant {
153 *info = RedepositState::Unnecessary;
154 }
155 }
156
157 fn transfer_edge(&mut self, info: &Self::Info, edge: &Edge<'db, 'a>) -> Self::Info {
158 match edge {
159 Edge::Goto { .. } => {
160 RedepositState::Required
163 }
164 Edge::MatchArm { arm, match_info } => {
165 if let RedepositState::Return(_) | RedepositState::Required = *info {
166 self.fixes.push((arm.block_id, *match_info.location()));
167 }
168 RedepositState::Unnecessary
169 }
170 _ => *info,
171 }
172 }
173}