use cairo_lang_semantic::TypeId;
use cairo_lang_utils::casts::IntoOrPanic;
use salsa::Database;
use crate::db::LoweringGroup;
use crate::{BlockEnd, Lowered, Statement, VarUsage, VariableId};
pub trait InlineWeight<'db> {
fn calling_weight(&self, lowered: &Lowered<'db>) -> isize;
fn statement_weight(&self, statement: &Statement<'db>) -> isize;
fn block_end_weight(&self, block_end: &BlockEnd<'db>) -> isize;
fn lowered_weight(&self, lowered: &Lowered<'db>) -> isize {
self.calling_weight(lowered)
+ lowered
.blocks
.iter()
.map(|(_, block)| {
block
.statements
.iter()
.map(|statement| self.statement_weight(statement))
.sum::<isize>()
+ self.block_end_weight(&block.end)
})
.sum::<isize>()
}
}
pub struct SimpleInlineWeight;
impl<'db> InlineWeight<'db> for SimpleInlineWeight {
fn calling_weight(&self, _lowered: &Lowered<'db>) -> isize {
0
}
fn statement_weight(&self, _statement: &Statement<'db>) -> isize {
1
}
fn block_end_weight(&self, _block_end: &BlockEnd<'db>) -> isize {
1
}
}
pub struct ApproxCasmInlineWeight<'db> {
db: &'db dyn Database,
lowered: &'db Lowered<'db>,
}
impl<'db> ApproxCasmInlineWeight<'db> {
pub fn new(db: &'db dyn Database, lowered: &'db Lowered<'db>) -> Self {
Self { db, lowered }
}
fn tys_total_size(&self, tys: impl IntoIterator<Item = TypeId<'db>>) -> usize {
tys.into_iter().map(|ty| self.db.type_size(ty)).sum()
}
fn vars_size<I: IntoIterator<Item = &'db VariableId>>(&self, vars: I) -> usize {
self.tys_total_size(vars.into_iter().map(|v| self.lowered.variables[*v].ty))
}
fn inputs_size<I: IntoIterator<Item = &'db VarUsage<'db>>>(&self, vars: I) -> usize {
self.vars_size(vars.into_iter().map(|v| &v.var_id))
}
}
impl<'db> InlineWeight<'db> for ApproxCasmInlineWeight<'db> {
fn calling_weight(&self, _lowered: &Lowered<'db>) -> isize {
0
}
fn statement_weight(&self, statement: &Statement<'db>) -> isize {
match statement {
Statement::Call(statement_call) => self.inputs_size(&statement_call.inputs),
_ => 0,
}
.into_or_panic()
}
fn block_end_weight(&self, block_end: &BlockEnd<'db>) -> isize {
match block_end {
BlockEnd::Return(..) => 0,
BlockEnd::Goto(_, r) => self.vars_size(r.keys()),
BlockEnd::Match { info } => info.arms().len() + self.inputs_size(info.inputs()),
BlockEnd::Panic(_) | BlockEnd::NotSet => unreachable!(),
}
.into_or_panic()
}
}