cairo_lang_lowering/inline/
mod.rs1#[cfg(test)]
2mod test;
3
4pub mod statements_weights;
5
6use cairo_lang_defs::diagnostic_utils::StableLocation;
7use cairo_lang_defs::ids::LanguageElementId;
8use cairo_lang_diagnostics::{Diagnostics, Maybe};
9use cairo_lang_semantic::items::functions::InlineConfiguration;
10use cairo_lang_utils::LookupIntern;
11use cairo_lang_utils::casts::IntoOrPanic;
12use cairo_lang_utils::ordered_hash_map::OrderedHashMap;
13use cairo_lang_utils::unordered_hash_map::UnorderedHashMap;
14use id_arena::Arena;
15use itertools::{Itertools, zip_eq};
16
17use crate::blocks::{Blocks, BlocksBuilder};
18use crate::db::LoweringGroup;
19use crate::diagnostic::{
20 LoweringDiagnostic, LoweringDiagnosticKind, LoweringDiagnostics, LoweringDiagnosticsBuilder,
21};
22use crate::ids::{
23 ConcreteFunctionWithBodyId, ConcreteFunctionWithBodyLongId, FunctionWithBodyId,
24 FunctionWithBodyLongId, LocationId,
25};
26use crate::optimizations::const_folding::ConstFoldingContext;
27use crate::utils::{InliningStrategy, Rebuilder, RebuilderEx};
28use crate::{
29 Block, BlockEnd, BlockId, DependencyType, Lowered, LoweringStage, Statement, StatementCall,
30 VarRemapping, Variable, VariableId,
31};
32
33pub fn get_inline_diagnostics(
34 db: &dyn LoweringGroup,
35 function_id: FunctionWithBodyId,
36) -> Maybe<Diagnostics<LoweringDiagnostic>> {
37 let inline_config = match function_id.lookup_intern(db) {
38 FunctionWithBodyLongId::Semantic(id) => db.function_declaration_inline_config(id)?,
39 FunctionWithBodyLongId::Generated { .. } => InlineConfiguration::None,
40 };
41 let mut diagnostics = LoweringDiagnostics::default();
42
43 if let InlineConfiguration::Always(_) = inline_config {
44 if db.in_cycle(function_id, crate::DependencyType::Call)? {
45 diagnostics.report(
46 function_id.base_semantic_function(db).untyped_stable_ptr(db),
47 LoweringDiagnosticKind::CannotInlineFunctionThatMightCallItself,
48 );
49 }
50 }
51
52 Ok(diagnostics.build())
53}
54
55pub fn priv_should_inline(
57 db: &dyn LoweringGroup,
58 function_id: ConcreteFunctionWithBodyId,
59) -> Maybe<bool> {
60 if db.priv_never_inline(function_id)? {
61 return Ok(false);
62 }
63
64 if db.concrete_in_cycle(function_id, DependencyType::Call, LoweringStage::Monomorphized)? {
66 return Ok(false);
67 }
68
69 match (db.optimization_config().inlining_strategy, function_inline_config(db, function_id)?) {
70 (_, InlineConfiguration::Always(_)) => Ok(true),
71 (InliningStrategy::Avoid, _) | (_, InlineConfiguration::Never(_)) => Ok(false),
72 (_, InlineConfiguration::Should(_)) => Ok(true),
73 (InliningStrategy::Default, InlineConfiguration::None) => {
74 const DEFAULT_INLINE_SMALL_FUNCTIONS_THRESHOLD: usize = 120;
77 should_inline_lowered(db, function_id, DEFAULT_INLINE_SMALL_FUNCTIONS_THRESHOLD)
78 }
79 (InliningStrategy::InlineSmallFunctions(threshold), InlineConfiguration::None) => {
80 should_inline_lowered(db, function_id, threshold)
81 }
82 }
83}
84
85pub fn priv_never_inline(
87 db: &dyn LoweringGroup,
88 function_id: ConcreteFunctionWithBodyId,
89) -> Maybe<bool> {
90 Ok(matches!(function_inline_config(db, function_id)?, InlineConfiguration::Never(_)))
91}
92
93pub fn function_inline_config(
95 db: &dyn LoweringGroup,
96 function_id: ConcreteFunctionWithBodyId,
97) -> Maybe<InlineConfiguration> {
98 match function_id.lookup_intern(db) {
99 ConcreteFunctionWithBodyLongId::Semantic(id) => {
100 db.function_declaration_inline_config(id.function_with_body_id(db))
101 }
102 ConcreteFunctionWithBodyLongId::Generated(_) => Ok(InlineConfiguration::None),
103 ConcreteFunctionWithBodyLongId::Specialized(specialized) => {
104 function_inline_config(db, specialized.base)
105 }
106 }
107}
108
109fn should_inline_lowered(
111 db: &dyn LoweringGroup,
112 function_id: ConcreteFunctionWithBodyId,
113 inline_small_functions_threshold: usize,
114) -> Maybe<bool> {
115 let weight_of_blocks = db.estimate_size(function_id)?;
116 Ok(weight_of_blocks < inline_small_functions_threshold.into_or_panic())
117}
118pub struct Mapper<'a> {
120 db: &'a dyn LoweringGroup,
121 variables: &'a mut Arena<Variable>,
122 lowered: &'a Lowered,
123 renamed_vars: UnorderedHashMap<VariableId, VariableId>,
124
125 outputs: Vec<VariableId>,
126 inlining_location: StableLocation,
127
128 block_id_offset: BlockId,
131
132 return_block_id: BlockId,
134}
135
136impl<'a> Mapper<'a> {
137 pub fn new(
138 db: &'a dyn LoweringGroup,
139 variables: &'a mut Arena<Variable>,
140 lowered: &'a Lowered,
141 call_stmt: StatementCall,
142 block_id_offset: usize,
143 ) -> Self {
144 let renamed_vars = UnorderedHashMap::<VariableId, VariableId>::from_iter(zip_eq(
146 lowered.parameters.iter().cloned(),
147 call_stmt.inputs.iter().map(|var_usage| var_usage.var_id),
148 ));
149
150 let inlining_location = call_stmt.location.lookup_intern(db).stable_location;
151
152 Self {
153 db,
154 variables,
155 lowered,
156 renamed_vars,
157 block_id_offset: BlockId(block_id_offset),
158 return_block_id: BlockId(block_id_offset + lowered.blocks.len()),
159 outputs: call_stmt.outputs,
160 inlining_location,
161 }
162 }
163}
164
165impl Rebuilder for Mapper<'_> {
166 fn map_var_id(&mut self, orig_var_id: VariableId) -> VariableId {
170 *self.renamed_vars.entry(orig_var_id).or_insert_with(|| {
171 let orig_var = &self.lowered.variables[orig_var_id];
172 self.variables.alloc(Variable {
173 location: orig_var.location.inlined(self.db, self.inlining_location),
174 ..orig_var.clone()
175 })
176 })
177 }
178
179 fn map_block_id(&mut self, orig_block_id: BlockId) -> BlockId {
182 BlockId(self.block_id_offset.0 + orig_block_id.0)
183 }
184
185 fn map_location(&mut self, location: LocationId) -> LocationId {
187 location.inlined(self.db, self.inlining_location)
188 }
189
190 fn transform_end(&mut self, end: &mut BlockEnd) {
191 match end {
192 BlockEnd::Return(returns, _location) => {
193 let remapping = VarRemapping {
194 remapping: OrderedHashMap::from_iter(zip_eq(
195 self.outputs.iter().cloned(),
196 returns.iter().cloned(),
197 )),
198 };
199 *end = BlockEnd::Goto(self.return_block_id, remapping);
200 }
201 BlockEnd::Panic(_) | BlockEnd::Goto(_, _) | BlockEnd::Match { .. } => {}
202 BlockEnd::NotSet => unreachable!(),
203 }
204 }
205}
206
207fn inner_apply_inlining(
212 db: &dyn LoweringGroup,
213 lowered: &mut Lowered,
214 calling_function_id: ConcreteFunctionWithBodyId,
215 mut enable_const_folding: bool,
216) -> Maybe<()> {
217 lowered.blocks.has_root()?;
218
219 let mut blocks = BlocksBuilder::new();
220
221 let mut stack: Vec<std::vec::IntoIter<BlockId>> = vec![
222 lowered
223 .blocks
224 .iter()
225 .map(|(_, block)| blocks.alloc(block.clone()))
226 .collect_vec()
227 .into_iter(),
228 ];
229
230 let mut const_folding_ctx =
231 ConstFoldingContext::new(db, calling_function_id, &mut lowered.variables);
232
233 enable_const_folding = enable_const_folding && !const_folding_ctx.should_skip_const_folding(db);
234
235 while let Some(mut func_blocks) = stack.pop() {
236 for block_id in func_blocks.by_ref() {
237 if enable_const_folding
238 && !const_folding_ctx
239 .visit_block_start(block_id, |block_id| blocks.get_mut_block(block_id))
240 {
241 continue;
242 }
243
244 let next_block_id = blocks.len();
246 let block = blocks.get_mut_block(block_id);
247
248 let mut opt_inline_info = None;
249 for (idx, statement) in block.statements.iter_mut().enumerate() {
250 if enable_const_folding {
251 const_folding_ctx.visit_statement(statement);
252 }
253 if let Some((call_stmt, called_func)) =
254 should_inline(db, calling_function_id, statement)?
255 {
256 opt_inline_info = Some((idx, call_stmt.clone(), called_func));
257 break;
258 }
259 }
260
261 let Some((call_stmt_idx, call_stmt, called_func)) = opt_inline_info else {
262 if enable_const_folding {
263 const_folding_ctx.visit_block_end(block_id, block);
264 }
265 continue;
267 };
268
269 let inlined_lowered = db.lowered_body(called_func, LoweringStage::PostBaseline)?;
270 inlined_lowered.blocks.has_root()?;
271
272 let remaining_statements =
274 block.statements.drain(call_stmt_idx..).skip(1).collect_vec();
275
276 let orig_block_end = std::mem::replace(
278 &mut block.end,
279 BlockEnd::Goto(BlockId(next_block_id), VarRemapping::default()),
280 );
281
282 if enable_const_folding {
283 const_folding_ctx.visit_block_end(block_id, block);
284 }
285
286 let mut inline_mapper = Mapper::new(
287 db,
288 const_folding_ctx.variables,
289 &inlined_lowered,
290 call_stmt,
291 next_block_id,
292 );
293
294 let mut inlined_blocks_ids = inlined_lowered
297 .blocks
298 .iter()
299 .map(|(_block_id, block)| blocks.alloc(inline_mapper.rebuild_block(block)))
300 .collect_vec();
301
302 let return_block_id =
304 blocks.alloc(Block { statements: remaining_statements, end: orig_block_end });
305 assert_eq!(return_block_id, inline_mapper.return_block_id);
306
307 inlined_blocks_ids.push(return_block_id);
311
312 stack.push(func_blocks);
315 stack.push(inlined_blocks_ids.into_iter());
316 break;
317 }
318 }
319
320 lowered.blocks = blocks.build().unwrap();
321 Ok(())
322}
323
324fn should_inline<'a>(
327 db: &dyn LoweringGroup,
328 calling_function_id: ConcreteFunctionWithBodyId,
329 statement: &'a Statement,
330) -> Maybe<Option<(&'a StatementCall, ConcreteFunctionWithBodyId)>> {
331 if let Statement::Call(stmt) = statement {
332 if stmt.with_coupon {
333 return Ok(None);
334 }
335
336 if let Some(called_func) = stmt.function.body(db)? {
337 if let ConcreteFunctionWithBodyLongId::Specialized(specialized) =
338 calling_function_id.lookup_intern(db)
339 {
340 if specialized.base == called_func {
341 return Ok(Some((stmt, called_func)));
343 }
344 }
345
346 if called_func != calling_function_id && db.priv_should_inline(called_func)? {
349 return Ok(Some((stmt, called_func)));
350 }
351 }
352 }
353
354 Ok(None)
355}
356
357pub fn apply_inlining(
361 db: &dyn LoweringGroup,
362 function_id: ConcreteFunctionWithBodyId,
363 lowered: &mut Lowered,
364 enable_const_folding: bool,
365) -> Maybe<()> {
366 if let Err(diag_added) = inner_apply_inlining(db, lowered, function_id, enable_const_folding) {
367 lowered.blocks = Blocks::new_errored(diag_added);
368 }
369 Ok(())
370}