Skip to main content

cairo_lang_lowering/inline/
mod.rs

1#[cfg(test)]
2mod test;
3
4pub mod statements_weights;
5
6use cairo_lang_defs::diagnostic_utils::StableLocation;
7use cairo_lang_defs::ids::LanguageElementId;
8use cairo_lang_diagnostics::{Diagnostics, Maybe};
9use cairo_lang_semantic::items::functions::InlineConfiguration;
10use cairo_lang_utils::LookupIntern;
11use cairo_lang_utils::casts::IntoOrPanic;
12use cairo_lang_utils::ordered_hash_map::OrderedHashMap;
13use cairo_lang_utils::unordered_hash_map::UnorderedHashMap;
14use id_arena::Arena;
15use itertools::{Itertools, zip_eq};
16
17use crate::blocks::{Blocks, BlocksBuilder};
18use crate::db::LoweringGroup;
19use crate::diagnostic::{
20    LoweringDiagnostic, LoweringDiagnosticKind, LoweringDiagnostics, LoweringDiagnosticsBuilder,
21};
22use crate::ids::{
23    ConcreteFunctionWithBodyId, ConcreteFunctionWithBodyLongId, FunctionWithBodyId,
24    FunctionWithBodyLongId, LocationId,
25};
26use crate::optimizations::const_folding::ConstFoldingContext;
27use crate::utils::{InliningStrategy, Rebuilder, RebuilderEx};
28use crate::{
29    Block, BlockEnd, BlockId, DependencyType, Lowered, LoweringStage, Statement, StatementCall,
30    VarRemapping, Variable, VariableId,
31};
32
33pub fn get_inline_diagnostics(
34    db: &dyn LoweringGroup,
35    function_id: FunctionWithBodyId,
36) -> Maybe<Diagnostics<LoweringDiagnostic>> {
37    let inline_config = match function_id.lookup_intern(db) {
38        FunctionWithBodyLongId::Semantic(id) => db.function_declaration_inline_config(id)?,
39        FunctionWithBodyLongId::Generated { .. } => InlineConfiguration::None,
40    };
41    let mut diagnostics = LoweringDiagnostics::default();
42
43    if let InlineConfiguration::Always(_) = inline_config {
44        if db.in_cycle(function_id, crate::DependencyType::Call)? {
45            diagnostics.report(
46                function_id.base_semantic_function(db).untyped_stable_ptr(db),
47                LoweringDiagnosticKind::CannotInlineFunctionThatMightCallItself,
48            );
49        }
50    }
51
52    Ok(diagnostics.build())
53}
54
55/// Query implementation of [LoweringGroup::priv_should_inline].
56pub fn priv_should_inline(
57    db: &dyn LoweringGroup,
58    function_id: ConcreteFunctionWithBodyId,
59) -> Maybe<bool> {
60    if db.priv_never_inline(function_id)? {
61        return Ok(false);
62    }
63
64    // Breaks cycles.
65    if db.concrete_in_cycle(function_id, DependencyType::Call, LoweringStage::Monomorphized)? {
66        return Ok(false);
67    }
68
69    match (db.optimization_config().inlining_strategy, function_inline_config(db, function_id)?) {
70        (_, InlineConfiguration::Always(_)) => Ok(true),
71        (InliningStrategy::Avoid, _) | (_, InlineConfiguration::Never(_)) => Ok(false),
72        (_, InlineConfiguration::Should(_)) => Ok(true),
73        (InliningStrategy::Default, InlineConfiguration::None) => {
74            /// The default threshold for inlining small functions. Decided according to sample
75            /// contracts profiling.
76            const DEFAULT_INLINE_SMALL_FUNCTIONS_THRESHOLD: usize = 120;
77            should_inline_lowered(db, function_id, DEFAULT_INLINE_SMALL_FUNCTIONS_THRESHOLD)
78        }
79        (InliningStrategy::InlineSmallFunctions(threshold), InlineConfiguration::None) => {
80            should_inline_lowered(db, function_id, threshold)
81        }
82    }
83}
84
85/// Query implementation of [LoweringGroup::priv_never_inline].
86pub fn priv_never_inline(
87    db: &dyn LoweringGroup,
88    function_id: ConcreteFunctionWithBodyId,
89) -> Maybe<bool> {
90    Ok(matches!(function_inline_config(db, function_id)?, InlineConfiguration::Never(_)))
91}
92
93/// Query implementation of [LoweringGroup::priv_never_inline].
94pub fn function_inline_config(
95    db: &dyn LoweringGroup,
96    function_id: ConcreteFunctionWithBodyId,
97) -> Maybe<InlineConfiguration> {
98    match function_id.lookup_intern(db) {
99        ConcreteFunctionWithBodyLongId::Semantic(id) => {
100            db.function_declaration_inline_config(id.function_with_body_id(db))
101        }
102        ConcreteFunctionWithBodyLongId::Generated(_) => Ok(InlineConfiguration::None),
103        ConcreteFunctionWithBodyLongId::Specialized(specialized) => {
104            function_inline_config(db, specialized.base)
105        }
106    }
107}
108
109// A heuristic to decide if a function without an inline attribute should be inlined.
110fn should_inline_lowered(
111    db: &dyn LoweringGroup,
112    function_id: ConcreteFunctionWithBodyId,
113    inline_small_functions_threshold: usize,
114) -> Maybe<bool> {
115    let weight_of_blocks = db.estimate_size(function_id)?;
116    Ok(weight_of_blocks < inline_small_functions_threshold.into_or_panic())
117}
118/// Context for mapping ids from `lowered` to a new `Lowered` object.
119pub struct Mapper<'a> {
120    db: &'a dyn LoweringGroup,
121    variables: &'a mut Arena<Variable>,
122    lowered: &'a Lowered,
123    renamed_vars: UnorderedHashMap<VariableId, VariableId>,
124
125    outputs: Vec<VariableId>,
126    inlining_location: StableLocation,
127
128    /// An offset that is added to all the block IDs in order to translate them into the new
129    /// lowering representation.
130    block_id_offset: BlockId,
131
132    /// Return statements are replaced with goto to this block with the appropriate remapping.
133    return_block_id: BlockId,
134}
135
136impl<'a> Mapper<'a> {
137    pub fn new(
138        db: &'a dyn LoweringGroup,
139        variables: &'a mut Arena<Variable>,
140        lowered: &'a Lowered,
141        call_stmt: StatementCall,
142        block_id_offset: usize,
143    ) -> Self {
144        // The input variables need to be renamed to match the inputs to the function call.
145        let renamed_vars = UnorderedHashMap::<VariableId, VariableId>::from_iter(zip_eq(
146            lowered.parameters.iter().cloned(),
147            call_stmt.inputs.iter().map(|var_usage| var_usage.var_id),
148        ));
149
150        let inlining_location = call_stmt.location.lookup_intern(db).stable_location;
151
152        Self {
153            db,
154            variables,
155            lowered,
156            renamed_vars,
157            block_id_offset: BlockId(block_id_offset),
158            return_block_id: BlockId(block_id_offset + lowered.blocks.len()),
159            outputs: call_stmt.outputs,
160            inlining_location,
161        }
162    }
163}
164
165impl Rebuilder for Mapper<'_> {
166    /// Maps a var id from the original lowering representation to the equivalent id in the
167    /// new lowering representation.
168    /// If the variable wasn't assigned an id yet, a new id is assigned.
169    fn map_var_id(&mut self, orig_var_id: VariableId) -> VariableId {
170        *self.renamed_vars.entry(orig_var_id).or_insert_with(|| {
171            let orig_var = &self.lowered.variables[orig_var_id];
172            self.variables.alloc(Variable {
173                location: orig_var.location.inlined(self.db, self.inlining_location),
174                ..orig_var.clone()
175            })
176        })
177    }
178
179    /// Maps a block id from the original lowering representation to the equivalent id in the
180    /// new lowering representation.
181    fn map_block_id(&mut self, orig_block_id: BlockId) -> BlockId {
182        BlockId(self.block_id_offset.0 + orig_block_id.0)
183    }
184
185    /// Adds the inlining location to a location.
186    fn map_location(&mut self, location: LocationId) -> LocationId {
187        location.inlined(self.db, self.inlining_location)
188    }
189
190    fn transform_end(&mut self, end: &mut BlockEnd) {
191        match end {
192            BlockEnd::Return(returns, _location) => {
193                let remapping = VarRemapping {
194                    remapping: OrderedHashMap::from_iter(zip_eq(
195                        self.outputs.iter().cloned(),
196                        returns.iter().cloned(),
197                    )),
198                };
199                *end = BlockEnd::Goto(self.return_block_id, remapping);
200            }
201            BlockEnd::Panic(_) | BlockEnd::Goto(_, _) | BlockEnd::Match { .. } => {}
202            BlockEnd::NotSet => unreachable!(),
203        }
204    }
205}
206
207/// Inner function for applying inlining.
208///
209/// This function should be called through `apply_inlining` to remove all the lowered blocks in the
210/// error case.
211fn inner_apply_inlining(
212    db: &dyn LoweringGroup,
213    lowered: &mut Lowered,
214    calling_function_id: ConcreteFunctionWithBodyId,
215    mut enable_const_folding: bool,
216) -> Maybe<()> {
217    lowered.blocks.has_root()?;
218
219    let mut blocks = BlocksBuilder::new();
220
221    let mut stack: Vec<std::vec::IntoIter<BlockId>> = vec![
222        lowered
223            .blocks
224            .iter()
225            .map(|(_, block)| blocks.alloc(block.clone()))
226            .collect_vec()
227            .into_iter(),
228    ];
229
230    let mut const_folding_ctx =
231        ConstFoldingContext::new(db, calling_function_id, &mut lowered.variables);
232
233    enable_const_folding = enable_const_folding && !const_folding_ctx.should_skip_const_folding(db);
234
235    while let Some(mut func_blocks) = stack.pop() {
236        for block_id in func_blocks.by_ref() {
237            if enable_const_folding
238                && !const_folding_ctx
239                    .visit_block_start(block_id, |block_id| blocks.get_mut_block(block_id))
240            {
241                continue;
242            }
243
244            // Read the next block id before `blocks` is borrowed.
245            let next_block_id = blocks.len();
246            let block = blocks.get_mut_block(block_id);
247
248            let mut opt_inline_info = None;
249            for (idx, statement) in block.statements.iter_mut().enumerate() {
250                if enable_const_folding {
251                    const_folding_ctx.visit_statement(statement);
252                }
253                if let Some((call_stmt, called_func)) =
254                    should_inline(db, calling_function_id, statement)?
255                {
256                    opt_inline_info = Some((idx, call_stmt.clone(), called_func));
257                    break;
258                }
259            }
260
261            let Some((call_stmt_idx, call_stmt, called_func)) = opt_inline_info else {
262                if enable_const_folding {
263                    const_folding_ctx.visit_block_end(block_id, block);
264                }
265                // Nothing to inline in this block, go to the next block.
266                continue;
267            };
268
269            let inlined_lowered = db.lowered_body(called_func, LoweringStage::PostBaseline)?;
270            inlined_lowered.blocks.has_root()?;
271
272            // Drain the statements starting at the call to the inlined function.
273            let remaining_statements =
274                block.statements.drain(call_stmt_idx..).skip(1).collect_vec();
275
276            // Replace the end of the block with a goto to the root block of the inlined function.
277            let orig_block_end = std::mem::replace(
278                &mut block.end,
279                BlockEnd::Goto(BlockId(next_block_id), VarRemapping::default()),
280            );
281
282            if enable_const_folding {
283                const_folding_ctx.visit_block_end(block_id, block);
284            }
285
286            let mut inline_mapper = Mapper::new(
287                db,
288                const_folding_ctx.variables,
289                &inlined_lowered,
290                call_stmt,
291                next_block_id,
292            );
293
294            // Apply the mapper to the inlined blocks and add them as a contiguous chunk to the
295            // blocks builder.
296            let mut inlined_blocks_ids = inlined_lowered
297                .blocks
298                .iter()
299                .map(|(_block_id, block)| blocks.alloc(inline_mapper.rebuild_block(block)))
300                .collect_vec();
301
302            // Move the remaining statements and the original block end to a new return block.
303            let return_block_id =
304                blocks.alloc(Block { statements: remaining_statements, end: orig_block_end });
305            assert_eq!(return_block_id, inline_mapper.return_block_id);
306
307            // Append the id of the return block to the list of blocks in the inlined function.
308            // It is not part of that function, but we want to visit it right after the inlined
309            // function blocks.
310            inlined_blocks_ids.push(return_block_id);
311
312            // Return the remaining blocks from the current function to the stack and add the blocks
313            // of the inlined function to the top of the stack.
314            stack.push(func_blocks);
315            stack.push(inlined_blocks_ids.into_iter());
316            break;
317        }
318    }
319
320    lowered.blocks = blocks.build().unwrap();
321    Ok(())
322}
323
324/// Rewrites a statement and either appends it to self.statements or adds new statements to
325/// self.statements_rewrite_stack.
326fn should_inline<'a>(
327    db: &dyn LoweringGroup,
328    calling_function_id: ConcreteFunctionWithBodyId,
329    statement: &'a Statement,
330) -> Maybe<Option<(&'a StatementCall, ConcreteFunctionWithBodyId)>> {
331    if let Statement::Call(stmt) = statement {
332        if stmt.with_coupon {
333            return Ok(None);
334        }
335
336        if let Some(called_func) = stmt.function.body(db)? {
337            if let ConcreteFunctionWithBodyLongId::Specialized(specialized) =
338                calling_function_id.lookup_intern(db)
339            {
340                if specialized.base == called_func {
341                    // A specialized function should always inline its base.
342                    return Ok(Some((stmt, called_func)));
343                }
344            }
345
346            // TODO: Implement better logic to avoid inlining of destructors that call
347            // themselves.
348            if called_func != calling_function_id && db.priv_should_inline(called_func)? {
349                return Ok(Some((stmt, called_func)));
350            }
351        }
352    }
353
354    Ok(None)
355}
356
357/// Applies inlining to a lowered function.
358///
359/// Note that if const folding is enabled, the blocks must be topologically sorted.
360pub fn apply_inlining(
361    db: &dyn LoweringGroup,
362    function_id: ConcreteFunctionWithBodyId,
363    lowered: &mut Lowered,
364    enable_const_folding: bool,
365) -> Maybe<()> {
366    if let Err(diag_added) = inner_apply_inlining(db, lowered, function_id, enable_const_folding) {
367        lowered.blocks = Blocks::new_errored(diag_added);
368    }
369    Ok(())
370}