cairo_lang_lowering/inline/
mod.rs1#[cfg(test)]
2mod test;
3
4pub mod statements_weights;
5
6use cairo_lang_defs::diagnostic_utils::StableLocation;
7use cairo_lang_defs::ids::LanguageElementId;
8use cairo_lang_diagnostics::{Diagnostics, Maybe};
9use cairo_lang_semantic::items::function_with_body::FunctionWithBodySemantic;
10use cairo_lang_semantic::items::functions::InlineConfiguration;
11use cairo_lang_utils::casts::IntoOrPanic;
12use cairo_lang_utils::ordered_hash_map::OrderedHashMap;
13use cairo_lang_utils::unordered_hash_map::UnorderedHashMap;
14use itertools::{Itertools, zip_eq};
15use salsa::Database;
16
17use crate::blocks::{Blocks, BlocksBuilder};
18use crate::db::LoweringGroup;
19use crate::diagnostic::{
20 LoweringDiagnostic, LoweringDiagnosticKind, LoweringDiagnostics, LoweringDiagnosticsBuilder,
21};
22use crate::ids::{
23 ConcreteFunctionWithBodyId, ConcreteFunctionWithBodyLongId, FunctionWithBodyId,
24 FunctionWithBodyLongId, LocationId,
25};
26use crate::optimizations::const_folding::ConstFoldingContext;
27use crate::utils::{InliningStrategy, Rebuilder, RebuilderEx};
28use crate::{
29 Block, BlockEnd, BlockId, DependencyType, Lowered, LoweringStage, Statement, StatementCall,
30 VarRemapping, Variable, VariableArena, VariableId,
31};
32
33pub fn get_inline_diagnostics<'db>(
34 db: &'db dyn Database,
35 function_id: FunctionWithBodyId<'db>,
36) -> Maybe<Diagnostics<'db, LoweringDiagnostic<'db>>> {
37 let inline_config = match function_id.long(db) {
38 FunctionWithBodyLongId::Semantic(id) => db.function_declaration_inline_config(*id)?,
39 FunctionWithBodyLongId::Generated { .. } => InlineConfiguration::None,
40 };
41 let mut diagnostics = LoweringDiagnostics::default();
42
43 if let InlineConfiguration::Always(_) = inline_config
44 && db.in_cycle(function_id, crate::DependencyType::Call)?
45 {
46 diagnostics.report(
47 function_id.base_semantic_function(db).untyped_stable_ptr(db),
48 LoweringDiagnosticKind::CannotInlineFunctionThatMightCallItself,
49 );
50 }
51
52 Ok(diagnostics.build())
53}
54
55#[salsa::tracked]
57pub fn priv_should_inline<'db>(
58 db: &'db dyn Database,
59 function_id: ConcreteFunctionWithBodyId<'db>,
60) -> Maybe<bool> {
61 if db.priv_never_inline(function_id)? {
62 return Ok(false);
63 }
64 let base = match function_id.long(db) {
71 ConcreteFunctionWithBodyLongId::Semantic(_)
72 | ConcreteFunctionWithBodyLongId::Generated(_) => function_id,
73 ConcreteFunctionWithBodyLongId::Specialized(specialized) => specialized.base,
74 };
75 if db.concrete_in_cycle(base, DependencyType::Call, LoweringStage::Monomorphized)? {
76 return Ok(false);
77 }
78
79 match (db.optimizations().inlining_strategy(), function_inline_config(db, function_id)?) {
80 (_, InlineConfiguration::Always(_)) => Ok(true),
81 (InliningStrategy::Avoid, _) | (_, InlineConfiguration::Never(_)) => Ok(false),
82 (_, InlineConfiguration::Should(_)) => Ok(true),
83 (InliningStrategy::Default, InlineConfiguration::None) => {
84 const DEFAULT_INLINE_SMALL_FUNCTIONS_THRESHOLD: usize = 120;
87 should_inline_lowered(db, function_id, DEFAULT_INLINE_SMALL_FUNCTIONS_THRESHOLD)
88 }
89 (InliningStrategy::InlineSmallFunctions(threshold), InlineConfiguration::None) => {
90 should_inline_lowered(db, function_id, threshold)
91 }
92 }
93}
94
95#[salsa::tracked]
97pub fn priv_never_inline<'db>(
98 db: &'db dyn Database,
99 function_id: ConcreteFunctionWithBodyId<'db>,
100) -> Maybe<bool> {
101 Ok(matches!(function_inline_config(db, function_id)?, InlineConfiguration::Never(_)))
102}
103
104fn function_inline_config<'db>(
106 db: &'db dyn Database,
107 function_id: ConcreteFunctionWithBodyId<'db>,
108) -> Maybe<InlineConfiguration<'db>> {
109 match function_id.long(db) {
110 ConcreteFunctionWithBodyLongId::Semantic(id) => {
111 db.function_declaration_inline_config(id.function_with_body_id(db))
112 }
113 ConcreteFunctionWithBodyLongId::Generated(_) => Ok(InlineConfiguration::None),
114 ConcreteFunctionWithBodyLongId::Specialized(specialized) => {
115 function_inline_config(db, specialized.base)
116 }
117 }
118}
119
120fn should_inline_lowered(
122 db: &dyn Database,
123 function_id: ConcreteFunctionWithBodyId<'_>,
124 inline_small_functions_threshold: usize,
125) -> Maybe<bool> {
126 let weight_of_blocks = db.estimate_size(function_id)?;
127 Ok(weight_of_blocks < inline_small_functions_threshold.into_or_panic())
128}
129pub struct Mapper<'db, 'mt, 'l> {
131 db: &'db dyn Database,
132 variables: &'mt mut VariableArena<'db>,
133 lowered: &'l Lowered<'db>,
134 renamed_vars: UnorderedHashMap<VariableId, VariableId>,
135
136 outputs: Vec<VariableId>,
137 inlining_location: StableLocation<'db>,
138
139 block_id_offset: BlockId,
142
143 return_block_id: BlockId,
145}
146
147impl<'db, 'mt, 'l> Mapper<'db, 'mt, 'l> {
148 pub fn new(
149 db: &'db dyn Database,
150 variables: &'mt mut VariableArena<'db>,
151 lowered: &'l Lowered<'db>,
152 call_stmt: StatementCall<'db>,
153 block_id_offset: usize,
154 ) -> Self {
155 let renamed_vars = UnorderedHashMap::<VariableId, VariableId>::from_iter(zip_eq(
157 lowered.parameters.iter().cloned(),
158 call_stmt.inputs.iter().map(|var_usage| var_usage.var_id),
159 ));
160
161 let inlining_location = call_stmt.location.long(db).stable_location;
162
163 Self {
164 db,
165 variables,
166 lowered,
167 renamed_vars,
168 block_id_offset: BlockId(block_id_offset),
169 return_block_id: BlockId(block_id_offset + lowered.blocks.len()),
170 outputs: call_stmt.outputs,
171 inlining_location,
172 }
173 }
174}
175
176impl<'db, 'mt> Rebuilder<'db> for Mapper<'db, 'mt, '_> {
177 fn map_var_id(&mut self, orig_var_id: VariableId) -> VariableId {
181 *self.renamed_vars.entry(orig_var_id).or_insert_with(|| {
182 let orig_var = &self.lowered.variables[orig_var_id];
183 self.variables.alloc(Variable {
184 location: orig_var.location.inlined(self.db, self.inlining_location),
185 ..orig_var.clone()
186 })
187 })
188 }
189
190 fn map_block_id(&mut self, orig_block_id: BlockId) -> BlockId {
193 BlockId(self.block_id_offset.0 + orig_block_id.0)
194 }
195
196 fn map_location(&mut self, location: LocationId<'db>) -> LocationId<'db> {
198 location.inlined(self.db, self.inlining_location)
199 }
200
201 fn transform_end(&mut self, end: &mut BlockEnd<'db>) {
202 match end {
203 BlockEnd::Return(returns, _location) => {
204 let remapping = VarRemapping {
205 remapping: OrderedHashMap::from_iter(zip_eq(
206 self.outputs.iter().cloned(),
207 returns.iter().cloned(),
208 )),
209 };
210 *end = BlockEnd::Goto(self.return_block_id, remapping);
211 }
212 BlockEnd::Panic(_) | BlockEnd::Goto(_, _) | BlockEnd::Match { .. } => {}
213 BlockEnd::NotSet => unreachable!(),
214 }
215 }
216}
217
218fn inner_apply_inlining<'db>(
223 db: &'db dyn Database,
224 lowered: &mut Lowered<'db>,
225 calling_function_id: ConcreteFunctionWithBodyId<'db>,
226 mut enable_const_folding: bool,
227) -> Maybe<()> {
228 lowered.blocks.has_root()?;
229
230 let mut blocks: BlocksBuilder<'db> = BlocksBuilder::new();
231
232 let mut stack: Vec<std::vec::IntoIter<BlockId>> = vec![
233 lowered
234 .blocks
235 .iter()
236 .map(|(_, block)| blocks.alloc(block.clone()))
237 .collect_vec()
238 .into_iter(),
239 ];
240
241 let mut const_folding_ctx =
242 ConstFoldingContext::new(db, calling_function_id, &mut lowered.variables);
243
244 enable_const_folding = enable_const_folding && !const_folding_ctx.should_skip_const_folding(db);
245
246 while let Some(mut func_blocks) = stack.pop() {
247 for block_id in func_blocks.by_ref() {
248 let blocks = &mut blocks;
249 if enable_const_folding
250 && !const_folding_ctx.visit_block_start(block_id, |block_id| &blocks.0[block_id.0])
251 {
252 continue;
253 }
254
255 let next_block_id = blocks.len();
257 let block = blocks.get_mut_block(block_id);
258
259 let mut opt_inline_info = None;
260 for (idx, statement) in block.statements.iter_mut().enumerate() {
261 if enable_const_folding {
262 const_folding_ctx.visit_statement(statement);
263 }
264 if let Some((call_stmt, called_func)) =
265 should_inline(db, calling_function_id, statement)?
266 {
267 opt_inline_info = Some((idx, call_stmt.clone(), called_func));
268 break;
269 }
270 }
271
272 let Some((call_stmt_idx, call_stmt, called_func)) = opt_inline_info else {
273 if enable_const_folding {
274 const_folding_ctx.visit_block_end(block_id, block);
275 }
276 continue;
278 };
279
280 let inlined_lowered = db.lowered_body(called_func, LoweringStage::PostBaseline)?;
281 inlined_lowered.blocks.has_root()?;
282
283 let remaining_statements =
285 block.statements.drain(call_stmt_idx..).skip(1).collect_vec();
286
287 let orig_block_end = std::mem::replace(
289 &mut block.end,
290 BlockEnd::Goto(BlockId(next_block_id), VarRemapping::default()),
291 );
292
293 if enable_const_folding {
294 const_folding_ctx.visit_block_end(block_id, block);
295 }
296
297 let mut inline_mapper = Mapper::new(
298 db,
299 const_folding_ctx.variables,
300 inlined_lowered,
301 call_stmt,
302 next_block_id,
303 );
304
305 let mut inlined_blocks_ids = inlined_lowered
308 .blocks
309 .iter()
310 .map(|(_block_id, block)| blocks.alloc(inline_mapper.rebuild_block(block)))
311 .collect_vec();
312
313 let return_block_id =
315 blocks.alloc(Block { statements: remaining_statements, end: orig_block_end });
316 assert_eq!(return_block_id, inline_mapper.return_block_id);
317
318 inlined_blocks_ids.push(return_block_id);
322
323 stack.push(func_blocks);
326 stack.push(inlined_blocks_ids.into_iter());
327 break;
328 }
329 }
330
331 lowered.blocks = blocks.build().unwrap();
332 Ok(())
333}
334
335fn should_inline<'db, 'r>(
338 db: &'db dyn Database,
339 calling_function_id: ConcreteFunctionWithBodyId<'db>,
340 statement: &'r Statement<'db>,
341) -> Maybe<Option<(&'r StatementCall<'db>, ConcreteFunctionWithBodyId<'db>)>>
342where
343 'db: 'r,
344{
345 if let Statement::Call(stmt) = statement {
346 if stmt.with_coupon {
347 return Ok(None);
348 }
349
350 if let Some(called_func) = stmt.function.body(db)? {
351 if let ConcreteFunctionWithBodyLongId::Specialized(specialized) =
352 calling_function_id.long(db)
353 && specialized.base == called_func
354 && stmt.is_specialization_base_call
355 {
356 return Ok(Some((stmt, called_func)));
358 }
359
360 if called_func != calling_function_id && db.priv_should_inline(called_func)? {
363 return Ok(Some((stmt, called_func)));
364 }
365 }
366 }
367
368 Ok(None)
369}
370
371pub fn apply_inlining<'db>(
375 db: &'db dyn Database,
376 function_id: ConcreteFunctionWithBodyId<'db>,
377 lowered: &mut Lowered<'db>,
378 enable_const_folding: bool,
379) -> Maybe<()> {
380 if let Err(diag_added) = inner_apply_inlining(db, lowered, function_id, enable_const_folding) {
381 lowered.blocks = Blocks::new_errored(diag_added);
382 }
383 Ok(())
384}