#[cfg(test)]
mod test;
pub mod statements_weights;
use cairo_lang_defs::diagnostic_utils::StableLocation;
use cairo_lang_defs::ids::LanguageElementId;
use cairo_lang_diagnostics::{Diagnostics, Maybe};
use cairo_lang_semantic::items::function_with_body::FunctionWithBodySemantic;
use cairo_lang_semantic::items::functions::InlineConfiguration;
use cairo_lang_utils::casts::IntoOrPanic;
use cairo_lang_utils::ordered_hash_map::OrderedHashMap;
use cairo_lang_utils::unordered_hash_map::UnorderedHashMap;
use itertools::{Itertools, zip_eq};
use salsa::Database;
use crate::blocks::{Blocks, BlocksBuilder};
use crate::db::LoweringGroup;
use crate::diagnostic::{
LoweringDiagnostic, LoweringDiagnosticKind, LoweringDiagnostics, LoweringDiagnosticsBuilder,
};
use crate::ids::{
ConcreteFunctionWithBodyId, ConcreteFunctionWithBodyLongId, FunctionWithBodyId,
FunctionWithBodyLongId, LocationId,
};
use crate::optimizations::const_folding::ConstFoldingContext;
use crate::utils::{InliningStrategy, Rebuilder, RebuilderEx};
use crate::{
Block, BlockEnd, BlockId, DependencyType, Lowered, LoweringStage, Statement, StatementCall,
VarRemapping, Variable, VariableArena, VariableId,
};
pub fn get_inline_diagnostics<'db>(
db: &'db dyn Database,
function_id: FunctionWithBodyId<'db>,
) -> Maybe<Diagnostics<'db, LoweringDiagnostic<'db>>> {
let inline_config = match function_id.long(db) {
FunctionWithBodyLongId::Semantic(id) => db.function_declaration_inline_config(*id)?,
FunctionWithBodyLongId::Generated { .. } => InlineConfiguration::None,
};
let mut diagnostics = LoweringDiagnostics::default();
if let InlineConfiguration::Always(_) = inline_config
&& db.in_cycle(function_id, crate::DependencyType::Call)?
{
diagnostics.report(
function_id.base_semantic_function(db).untyped_stable_ptr(db),
LoweringDiagnosticKind::CannotInlineFunctionThatMightCallItself,
);
}
Ok(diagnostics.build())
}
#[salsa::tracked]
pub fn priv_should_inline<'db>(
db: &'db dyn Database,
function_id: ConcreteFunctionWithBodyId<'db>,
) -> Maybe<bool> {
if db.priv_never_inline(function_id)? {
return Ok(false);
}
let base = match function_id.long(db) {
ConcreteFunctionWithBodyLongId::Semantic(_)
| ConcreteFunctionWithBodyLongId::Generated(_) => function_id,
ConcreteFunctionWithBodyLongId::Specialized(specialized) => specialized.long(db).base,
};
if db.concrete_in_cycle(base, DependencyType::Call, LoweringStage::Monomorphized)? {
return Ok(false);
}
match (db.optimizations().inlining_strategy(), function_inline_config(db, function_id)?) {
(_, InlineConfiguration::Always(_)) => Ok(true),
(InliningStrategy::Avoid, _) | (_, InlineConfiguration::Never(_)) => Ok(false),
(_, InlineConfiguration::Should(_)) => Ok(true),
(InliningStrategy::Default, InlineConfiguration::None) => {
const DEFAULT_INLINE_SMALL_FUNCTIONS_THRESHOLD: usize = 120;
should_inline_lowered(db, function_id, DEFAULT_INLINE_SMALL_FUNCTIONS_THRESHOLD)
}
(InliningStrategy::InlineSmallFunctions(threshold), InlineConfiguration::None) => {
should_inline_lowered(db, function_id, threshold)
}
}
}
#[salsa::tracked]
pub fn priv_never_inline<'db>(
db: &'db dyn Database,
function_id: ConcreteFunctionWithBodyId<'db>,
) -> Maybe<bool> {
Ok(matches!(function_inline_config(db, function_id)?, InlineConfiguration::Never(_)))
}
fn function_inline_config<'db>(
db: &'db dyn Database,
function_id: ConcreteFunctionWithBodyId<'db>,
) -> Maybe<InlineConfiguration<'db>> {
match function_id.long(db) {
ConcreteFunctionWithBodyLongId::Semantic(id) => {
db.function_declaration_inline_config(id.function_with_body_id(db))
}
ConcreteFunctionWithBodyLongId::Generated(_) => Ok(InlineConfiguration::None),
ConcreteFunctionWithBodyLongId::Specialized(specialized) => {
function_inline_config(db, specialized.long(db).base)
}
}
}
fn should_inline_lowered(
db: &dyn Database,
function_id: ConcreteFunctionWithBodyId<'_>,
inline_small_functions_threshold: usize,
) -> Maybe<bool> {
let weight_of_blocks = db.estimate_size(function_id)?;
Ok(weight_of_blocks < inline_small_functions_threshold.into_or_panic())
}
pub struct Mapper<'db, 'mt, 'l> {
db: &'db dyn Database,
variables: &'mt mut VariableArena<'db>,
lowered: &'l Lowered<'db>,
renamed_vars: UnorderedHashMap<VariableId, VariableId>,
outputs: Vec<VariableId>,
inlining_location: StableLocation<'db>,
block_id_offset: BlockId,
return_block_id: BlockId,
}
impl<'db, 'mt, 'l> Mapper<'db, 'mt, 'l> {
pub fn new(
db: &'db dyn Database,
variables: &'mt mut VariableArena<'db>,
lowered: &'l Lowered<'db>,
call_stmt: StatementCall<'db>,
block_id_offset: usize,
) -> Self {
let renamed_vars = UnorderedHashMap::<VariableId, VariableId>::from_iter(zip_eq(
lowered.parameters.iter().cloned(),
call_stmt.inputs.iter().map(|var_usage| var_usage.var_id),
));
let inlining_location = call_stmt.location.long(db).stable_location;
Self {
db,
variables,
lowered,
renamed_vars,
block_id_offset: BlockId(block_id_offset),
return_block_id: BlockId(block_id_offset + lowered.blocks.len()),
outputs: call_stmt.outputs,
inlining_location,
}
}
}
impl<'db, 'mt> Rebuilder<'db> for Mapper<'db, 'mt, '_> {
fn map_var_id(&mut self, orig_var_id: VariableId) -> VariableId {
*self.renamed_vars.entry(orig_var_id).or_insert_with(|| {
let orig_var = &self.lowered.variables[orig_var_id];
self.variables.alloc(Variable {
location: orig_var.location.inlined(self.db, self.inlining_location),
..orig_var.clone()
})
})
}
fn map_block_id(&mut self, orig_block_id: BlockId) -> BlockId {
BlockId(self.block_id_offset.0 + orig_block_id.0)
}
fn map_location(&mut self, location: LocationId<'db>) -> LocationId<'db> {
location.inlined(self.db, self.inlining_location)
}
fn transform_end(&mut self, end: &mut BlockEnd<'db>) {
match end {
BlockEnd::Return(returns, _location) => {
let remapping = VarRemapping {
remapping: OrderedHashMap::from_iter(zip_eq(
self.outputs.iter().cloned(),
returns.iter().cloned(),
)),
};
*end = BlockEnd::Goto(self.return_block_id, remapping);
}
BlockEnd::Panic(_) | BlockEnd::Goto(_, _) | BlockEnd::Match { .. } => {}
BlockEnd::NotSet => unreachable!(),
}
}
}
fn inner_apply_inlining<'db>(
db: &'db dyn Database,
lowered: &mut Lowered<'db>,
calling_function_id: ConcreteFunctionWithBodyId<'db>,
mut enable_const_folding: bool,
) -> Maybe<()> {
lowered.blocks.has_root()?;
let mut blocks: BlocksBuilder<'db> = BlocksBuilder::new();
let mut stack: Vec<std::vec::IntoIter<BlockId>> = vec![
lowered
.blocks
.iter()
.map(|(_, block)| blocks.alloc(block.clone()))
.collect_vec()
.into_iter(),
];
let mut const_folding_ctx =
ConstFoldingContext::new(db, calling_function_id, &mut lowered.variables);
enable_const_folding = enable_const_folding && !const_folding_ctx.should_skip_const_folding(db);
while let Some(mut func_blocks) = stack.pop() {
for block_id in func_blocks.by_ref() {
let blocks = &mut blocks;
if enable_const_folding
&& !const_folding_ctx.visit_block_start(block_id, |block_id| &blocks.0[block_id.0])
{
continue;
}
let next_block_id = blocks.len();
let block = blocks.get_mut_block(block_id);
let mut opt_inline_info = None;
for (idx, statement) in block.statements.iter_mut().enumerate() {
if enable_const_folding {
const_folding_ctx.visit_statement(statement);
}
if let Some((call_stmt, called_func)) =
should_inline(db, calling_function_id, statement)?
{
opt_inline_info = Some((idx, call_stmt.clone(), called_func));
break;
}
}
let Some((call_stmt_idx, call_stmt, called_func)) = opt_inline_info else {
if enable_const_folding {
const_folding_ctx.visit_block_end(block_id, block);
}
continue;
};
let inlined_lowered = db.lowered_body(called_func, LoweringStage::PostBaseline)?;
inlined_lowered.blocks.has_root()?;
let remaining_statements =
block.statements.drain(call_stmt_idx..).skip(1).collect_vec();
let orig_block_end = std::mem::replace(
&mut block.end,
BlockEnd::Goto(BlockId(next_block_id), VarRemapping::default()),
);
if enable_const_folding {
const_folding_ctx.visit_block_end(block_id, block);
}
let mut inline_mapper = Mapper::new(
db,
const_folding_ctx.variables,
inlined_lowered,
call_stmt,
next_block_id,
);
let mut inlined_blocks_ids = inlined_lowered
.blocks
.iter()
.map(|(_block_id, block)| blocks.alloc(inline_mapper.rebuild_block(block)))
.collect_vec();
let return_block_id =
blocks.alloc(Block { statements: remaining_statements, end: orig_block_end });
assert_eq!(return_block_id, inline_mapper.return_block_id);
inlined_blocks_ids.push(return_block_id);
stack.push(func_blocks);
stack.push(inlined_blocks_ids.into_iter());
break;
}
}
lowered.blocks = blocks.build().unwrap();
Ok(())
}
fn should_inline<'db, 'r>(
db: &'db dyn Database,
calling_function_id: ConcreteFunctionWithBodyId<'db>,
statement: &'r Statement<'db>,
) -> Maybe<Option<(&'r StatementCall<'db>, ConcreteFunctionWithBodyId<'db>)>>
where
'db: 'r,
{
if let Statement::Call(stmt) = statement {
if stmt.with_coupon {
return Ok(None);
}
if let Some(called_func) = stmt.function.body(db)? {
if let ConcreteFunctionWithBodyLongId::Specialized(specialized) =
calling_function_id.long(db)
&& specialized.long(db).base == called_func
&& stmt.is_specialization_base_call
{
return Ok(Some((stmt, called_func)));
}
if called_func != calling_function_id && db.priv_should_inline(called_func)? {
return Ok(Some((stmt, called_func)));
}
}
}
Ok(None)
}
pub fn apply_inlining<'db>(
db: &'db dyn Database,
function_id: ConcreteFunctionWithBodyId<'db>,
lowered: &mut Lowered<'db>,
enable_const_folding: bool,
) -> Maybe<()> {
if let Err(diag_added) = inner_apply_inlining(db, lowered, function_id, enable_const_folding) {
lowered.blocks = Blocks::new_errored(diag_added);
}
Ok(())
}