#[cfg(test)]
#[path = "gas_redeposit_test.rs"]
mod test;
use cairo_lang_filesystem::flag::FlagsGroup;
use cairo_lang_filesystem::ids::SmolStrId;
use cairo_lang_semantic::{ConcreteVariant, corelib};
use itertools::Itertools;
use salsa::Database;
use crate::analysis::core::StatementLocation;
use crate::analysis::{DataflowAnalyzer, DataflowBackAnalysis, Direction, Edge};
use crate::ids::{ConcreteFunctionWithBodyId, LocationId, SemanticFunctionIdEx};
use crate::implicits::FunctionImplicitsTrait;
use crate::panic::PanicSignatureInfo;
use crate::{
BlockEnd, BlockId, Lowered, Statement, StatementCall, StatementEnumConstruct, VarUsage,
VariableId,
};
pub fn gas_redeposit<'db>(
db: &'db dyn Database,
function_id: ConcreteFunctionWithBodyId<'db>,
lowered: &mut Lowered<'db>,
) {
if lowered.blocks.is_empty() {
return;
}
if !db.flag_add_withdraw_gas() {
return;
}
let gb_ty = corelib::get_core_ty_by_name(db, SmolStrId::from(db, "GasBuiltin"), vec![]);
if let Ok(implicits) = db.function_with_body_implicits(function_id)
&& !implicits.into_iter().contains(&gb_ty)
{
return;
}
assert!(
lowered.parameters.iter().all(|p| lowered.variables[*p].ty != gb_ty),
"`GasRedeposit` stage must be called before `LowerImplicits` stage"
);
let panic_sig = PanicSignatureInfo::new(db, &function_id.signature(db).unwrap());
if panic_sig.always_panic {
return;
}
let mut ctx = GasRedepositContext { err_variant: panic_sig.err_variant, fixes: vec![] };
DataflowBackAnalysis::new(lowered, &mut ctx).run();
let redeposit_gas = corelib::get_function_id(
db,
corelib::core_submodule(db, SmolStrId::from(db, "gas")),
SmolStrId::from(db, "redeposit_gas"),
vec![],
)
.lowered(db);
for (block_id, location) in ctx.fixes {
let block = &mut lowered.blocks[block_id];
block.statements.insert(
0,
Statement::Call(StatementCall {
function: redeposit_gas,
inputs: vec![],
with_coupon: false,
outputs: vec![],
location,
is_specialization_base_call: false,
}),
);
}
}
pub struct GasRedepositContext<'db> {
pub err_variant: ConcreteVariant<'db>,
pub fixes: Vec<(BlockId, LocationId<'db>)>,
}
#[derive(Clone, Copy, PartialEq, Debug)]
pub enum RedepositState {
Required,
Unnecessary,
Return(VariableId),
}
impl<'db, 'a> DataflowAnalyzer<'db, 'a> for GasRedepositContext<'db> {
type Info = RedepositState;
const DIRECTION: Direction = Direction::Backward;
fn initial_info(&mut self, _block_id: BlockId, block_end: &'a BlockEnd<'db>) -> Self::Info {
match block_end {
BlockEnd::Return(vars, _) => match vars.last() {
Some(VarUsage { var_id, location: _ }) => RedepositState::Return(*var_id),
None => RedepositState::Required,
},
_ => RedepositState::Unnecessary,
}
}
fn merge(
&mut self,
_lowered: &Lowered<'db>,
_statement_location: StatementLocation,
_info1: Self::Info,
_info2: Self::Info,
) -> Self::Info {
RedepositState::Unnecessary
}
fn transfer_stmt(
&mut self,
info: &mut Self::Info,
_statement_location: StatementLocation,
stmt: &'a Statement<'db>,
) {
let RedepositState::Return(var_id) = *info else {
return;
};
let Statement::EnumConstruct(StatementEnumConstruct { variant, input: _, output }) = stmt
else {
return;
};
if *output == var_id && *variant == self.err_variant {
*info = RedepositState::Unnecessary;
}
}
fn transfer_edge(&mut self, info: &Self::Info, edge: &Edge<'db, 'a>) -> Self::Info {
match edge {
Edge::Goto { .. } => {
RedepositState::Required
}
Edge::MatchArm { arm, match_info } => {
if let RedepositState::Return(_) | RedepositState::Required = *info {
self.fixes.push((arm.block_id, *match_info.location()));
}
RedepositState::Unnecessary
}
_ => *info,
}
}
}