cairo_lang_lowering/inline/
mod.rs

1#[cfg(test)]
2mod test;
3
4pub mod statements_weights;
5
6use cairo_lang_defs::diagnostic_utils::StableLocation;
7use cairo_lang_defs::ids::LanguageElementId;
8use cairo_lang_diagnostics::{Diagnostics, Maybe};
9use cairo_lang_semantic::items::function_with_body::FunctionWithBodySemantic;
10use cairo_lang_semantic::items::functions::InlineConfiguration;
11use cairo_lang_utils::casts::IntoOrPanic;
12use cairo_lang_utils::ordered_hash_map::OrderedHashMap;
13use cairo_lang_utils::unordered_hash_map::UnorderedHashMap;
14use itertools::{Itertools, zip_eq};
15use salsa::Database;
16
17use crate::blocks::{Blocks, BlocksBuilder};
18use crate::db::LoweringGroup;
19use crate::diagnostic::{
20    LoweringDiagnostic, LoweringDiagnosticKind, LoweringDiagnostics, LoweringDiagnosticsBuilder,
21};
22use crate::ids::{
23    ConcreteFunctionWithBodyId, ConcreteFunctionWithBodyLongId, FunctionWithBodyId,
24    FunctionWithBodyLongId, LocationId,
25};
26use crate::optimizations::const_folding::ConstFoldingContext;
27use crate::utils::{InliningStrategy, Rebuilder, RebuilderEx};
28use crate::{
29    Block, BlockEnd, BlockId, DependencyType, Lowered, LoweringStage, Statement, StatementCall,
30    VarRemapping, Variable, VariableArena, VariableId,
31};
32
33pub fn get_inline_diagnostics<'db>(
34    db: &'db dyn Database,
35    function_id: FunctionWithBodyId<'db>,
36) -> Maybe<Diagnostics<'db, LoweringDiagnostic<'db>>> {
37    let inline_config = match function_id.long(db) {
38        FunctionWithBodyLongId::Semantic(id) => db.function_declaration_inline_config(*id)?,
39        FunctionWithBodyLongId::Generated { .. } => InlineConfiguration::None,
40    };
41    let mut diagnostics = LoweringDiagnostics::default();
42
43    if let InlineConfiguration::Always(_) = inline_config
44        && db.in_cycle(function_id, crate::DependencyType::Call)?
45    {
46        diagnostics.report(
47            function_id.base_semantic_function(db).untyped_stable_ptr(db),
48            LoweringDiagnosticKind::CannotInlineFunctionThatMightCallItself,
49        );
50    }
51
52    Ok(diagnostics.build())
53}
54
55/// Query implementation of [LoweringGroup::priv_should_inline].
56#[salsa::tracked]
57pub fn priv_should_inline<'db>(
58    db: &'db dyn Database,
59    function_id: ConcreteFunctionWithBodyId<'db>,
60) -> Maybe<bool> {
61    if db.priv_never_inline(function_id)? {
62        return Ok(false);
63    }
64    // Prevents inlining of functions that may call themselves, by checking if the base of the
65    // function (the function without specialization) is in a call cycle (we cannot use the
66    // specialized version as it may call the base function with different specialization which may
67    // cause long inlining chains).
68    // TODO(Tomerstarkware): allow inlining of specialized recursive functions for one level of
69    // recursion.
70    let base = match function_id.long(db) {
71        ConcreteFunctionWithBodyLongId::Semantic(_)
72        | ConcreteFunctionWithBodyLongId::Generated(_) => function_id,
73        ConcreteFunctionWithBodyLongId::Specialized(specialized) => specialized.base,
74    };
75    if db.concrete_in_cycle(base, DependencyType::Call, LoweringStage::Monomorphized)? {
76        return Ok(false);
77    }
78
79    match (db.optimizations().inlining_strategy(), function_inline_config(db, function_id)?) {
80        (_, InlineConfiguration::Always(_)) => Ok(true),
81        (InliningStrategy::Avoid, _) | (_, InlineConfiguration::Never(_)) => Ok(false),
82        (_, InlineConfiguration::Should(_)) => Ok(true),
83        (InliningStrategy::Default, InlineConfiguration::None) => {
84            /// The default threshold for inlining small functions. Decided according to sample
85            /// contracts profiling.
86            const DEFAULT_INLINE_SMALL_FUNCTIONS_THRESHOLD: usize = 120;
87            should_inline_lowered(db, function_id, DEFAULT_INLINE_SMALL_FUNCTIONS_THRESHOLD)
88        }
89        (InliningStrategy::InlineSmallFunctions(threshold), InlineConfiguration::None) => {
90            should_inline_lowered(db, function_id, threshold)
91        }
92    }
93}
94
95/// Query implementation of [LoweringGroup::priv_never_inline].
96#[salsa::tracked]
97pub fn priv_never_inline<'db>(
98    db: &'db dyn Database,
99    function_id: ConcreteFunctionWithBodyId<'db>,
100) -> Maybe<bool> {
101    Ok(matches!(function_inline_config(db, function_id)?, InlineConfiguration::Never(_)))
102}
103
104/// Returns the [InlineConfiguration] of a function.
105fn function_inline_config<'db>(
106    db: &'db dyn Database,
107    function_id: ConcreteFunctionWithBodyId<'db>,
108) -> Maybe<InlineConfiguration<'db>> {
109    match function_id.long(db) {
110        ConcreteFunctionWithBodyLongId::Semantic(id) => {
111            db.function_declaration_inline_config(id.function_with_body_id(db))
112        }
113        ConcreteFunctionWithBodyLongId::Generated(_) => Ok(InlineConfiguration::None),
114        ConcreteFunctionWithBodyLongId::Specialized(specialized) => {
115            function_inline_config(db, specialized.base)
116        }
117    }
118}
119
120// A heuristic to decide if a function without an inline attribute should be inlined.
121fn should_inline_lowered(
122    db: &dyn Database,
123    function_id: ConcreteFunctionWithBodyId<'_>,
124    inline_small_functions_threshold: usize,
125) -> Maybe<bool> {
126    let weight_of_blocks = db.estimate_size(function_id)?;
127    Ok(weight_of_blocks < inline_small_functions_threshold.into_or_panic())
128}
129/// Context for mapping ids from `lowered` to a new `Lowered` object.
130pub struct Mapper<'db, 'mt, 'l> {
131    db: &'db dyn Database,
132    variables: &'mt mut VariableArena<'db>,
133    lowered: &'l Lowered<'db>,
134    renamed_vars: UnorderedHashMap<VariableId, VariableId>,
135
136    outputs: Vec<VariableId>,
137    inlining_location: StableLocation<'db>,
138
139    /// An offset that is added to all the block IDs in order to translate them into the new
140    /// lowering representation.
141    block_id_offset: BlockId,
142
143    /// Return statements are replaced with goto to this block with the appropriate remapping.
144    return_block_id: BlockId,
145}
146
147impl<'db, 'mt, 'l> Mapper<'db, 'mt, 'l> {
148    pub fn new(
149        db: &'db dyn Database,
150        variables: &'mt mut VariableArena<'db>,
151        lowered: &'l Lowered<'db>,
152        call_stmt: StatementCall<'db>,
153        block_id_offset: usize,
154    ) -> Self {
155        // The input variables need to be renamed to match the inputs to the function call.
156        let renamed_vars = UnorderedHashMap::<VariableId, VariableId>::from_iter(zip_eq(
157            lowered.parameters.iter().cloned(),
158            call_stmt.inputs.iter().map(|var_usage| var_usage.var_id),
159        ));
160
161        let inlining_location = call_stmt.location.long(db).stable_location;
162
163        Self {
164            db,
165            variables,
166            lowered,
167            renamed_vars,
168            block_id_offset: BlockId(block_id_offset),
169            return_block_id: BlockId(block_id_offset + lowered.blocks.len()),
170            outputs: call_stmt.outputs,
171            inlining_location,
172        }
173    }
174}
175
176impl<'db, 'mt> Rebuilder<'db> for Mapper<'db, 'mt, '_> {
177    /// Maps a var id from the original lowering representation to the equivalent id in the
178    /// new lowering representation.
179    /// If the variable wasn't assigned an id yet, a new id is assigned.
180    fn map_var_id(&mut self, orig_var_id: VariableId) -> VariableId {
181        *self.renamed_vars.entry(orig_var_id).or_insert_with(|| {
182            let orig_var = &self.lowered.variables[orig_var_id];
183            self.variables.alloc(Variable {
184                location: orig_var.location.inlined(self.db, self.inlining_location),
185                ..orig_var.clone()
186            })
187        })
188    }
189
190    /// Maps a block id from the original lowering representation to the equivalent id in the
191    /// new lowering representation.
192    fn map_block_id(&mut self, orig_block_id: BlockId) -> BlockId {
193        BlockId(self.block_id_offset.0 + orig_block_id.0)
194    }
195
196    /// Adds the inlining location to a location.
197    fn map_location(&mut self, location: LocationId<'db>) -> LocationId<'db> {
198        location.inlined(self.db, self.inlining_location)
199    }
200
201    fn transform_end(&mut self, end: &mut BlockEnd<'db>) {
202        match end {
203            BlockEnd::Return(returns, _location) => {
204                let remapping = VarRemapping {
205                    remapping: OrderedHashMap::from_iter(zip_eq(
206                        self.outputs.iter().cloned(),
207                        returns.iter().cloned(),
208                    )),
209                };
210                *end = BlockEnd::Goto(self.return_block_id, remapping);
211            }
212            BlockEnd::Panic(_) | BlockEnd::Goto(_, _) | BlockEnd::Match { .. } => {}
213            BlockEnd::NotSet => unreachable!(),
214        }
215    }
216}
217
218/// Inner function for applying inlining.
219///
220/// This function should be called through `apply_inlining` to remove all the lowered blocks in the
221/// error case.
222fn inner_apply_inlining<'db>(
223    db: &'db dyn Database,
224    lowered: &mut Lowered<'db>,
225    calling_function_id: ConcreteFunctionWithBodyId<'db>,
226    mut enable_const_folding: bool,
227) -> Maybe<()> {
228    lowered.blocks.has_root()?;
229
230    let mut blocks: BlocksBuilder<'db> = BlocksBuilder::new();
231
232    let mut stack: Vec<std::vec::IntoIter<BlockId>> = vec![
233        lowered
234            .blocks
235            .iter()
236            .map(|(_, block)| blocks.alloc(block.clone()))
237            .collect_vec()
238            .into_iter(),
239    ];
240
241    let mut const_folding_ctx =
242        ConstFoldingContext::new(db, calling_function_id, &mut lowered.variables);
243
244    enable_const_folding = enable_const_folding && !const_folding_ctx.should_skip_const_folding(db);
245
246    while let Some(mut func_blocks) = stack.pop() {
247        for block_id in func_blocks.by_ref() {
248            let blocks = &mut blocks;
249            if enable_const_folding
250                && !const_folding_ctx.visit_block_start(block_id, |block_id| &blocks.0[block_id.0])
251            {
252                continue;
253            }
254
255            // Read the next block id before `blocks` is borrowed.
256            let next_block_id = blocks.len();
257            let block = blocks.get_mut_block(block_id);
258
259            let mut opt_inline_info = None;
260            for (idx, statement) in block.statements.iter_mut().enumerate() {
261                if enable_const_folding {
262                    const_folding_ctx.visit_statement(statement);
263                }
264                if let Some((call_stmt, called_func)) =
265                    should_inline(db, calling_function_id, statement)?
266                {
267                    opt_inline_info = Some((idx, call_stmt.clone(), called_func));
268                    break;
269                }
270            }
271
272            let Some((call_stmt_idx, call_stmt, called_func)) = opt_inline_info else {
273                if enable_const_folding {
274                    const_folding_ctx.visit_block_end(block_id, block);
275                }
276                // Nothing to inline in this block, go to the next block.
277                continue;
278            };
279
280            let inlined_lowered = db.lowered_body(called_func, LoweringStage::PostBaseline)?;
281            inlined_lowered.blocks.has_root()?;
282
283            // Drain the statements starting at the call to the inlined function.
284            let remaining_statements =
285                block.statements.drain(call_stmt_idx..).skip(1).collect_vec();
286
287            // Replace the end of the block with a goto to the root block of the inlined function.
288            let orig_block_end = std::mem::replace(
289                &mut block.end,
290                BlockEnd::Goto(BlockId(next_block_id), VarRemapping::default()),
291            );
292
293            if enable_const_folding {
294                const_folding_ctx.visit_block_end(block_id, block);
295            }
296
297            let mut inline_mapper = Mapper::new(
298                db,
299                const_folding_ctx.variables,
300                inlined_lowered,
301                call_stmt,
302                next_block_id,
303            );
304
305            // Apply the mapper to the inlined blocks and add them as a contiguous chunk to the
306            // blocks builder.
307            let mut inlined_blocks_ids = inlined_lowered
308                .blocks
309                .iter()
310                .map(|(_block_id, block)| blocks.alloc(inline_mapper.rebuild_block(block)))
311                .collect_vec();
312
313            // Move the remaining statements and the original block end to a new return block.
314            let return_block_id =
315                blocks.alloc(Block { statements: remaining_statements, end: orig_block_end });
316            assert_eq!(return_block_id, inline_mapper.return_block_id);
317
318            // Append the id of the return block to the list of blocks in the inlined function.
319            // It is not part of that function, but we want to visit it right after the inlined
320            // function blocks.
321            inlined_blocks_ids.push(return_block_id);
322
323            // Return the remaining blocks from the current function to the stack and add the blocks
324            // of the inlined function to the top of the stack.
325            stack.push(func_blocks);
326            stack.push(inlined_blocks_ids.into_iter());
327            break;
328        }
329    }
330
331    lowered.blocks = blocks.build().unwrap();
332    Ok(())
333}
334
335/// Rewrites a statement and either appends it to self.statements or adds new statements to
336/// self.statements_rewrite_stack.
337fn should_inline<'db, 'r>(
338    db: &'db dyn Database,
339    calling_function_id: ConcreteFunctionWithBodyId<'db>,
340    statement: &'r Statement<'db>,
341) -> Maybe<Option<(&'r StatementCall<'db>, ConcreteFunctionWithBodyId<'db>)>>
342where
343    'db: 'r,
344{
345    if let Statement::Call(stmt) = statement {
346        if stmt.with_coupon {
347            return Ok(None);
348        }
349
350        if let Some(called_func) = stmt.function.body(db)? {
351            if let ConcreteFunctionWithBodyLongId::Specialized(specialized) =
352                calling_function_id.long(db)
353                && specialized.base == called_func
354                && stmt.is_specialization_base_call
355            {
356                // A specialized function should always inline its base.
357                return Ok(Some((stmt, called_func)));
358            }
359
360            // TODO: Implement better logic to avoid inlining of destructors that call
361            // themselves.
362            if called_func != calling_function_id && db.priv_should_inline(called_func)? {
363                return Ok(Some((stmt, called_func)));
364            }
365        }
366    }
367
368    Ok(None)
369}
370
371/// Applies inlining to a lowered function.
372///
373/// Note that if const folding is enabled, the blocks must be topologically sorted.
374pub fn apply_inlining<'db>(
375    db: &'db dyn Database,
376    function_id: ConcreteFunctionWithBodyId<'db>,
377    lowered: &mut Lowered<'db>,
378    enable_const_folding: bool,
379) -> Maybe<()> {
380    if let Err(diag_added) = inner_apply_inlining(db, lowered, function_id, enable_const_folding) {
381        lowered.blocks = Blocks::new_errored(diag_added);
382    }
383    Ok(())
384}