cairo_lang_lowering/inline/
mod.rs1#[cfg(test)]
2mod test;
3
4pub mod statements_weights;
5
6use cairo_lang_defs::diagnostic_utils::StableLocation;
7use cairo_lang_defs::ids::LanguageElementId;
8use cairo_lang_diagnostics::{Diagnostics, Maybe};
9use cairo_lang_semantic::items::function_with_body::FunctionWithBodySemantic;
10use cairo_lang_semantic::items::functions::InlineConfiguration;
11use cairo_lang_utils::casts::IntoOrPanic;
12use cairo_lang_utils::ordered_hash_map::OrderedHashMap;
13use cairo_lang_utils::unordered_hash_map::UnorderedHashMap;
14use itertools::{Itertools, zip_eq};
15use salsa::Database;
16
17use crate::blocks::{Blocks, BlocksBuilder};
18use crate::db::LoweringGroup;
19use crate::diagnostic::{
20 LoweringDiagnostic, LoweringDiagnosticKind, LoweringDiagnostics, LoweringDiagnosticsBuilder,
21};
22use crate::ids::{
23 ConcreteFunctionWithBodyId, ConcreteFunctionWithBodyLongId, FunctionWithBodyId,
24 FunctionWithBodyLongId, LocationId,
25};
26use crate::optimizations::const_folding::ConstFoldingContext;
27use crate::utils::{InliningStrategy, Rebuilder, RebuilderEx};
28use crate::{
29 Block, BlockEnd, BlockId, DependencyType, Lowered, LoweringStage, Statement, StatementCall,
30 VarRemapping, Variable, VariableArena, VariableId,
31};
32
33pub fn get_inline_diagnostics<'db>(
34 db: &'db dyn Database,
35 function_id: FunctionWithBodyId<'db>,
36) -> Maybe<Diagnostics<'db, LoweringDiagnostic<'db>>> {
37 let inline_config = match function_id.long(db) {
38 FunctionWithBodyLongId::Semantic(id) => db.function_declaration_inline_config(*id)?,
39 FunctionWithBodyLongId::Generated { .. } => InlineConfiguration::None,
40 };
41 let mut diagnostics = LoweringDiagnostics::default();
42
43 if let InlineConfiguration::Always(_) = inline_config
44 && db.in_cycle(function_id, crate::DependencyType::Call)?
45 {
46 diagnostics.report(
47 function_id.base_semantic_function(db).untyped_stable_ptr(db),
48 LoweringDiagnosticKind::CannotInlineFunctionThatMightCallItself,
49 );
50 }
51
52 Ok(diagnostics.build())
53}
54
55#[salsa::tracked]
57pub fn priv_should_inline<'db>(
58 db: &'db dyn Database,
59 function_id: ConcreteFunctionWithBodyId<'db>,
60) -> Maybe<bool> {
61 if db.priv_never_inline(function_id)? {
62 return Ok(false);
63 }
64
65 if db.concrete_in_cycle(function_id, DependencyType::Call, LoweringStage::Monomorphized)? {
67 return Ok(false);
68 }
69
70 match (db.optimizations().inlining_strategy(), function_inline_config(db, function_id)?) {
71 (_, InlineConfiguration::Always(_)) => Ok(true),
72 (InliningStrategy::Avoid, _) | (_, InlineConfiguration::Never(_)) => Ok(false),
73 (_, InlineConfiguration::Should(_)) => Ok(true),
74 (InliningStrategy::Default, InlineConfiguration::None) => {
75 const DEFAULT_INLINE_SMALL_FUNCTIONS_THRESHOLD: usize = 120;
78 should_inline_lowered(db, function_id, DEFAULT_INLINE_SMALL_FUNCTIONS_THRESHOLD)
79 }
80 (InliningStrategy::InlineSmallFunctions(threshold), InlineConfiguration::None) => {
81 should_inline_lowered(db, function_id, threshold)
82 }
83 }
84}
85
86#[salsa::tracked]
88pub fn priv_never_inline<'db>(
89 db: &'db dyn Database,
90 function_id: ConcreteFunctionWithBodyId<'db>,
91) -> Maybe<bool> {
92 Ok(matches!(function_inline_config(db, function_id)?, InlineConfiguration::Never(_)))
93}
94
95fn function_inline_config<'db>(
97 db: &'db dyn Database,
98 function_id: ConcreteFunctionWithBodyId<'db>,
99) -> Maybe<InlineConfiguration<'db>> {
100 match function_id.long(db) {
101 ConcreteFunctionWithBodyLongId::Semantic(id) => {
102 db.function_declaration_inline_config(id.function_with_body_id(db))
103 }
104 ConcreteFunctionWithBodyLongId::Generated(_) => Ok(InlineConfiguration::None),
105 ConcreteFunctionWithBodyLongId::Specialized(specialized) => {
106 function_inline_config(db, specialized.base)
107 }
108 }
109}
110
111fn should_inline_lowered(
113 db: &dyn Database,
114 function_id: ConcreteFunctionWithBodyId<'_>,
115 inline_small_functions_threshold: usize,
116) -> Maybe<bool> {
117 let weight_of_blocks = db.estimate_size(function_id)?;
118 Ok(weight_of_blocks < inline_small_functions_threshold.into_or_panic())
119}
120pub struct Mapper<'db, 'mt, 'l> {
122 db: &'db dyn Database,
123 variables: &'mt mut VariableArena<'db>,
124 lowered: &'l Lowered<'db>,
125 renamed_vars: UnorderedHashMap<VariableId, VariableId>,
126
127 outputs: Vec<VariableId>,
128 inlining_location: StableLocation<'db>,
129
130 block_id_offset: BlockId,
133
134 return_block_id: BlockId,
136}
137
138impl<'db, 'mt, 'l> Mapper<'db, 'mt, 'l> {
139 pub fn new(
140 db: &'db dyn Database,
141 variables: &'mt mut VariableArena<'db>,
142 lowered: &'l Lowered<'db>,
143 call_stmt: StatementCall<'db>,
144 block_id_offset: usize,
145 ) -> Self {
146 let renamed_vars = UnorderedHashMap::<VariableId, VariableId>::from_iter(zip_eq(
148 lowered.parameters.iter().cloned(),
149 call_stmt.inputs.iter().map(|var_usage| var_usage.var_id),
150 ));
151
152 let inlining_location = call_stmt.location.long(db).stable_location;
153
154 Self {
155 db,
156 variables,
157 lowered,
158 renamed_vars,
159 block_id_offset: BlockId(block_id_offset),
160 return_block_id: BlockId(block_id_offset + lowered.blocks.len()),
161 outputs: call_stmt.outputs,
162 inlining_location,
163 }
164 }
165}
166
167impl<'db, 'mt> Rebuilder<'db> for Mapper<'db, 'mt, '_> {
168 fn map_var_id(&mut self, orig_var_id: VariableId) -> VariableId {
172 *self.renamed_vars.entry(orig_var_id).or_insert_with(|| {
173 let orig_var = &self.lowered.variables[orig_var_id];
174 self.variables.alloc(Variable {
175 location: orig_var.location.inlined(self.db, self.inlining_location),
176 ..orig_var.clone()
177 })
178 })
179 }
180
181 fn map_block_id(&mut self, orig_block_id: BlockId) -> BlockId {
184 BlockId(self.block_id_offset.0 + orig_block_id.0)
185 }
186
187 fn map_location(&mut self, location: LocationId<'db>) -> LocationId<'db> {
189 location.inlined(self.db, self.inlining_location)
190 }
191
192 fn transform_end(&mut self, end: &mut BlockEnd<'db>) {
193 match end {
194 BlockEnd::Return(returns, _location) => {
195 let remapping = VarRemapping {
196 remapping: OrderedHashMap::from_iter(zip_eq(
197 self.outputs.iter().cloned(),
198 returns.iter().cloned(),
199 )),
200 };
201 *end = BlockEnd::Goto(self.return_block_id, remapping);
202 }
203 BlockEnd::Panic(_) | BlockEnd::Goto(_, _) | BlockEnd::Match { .. } => {}
204 BlockEnd::NotSet => unreachable!(),
205 }
206 }
207}
208
209fn inner_apply_inlining<'db>(
214 db: &'db dyn Database,
215 lowered: &mut Lowered<'db>,
216 calling_function_id: ConcreteFunctionWithBodyId<'db>,
217 mut enable_const_folding: bool,
218) -> Maybe<()> {
219 lowered.blocks.has_root()?;
220
221 let mut blocks: BlocksBuilder<'db> = BlocksBuilder::new();
222
223 let mut stack: Vec<std::vec::IntoIter<BlockId>> = vec![
224 lowered
225 .blocks
226 .iter()
227 .map(|(_, block)| blocks.alloc(block.clone()))
228 .collect_vec()
229 .into_iter(),
230 ];
231
232 let mut const_folding_ctx =
233 ConstFoldingContext::new(db, calling_function_id, &mut lowered.variables);
234
235 enable_const_folding = enable_const_folding && !const_folding_ctx.should_skip_const_folding(db);
236
237 while let Some(mut func_blocks) = stack.pop() {
238 for block_id in func_blocks.by_ref() {
239 let blocks = &mut blocks;
240 if enable_const_folding
241 && !const_folding_ctx.visit_block_start(block_id, |block_id| &blocks.0[block_id.0])
242 {
243 continue;
244 }
245
246 let next_block_id = blocks.len();
248 let block = blocks.get_mut_block(block_id);
249
250 let mut opt_inline_info = None;
251 for (idx, statement) in block.statements.iter_mut().enumerate() {
252 if enable_const_folding {
253 const_folding_ctx.visit_statement(statement);
254 }
255 if let Some((call_stmt, called_func)) =
256 should_inline(db, calling_function_id, statement)?
257 {
258 opt_inline_info = Some((idx, call_stmt.clone(), called_func));
259 break;
260 }
261 }
262
263 let Some((call_stmt_idx, call_stmt, called_func)) = opt_inline_info else {
264 if enable_const_folding {
265 const_folding_ctx.visit_block_end(block_id, block);
266 }
267 continue;
269 };
270
271 let inlined_lowered = db.lowered_body(called_func, LoweringStage::PostBaseline)?;
272 inlined_lowered.blocks.has_root()?;
273
274 let remaining_statements =
276 block.statements.drain(call_stmt_idx..).skip(1).collect_vec();
277
278 let orig_block_end = std::mem::replace(
280 &mut block.end,
281 BlockEnd::Goto(BlockId(next_block_id), VarRemapping::default()),
282 );
283
284 if enable_const_folding {
285 const_folding_ctx.visit_block_end(block_id, block);
286 }
287
288 let mut inline_mapper = Mapper::new(
289 db,
290 const_folding_ctx.variables,
291 inlined_lowered,
292 call_stmt,
293 next_block_id,
294 );
295
296 let mut inlined_blocks_ids = inlined_lowered
299 .blocks
300 .iter()
301 .map(|(_block_id, block)| blocks.alloc(inline_mapper.rebuild_block(block)))
302 .collect_vec();
303
304 let return_block_id =
306 blocks.alloc(Block { statements: remaining_statements, end: orig_block_end });
307 assert_eq!(return_block_id, inline_mapper.return_block_id);
308
309 inlined_blocks_ids.push(return_block_id);
313
314 stack.push(func_blocks);
317 stack.push(inlined_blocks_ids.into_iter());
318 break;
319 }
320 }
321
322 lowered.blocks = blocks.build().unwrap();
323 Ok(())
324}
325
326fn should_inline<'db, 'r>(
329 db: &'db dyn Database,
330 calling_function_id: ConcreteFunctionWithBodyId<'db>,
331 statement: &'r Statement<'db>,
332) -> Maybe<Option<(&'r StatementCall<'db>, ConcreteFunctionWithBodyId<'db>)>>
333where
334 'db: 'r,
335{
336 if let Statement::Call(stmt) = statement {
337 if stmt.with_coupon {
338 return Ok(None);
339 }
340
341 if let Some(called_func) = stmt.function.body(db)? {
342 if let ConcreteFunctionWithBodyLongId::Specialized(specialized) =
343 calling_function_id.long(db)
344 && specialized.base == called_func
345 {
346 return Ok(Some((stmt, called_func)));
348 }
349
350 if called_func != calling_function_id && db.priv_should_inline(called_func)? {
353 return Ok(Some((stmt, called_func)));
354 }
355 }
356 }
357
358 Ok(None)
359}
360
361pub fn apply_inlining<'db>(
365 db: &'db dyn Database,
366 function_id: ConcreteFunctionWithBodyId<'db>,
367 lowered: &mut Lowered<'db>,
368 enable_const_folding: bool,
369) -> Maybe<()> {
370 if let Err(diag_added) = inner_apply_inlining(db, lowered, function_id, enable_const_folding) {
371 lowered.blocks = Blocks::new_errored(diag_added);
372 }
373 Ok(())
374}