#[cfg(test)]
#[path = "gas_redeposit_test.rs"]
mod test;
use cairo_lang_filesystem::flag::FlagsGroup;
use cairo_lang_filesystem::ids::SmolStrId;
use cairo_lang_semantic::{ConcreteVariant, corelib};
use itertools::{Itertools, zip_eq};
use salsa::Database;
use crate::analysis::{Analyzer, BackAnalysis, StatementLocation};
use crate::ids::{ConcreteFunctionWithBodyId, LocationId, SemanticFunctionIdEx};
use crate::implicits::FunctionImplicitsTrait;
use crate::panic::PanicSignatureInfo;
use crate::{
BlockId, Lowered, MatchInfo, Statement, StatementCall, StatementEnumConstruct, VarRemapping,
VarUsage, VariableId,
};
pub fn gas_redeposit<'db>(
db: &'db dyn Database,
function_id: ConcreteFunctionWithBodyId<'db>,
lowered: &mut Lowered<'db>,
) {
if lowered.blocks.is_empty() {
return;
}
if !db.flag_add_withdraw_gas() {
return;
}
let gb_ty = corelib::get_core_ty_by_name(db, SmolStrId::from(db, "GasBuiltin"), vec![]);
if let Ok(implicits) = db.function_with_body_implicits(function_id)
&& !implicits.into_iter().contains(&gb_ty)
{
return;
}
assert!(
lowered.parameters.iter().all(|p| lowered.variables[*p].ty != gb_ty),
"`GasRedeposit` stage must be called before `LowerImplicits` stage"
);
let panic_sig = PanicSignatureInfo::new(db, &function_id.signature(db).unwrap());
if panic_sig.always_panic {
return;
}
let ctx = GasRedepositContext { fixes: vec![], err_variant: panic_sig.err_variant };
let mut analysis = BackAnalysis::new(lowered, ctx);
analysis.get_root_info();
let redeposit_gas = corelib::get_function_id(
db,
corelib::core_submodule(db, SmolStrId::from(db, "gas")),
SmolStrId::from(db, "redeposit_gas"),
vec![],
)
.lowered(db);
for (block_id, location) in analysis.analyzer.fixes {
let block = &mut lowered.blocks[block_id];
block.statements.insert(
0,
Statement::Call(StatementCall {
function: redeposit_gas,
inputs: vec![],
with_coupon: false,
outputs: vec![],
location,
is_specialization_base_call: false,
}),
);
}
}
pub struct GasRedepositContext<'db> {
fixes: Vec<(BlockId, LocationId<'db>)>,
pub err_variant: ConcreteVariant<'db>,
}
#[derive(Clone, PartialEq, Debug)]
pub enum RedepositState {
Required,
Unnecessary,
Return(VariableId),
}
impl<'db> Analyzer<'db, '_> for GasRedepositContext<'db> {
type Info = RedepositState;
fn visit_stmt(
&mut self,
info: &mut Self::Info,
_statement_location: StatementLocation,
stmt: &Statement<'db>,
) {
let RedepositState::Return(var_id) = info else {
return;
};
let Statement::EnumConstruct(StatementEnumConstruct { variant, input: _, output }) = stmt
else {
return;
};
if output == var_id && *variant == self.err_variant {
*info = RedepositState::Unnecessary;
}
}
fn visit_goto(
&mut self,
info: &mut Self::Info,
_statement_location: StatementLocation,
_target_block_id: BlockId,
_remapping: &VarRemapping<'db>,
) {
*info = RedepositState::Required
}
fn merge_match(
&mut self,
_st: StatementLocation,
match_info: &MatchInfo<'db>,
infos: impl Iterator<Item = Self::Info>,
) -> Self::Info {
for (info, arm) in zip_eq(infos, match_info.arms()) {
match info {
RedepositState::Return(_) | RedepositState::Required => {
self.fixes.push((arm.block_id, *match_info.location()));
}
RedepositState::Unnecessary => {}
}
}
RedepositState::Unnecessary
}
fn info_from_return(&mut self, _: StatementLocation, vars: &[VarUsage<'db>]) -> Self::Info {
match vars.last() {
Some(VarUsage { var_id, location: _ }) => RedepositState::Return(*var_id),
None => RedepositState::Required,
}
}
}