cairo_lang_lowering/inline/
statements_weights.rs

1use cairo_lang_semantic::TypeId;
2use cairo_lang_utils::casts::IntoOrPanic;
3
4use crate::db::LoweringGroup;
5use crate::{BlockEnd, Lowered, Statement, VarUsage, VariableId};
6
7/// Trait for calculating the weight of a lowered function, for the purpose of inlining.
8pub trait InlineWeight {
9    /// The weight of calling the function.
10    fn calling_weight(&self, lowered: &Lowered) -> isize;
11    /// The weight of a statement in the lowered function.
12    fn statement_weight(&self, statement: &Statement) -> isize;
13    /// The weight of the block end in the lowered function.
14    fn block_end_weight(&self, block_end: &BlockEnd) -> isize;
15    /// The weight of the entire lowered function.
16    fn lowered_weight(&self, lowered: &Lowered) -> isize {
17        self.calling_weight(lowered)
18            + lowered
19                .blocks
20                .iter()
21                .map(|(_, block)| {
22                    block
23                        .statements
24                        .iter()
25                        .map(|statement| self.statement_weight(statement))
26                        .sum::<isize>()
27                        + self.block_end_weight(&block.end)
28                })
29                .sum::<isize>()
30    }
31}
32
33/// A simple inline weight that gives a weight of 1 to each statement and block end.
34pub struct SimpleInlineWeight;
35impl InlineWeight for SimpleInlineWeight {
36    fn calling_weight(&self, _lowered: &Lowered) -> isize {
37        0
38    }
39
40    fn statement_weight(&self, _statement: &Statement) -> isize {
41        1
42    }
43
44    fn block_end_weight(&self, _block_end: &BlockEnd) -> isize {
45        1
46    }
47}
48
49/// Try to approximate the weight of a lowered function by counting the number of casm statements it
50/// will add to the code.
51pub struct ApproxCasmInlineWeight<'a> {
52    db: &'a dyn LoweringGroup,
53    lowered: &'a Lowered,
54}
55impl<'a> ApproxCasmInlineWeight<'a> {
56    /// Create a new `ApproxCasmInlineWeight` for the given lowered function.
57    pub fn new(db: &'a dyn LoweringGroup, lowered: &'a Lowered) -> Self {
58        Self { db, lowered }
59    }
60    /// Calculate the total size of the given types.
61    fn tys_total_size(&self, tys: impl IntoIterator<Item = TypeId>) -> usize {
62        tys.into_iter().map(|ty| self.db.type_size(ty)).sum()
63    }
64    /// Calculate the total size of the given variables.
65    fn vars_size<'b, I: IntoIterator<Item = &'b VariableId>>(&self, vars: I) -> usize {
66        self.tys_total_size(vars.into_iter().map(|v| self.lowered.variables[*v].ty))
67    }
68    /// Calculate the total size of the given inputs.
69    fn inputs_size<'b, I: IntoIterator<Item = &'b VarUsage>>(&self, vars: I) -> usize {
70        self.vars_size(vars.into_iter().map(|v| &v.var_id))
71    }
72}
73
74impl InlineWeight for ApproxCasmInlineWeight<'_> {
75    fn calling_weight(&self, _lowered: &Lowered) -> isize {
76        0
77    }
78    fn statement_weight(&self, statement: &Statement) -> isize {
79        match statement {
80            // TODO(orizi): Add analysis of existing compilation to provide proper approximation for
81            // libfunc sizes.
82
83            // Current approximation is only based on assuming all libfunc require preparation for
84            // their arguments.
85            Statement::Call(statement_call) => self.inputs_size(&statement_call.inputs),
86            _ => 0,
87        }
88        .into_or_panic()
89    }
90
91    fn block_end_weight(&self, block_end: &BlockEnd) -> isize {
92        match block_end {
93            // Return are removed when the function is inlined.
94            BlockEnd::Return(..) => 0,
95            // Goto requires the size of the variables in the mappings, as these are likely to be
96            // stored for merge.
97            BlockEnd::Goto(_, r) => self.vars_size(r.keys()),
98            // The required store for the branch parameter, as well as the branch aligns.
99            BlockEnd::Match { info } => info.arms().len() + self.inputs_size(info.inputs()),
100            BlockEnd::Panic(_) | BlockEnd::NotSet => unreachable!(),
101        }
102        .into_or_panic()
103    }
104}