cairo_lang_lowering/inline/
mod.rs

1#[cfg(test)]
2mod test;
3
4pub mod statements_weights;
5
6use cairo_lang_defs::diagnostic_utils::StableLocation;
7use cairo_lang_defs::ids::LanguageElementId;
8use cairo_lang_diagnostics::{Diagnostics, Maybe};
9use cairo_lang_semantic::items::function_with_body::FunctionWithBodySemantic;
10use cairo_lang_semantic::items::functions::InlineConfiguration;
11use cairo_lang_utils::casts::IntoOrPanic;
12use cairo_lang_utils::ordered_hash_map::OrderedHashMap;
13use cairo_lang_utils::unordered_hash_map::UnorderedHashMap;
14use itertools::{Itertools, zip_eq};
15use salsa::Database;
16
17use crate::blocks::{Blocks, BlocksBuilder};
18use crate::db::LoweringGroup;
19use crate::diagnostic::{
20    LoweringDiagnostic, LoweringDiagnosticKind, LoweringDiagnostics, LoweringDiagnosticsBuilder,
21};
22use crate::ids::{
23    ConcreteFunctionWithBodyId, ConcreteFunctionWithBodyLongId, FunctionWithBodyId,
24    FunctionWithBodyLongId, LocationId,
25};
26use crate::optimizations::const_folding::ConstFoldingContext;
27use crate::utils::{InliningStrategy, Rebuilder, RebuilderEx};
28use crate::{
29    Block, BlockEnd, BlockId, DependencyType, Lowered, LoweringStage, Statement, StatementCall,
30    VarRemapping, Variable, VariableArena, VariableId,
31};
32
33pub fn get_inline_diagnostics<'db>(
34    db: &'db dyn Database,
35    function_id: FunctionWithBodyId<'db>,
36) -> Maybe<Diagnostics<'db, LoweringDiagnostic<'db>>> {
37    let inline_config = match function_id.long(db) {
38        FunctionWithBodyLongId::Semantic(id) => db.function_declaration_inline_config(*id)?,
39        FunctionWithBodyLongId::Generated { .. } => InlineConfiguration::None,
40    };
41    let mut diagnostics = LoweringDiagnostics::default();
42
43    if let InlineConfiguration::Always(_) = inline_config
44        && db.in_cycle(function_id, crate::DependencyType::Call)?
45    {
46        diagnostics.report(
47            function_id.base_semantic_function(db).untyped_stable_ptr(db),
48            LoweringDiagnosticKind::CannotInlineFunctionThatMightCallItself,
49        );
50    }
51
52    Ok(diagnostics.build())
53}
54
55/// Query implementation of [LoweringGroup::priv_should_inline].
56#[salsa::tracked]
57pub fn priv_should_inline<'db>(
58    db: &'db dyn Database,
59    function_id: ConcreteFunctionWithBodyId<'db>,
60) -> Maybe<bool> {
61    if db.priv_never_inline(function_id)? {
62        return Ok(false);
63    }
64
65    // Breaks cycles.
66    if db.concrete_in_cycle(function_id, DependencyType::Call, LoweringStage::Monomorphized)? {
67        return Ok(false);
68    }
69
70    match (db.optimizations().inlining_strategy(), function_inline_config(db, function_id)?) {
71        (_, InlineConfiguration::Always(_)) => Ok(true),
72        (InliningStrategy::Avoid, _) | (_, InlineConfiguration::Never(_)) => Ok(false),
73        (_, InlineConfiguration::Should(_)) => Ok(true),
74        (InliningStrategy::Default, InlineConfiguration::None) => {
75            /// The default threshold for inlining small functions. Decided according to sample
76            /// contracts profiling.
77            const DEFAULT_INLINE_SMALL_FUNCTIONS_THRESHOLD: usize = 120;
78            should_inline_lowered(db, function_id, DEFAULT_INLINE_SMALL_FUNCTIONS_THRESHOLD)
79        }
80        (InliningStrategy::InlineSmallFunctions(threshold), InlineConfiguration::None) => {
81            should_inline_lowered(db, function_id, threshold)
82        }
83    }
84}
85
86/// Query implementation of [LoweringGroup::priv_never_inline].
87#[salsa::tracked]
88pub fn priv_never_inline<'db>(
89    db: &'db dyn Database,
90    function_id: ConcreteFunctionWithBodyId<'db>,
91) -> Maybe<bool> {
92    Ok(matches!(function_inline_config(db, function_id)?, InlineConfiguration::Never(_)))
93}
94
95/// Returns the [InlineConfiguration] of a function.
96fn function_inline_config<'db>(
97    db: &'db dyn Database,
98    function_id: ConcreteFunctionWithBodyId<'db>,
99) -> Maybe<InlineConfiguration<'db>> {
100    match function_id.long(db) {
101        ConcreteFunctionWithBodyLongId::Semantic(id) => {
102            db.function_declaration_inline_config(id.function_with_body_id(db))
103        }
104        ConcreteFunctionWithBodyLongId::Generated(_) => Ok(InlineConfiguration::None),
105        ConcreteFunctionWithBodyLongId::Specialized(specialized) => {
106            function_inline_config(db, specialized.base)
107        }
108    }
109}
110
111// A heuristic to decide if a function without an inline attribute should be inlined.
112fn should_inline_lowered(
113    db: &dyn Database,
114    function_id: ConcreteFunctionWithBodyId<'_>,
115    inline_small_functions_threshold: usize,
116) -> Maybe<bool> {
117    let weight_of_blocks = db.estimate_size(function_id)?;
118    Ok(weight_of_blocks < inline_small_functions_threshold.into_or_panic())
119}
120/// Context for mapping ids from `lowered` to a new `Lowered` object.
121pub struct Mapper<'db, 'mt, 'l> {
122    db: &'db dyn Database,
123    variables: &'mt mut VariableArena<'db>,
124    lowered: &'l Lowered<'db>,
125    renamed_vars: UnorderedHashMap<VariableId, VariableId>,
126
127    outputs: Vec<VariableId>,
128    inlining_location: StableLocation<'db>,
129
130    /// An offset that is added to all the block IDs in order to translate them into the new
131    /// lowering representation.
132    block_id_offset: BlockId,
133
134    /// Return statements are replaced with goto to this block with the appropriate remapping.
135    return_block_id: BlockId,
136}
137
138impl<'db, 'mt, 'l> Mapper<'db, 'mt, 'l> {
139    pub fn new(
140        db: &'db dyn Database,
141        variables: &'mt mut VariableArena<'db>,
142        lowered: &'l Lowered<'db>,
143        call_stmt: StatementCall<'db>,
144        block_id_offset: usize,
145    ) -> Self {
146        // The input variables need to be renamed to match the inputs to the function call.
147        let renamed_vars = UnorderedHashMap::<VariableId, VariableId>::from_iter(zip_eq(
148            lowered.parameters.iter().cloned(),
149            call_stmt.inputs.iter().map(|var_usage| var_usage.var_id),
150        ));
151
152        let inlining_location = call_stmt.location.long(db).stable_location;
153
154        Self {
155            db,
156            variables,
157            lowered,
158            renamed_vars,
159            block_id_offset: BlockId(block_id_offset),
160            return_block_id: BlockId(block_id_offset + lowered.blocks.len()),
161            outputs: call_stmt.outputs,
162            inlining_location,
163        }
164    }
165}
166
167impl<'db, 'mt> Rebuilder<'db> for Mapper<'db, 'mt, '_> {
168    /// Maps a var id from the original lowering representation to the equivalent id in the
169    /// new lowering representation.
170    /// If the variable wasn't assigned an id yet, a new id is assigned.
171    fn map_var_id(&mut self, orig_var_id: VariableId) -> VariableId {
172        *self.renamed_vars.entry(orig_var_id).or_insert_with(|| {
173            let orig_var = &self.lowered.variables[orig_var_id];
174            self.variables.alloc(Variable {
175                location: orig_var.location.inlined(self.db, self.inlining_location),
176                ..orig_var.clone()
177            })
178        })
179    }
180
181    /// Maps a block id from the original lowering representation to the equivalent id in the
182    /// new lowering representation.
183    fn map_block_id(&mut self, orig_block_id: BlockId) -> BlockId {
184        BlockId(self.block_id_offset.0 + orig_block_id.0)
185    }
186
187    /// Adds the inlining location to a location.
188    fn map_location(&mut self, location: LocationId<'db>) -> LocationId<'db> {
189        location.inlined(self.db, self.inlining_location)
190    }
191
192    fn transform_end(&mut self, end: &mut BlockEnd<'db>) {
193        match end {
194            BlockEnd::Return(returns, _location) => {
195                let remapping = VarRemapping {
196                    remapping: OrderedHashMap::from_iter(zip_eq(
197                        self.outputs.iter().cloned(),
198                        returns.iter().cloned(),
199                    )),
200                };
201                *end = BlockEnd::Goto(self.return_block_id, remapping);
202            }
203            BlockEnd::Panic(_) | BlockEnd::Goto(_, _) | BlockEnd::Match { .. } => {}
204            BlockEnd::NotSet => unreachable!(),
205        }
206    }
207}
208
209/// Inner function for applying inlining.
210///
211/// This function should be called through `apply_inlining` to remove all the lowered blocks in the
212/// error case.
213fn inner_apply_inlining<'db>(
214    db: &'db dyn Database,
215    lowered: &mut Lowered<'db>,
216    calling_function_id: ConcreteFunctionWithBodyId<'db>,
217    mut enable_const_folding: bool,
218) -> Maybe<()> {
219    lowered.blocks.has_root()?;
220
221    let mut blocks: BlocksBuilder<'db> = BlocksBuilder::new();
222
223    let mut stack: Vec<std::vec::IntoIter<BlockId>> = vec![
224        lowered
225            .blocks
226            .iter()
227            .map(|(_, block)| blocks.alloc(block.clone()))
228            .collect_vec()
229            .into_iter(),
230    ];
231
232    let mut const_folding_ctx =
233        ConstFoldingContext::new(db, calling_function_id, &mut lowered.variables);
234
235    enable_const_folding = enable_const_folding && !const_folding_ctx.should_skip_const_folding(db);
236
237    while let Some(mut func_blocks) = stack.pop() {
238        for block_id in func_blocks.by_ref() {
239            let blocks = &mut blocks;
240            if enable_const_folding
241                && !const_folding_ctx.visit_block_start(block_id, |block_id| &blocks.0[block_id.0])
242            {
243                continue;
244            }
245
246            // Read the next block id before `blocks` is borrowed.
247            let next_block_id = blocks.len();
248            let block = blocks.get_mut_block(block_id);
249
250            let mut opt_inline_info = None;
251            for (idx, statement) in block.statements.iter_mut().enumerate() {
252                if enable_const_folding {
253                    const_folding_ctx.visit_statement(statement);
254                }
255                if let Some((call_stmt, called_func)) =
256                    should_inline(db, calling_function_id, statement)?
257                {
258                    opt_inline_info = Some((idx, call_stmt.clone(), called_func));
259                    break;
260                }
261            }
262
263            let Some((call_stmt_idx, call_stmt, called_func)) = opt_inline_info else {
264                if enable_const_folding {
265                    const_folding_ctx.visit_block_end(block_id, block);
266                }
267                // Nothing to inline in this block, go to the next block.
268                continue;
269            };
270
271            let inlined_lowered = db.lowered_body(called_func, LoweringStage::PostBaseline)?;
272            inlined_lowered.blocks.has_root()?;
273
274            // Drain the statements starting at the call to the inlined function.
275            let remaining_statements =
276                block.statements.drain(call_stmt_idx..).skip(1).collect_vec();
277
278            // Replace the end of the block with a goto to the root block of the inlined function.
279            let orig_block_end = std::mem::replace(
280                &mut block.end,
281                BlockEnd::Goto(BlockId(next_block_id), VarRemapping::default()),
282            );
283
284            if enable_const_folding {
285                const_folding_ctx.visit_block_end(block_id, block);
286            }
287
288            let mut inline_mapper = Mapper::new(
289                db,
290                const_folding_ctx.variables,
291                inlined_lowered,
292                call_stmt,
293                next_block_id,
294            );
295
296            // Apply the mapper to the inlined blocks and add them as a contiguous chunk to the
297            // blocks builder.
298            let mut inlined_blocks_ids = inlined_lowered
299                .blocks
300                .iter()
301                .map(|(_block_id, block)| blocks.alloc(inline_mapper.rebuild_block(block)))
302                .collect_vec();
303
304            // Move the remaining statements and the original block end to a new return block.
305            let return_block_id =
306                blocks.alloc(Block { statements: remaining_statements, end: orig_block_end });
307            assert_eq!(return_block_id, inline_mapper.return_block_id);
308
309            // Append the id of the return block to the list of blocks in the inlined function.
310            // It is not part of that function, but we want to visit it right after the inlined
311            // function blocks.
312            inlined_blocks_ids.push(return_block_id);
313
314            // Return the remaining blocks from the current function to the stack and add the blocks
315            // of the inlined function to the top of the stack.
316            stack.push(func_blocks);
317            stack.push(inlined_blocks_ids.into_iter());
318            break;
319        }
320    }
321
322    lowered.blocks = blocks.build().unwrap();
323    Ok(())
324}
325
326/// Rewrites a statement and either appends it to self.statements or adds new statements to
327/// self.statements_rewrite_stack.
328fn should_inline<'db, 'r>(
329    db: &'db dyn Database,
330    calling_function_id: ConcreteFunctionWithBodyId<'db>,
331    statement: &'r Statement<'db>,
332) -> Maybe<Option<(&'r StatementCall<'db>, ConcreteFunctionWithBodyId<'db>)>>
333where
334    'db: 'r,
335{
336    if let Statement::Call(stmt) = statement {
337        if stmt.with_coupon {
338            return Ok(None);
339        }
340
341        if let Some(called_func) = stmt.function.body(db)? {
342            if let ConcreteFunctionWithBodyLongId::Specialized(specialized) =
343                calling_function_id.long(db)
344                && specialized.base == called_func
345            {
346                // A specialized function should always inline its base.
347                return Ok(Some((stmt, called_func)));
348            }
349
350            // TODO: Implement better logic to avoid inlining of destructors that call
351            // themselves.
352            if called_func != calling_function_id && db.priv_should_inline(called_func)? {
353                return Ok(Some((stmt, called_func)));
354            }
355        }
356    }
357
358    Ok(None)
359}
360
361/// Applies inlining to a lowered function.
362///
363/// Note that if const folding is enabled, the blocks must be topologically sorted.
364pub fn apply_inlining<'db>(
365    db: &'db dyn Database,
366    function_id: ConcreteFunctionWithBodyId<'db>,
367    lowered: &mut Lowered<'db>,
368    enable_const_folding: bool,
369) -> Maybe<()> {
370    if let Err(diag_added) = inner_apply_inlining(db, lowered, function_id, enable_const_folding) {
371        lowered.blocks = Blocks::new_errored(diag_added);
372    }
373    Ok(())
374}