cairo_lang_lowering/inline/
statements_weights.rs

1use cairo_lang_semantic::TypeId;
2use cairo_lang_utils::casts::IntoOrPanic;
3use salsa::Database;
4
5use crate::db::LoweringGroup;
6use crate::{BlockEnd, Lowered, Statement, VarUsage, VariableId};
7
8/// Trait for calculating the weight of a lowered function, for the purpose of inlining.
9pub trait InlineWeight<'db> {
10    /// The weight of calling the function.
11    fn calling_weight(&self, lowered: &Lowered<'db>) -> isize;
12    /// The weight of a statement in the lowered function.
13    fn statement_weight(&self, statement: &Statement<'db>) -> isize;
14    /// The weight of the block end in the lowered function.
15    fn block_end_weight(&self, block_end: &BlockEnd<'db>) -> isize;
16    /// The weight of the entire lowered function.
17    fn lowered_weight(&self, lowered: &Lowered<'db>) -> isize {
18        self.calling_weight(lowered)
19            + lowered
20                .blocks
21                .iter()
22                .map(|(_, block)| {
23                    block
24                        .statements
25                        .iter()
26                        .map(|statement| self.statement_weight(statement))
27                        .sum::<isize>()
28                        + self.block_end_weight(&block.end)
29                })
30                .sum::<isize>()
31    }
32}
33
34/// A simple inline weight that gives a weight of 1 to each statement and block end.
35pub struct SimpleInlineWeight;
36impl<'db> InlineWeight<'db> for SimpleInlineWeight {
37    fn calling_weight(&self, _lowered: &Lowered<'db>) -> isize {
38        0
39    }
40
41    fn statement_weight(&self, _statement: &Statement<'db>) -> isize {
42        1
43    }
44
45    fn block_end_weight(&self, _block_end: &BlockEnd<'db>) -> isize {
46        1
47    }
48}
49
50/// Try to approximate the weight of a lowered function by counting the number of CASM statements it
51/// will add to the code.
52pub struct ApproxCasmInlineWeight<'db> {
53    db: &'db dyn Database,
54    lowered: &'db Lowered<'db>,
55}
56impl<'db> ApproxCasmInlineWeight<'db> {
57    /// Create a new `ApproxCasmInlineWeight` for the given lowered function.
58    pub fn new(db: &'db dyn Database, lowered: &'db Lowered<'db>) -> Self {
59        Self { db, lowered }
60    }
61    /// Calculate the total size of the given types.
62    fn tys_total_size(&self, tys: impl IntoIterator<Item = TypeId<'db>>) -> usize {
63        tys.into_iter().map(|ty| self.db.type_size(ty)).sum()
64    }
65    /// Calculate the total size of the given variables.
66    fn vars_size<'b, I: IntoIterator<Item = &'db VariableId>>(&self, vars: I) -> usize {
67        self.tys_total_size(vars.into_iter().map(|v| self.lowered.variables[*v].ty))
68    }
69    /// Calculate the total size of the given inputs.
70    fn inputs_size<'b, I: IntoIterator<Item = &'db VarUsage<'db>>>(&self, vars: I) -> usize {
71        self.vars_size(vars.into_iter().map(|v| &v.var_id))
72    }
73}
74
75impl<'db> InlineWeight<'db> for ApproxCasmInlineWeight<'db> {
76    fn calling_weight(&self, _lowered: &Lowered<'db>) -> isize {
77        0
78    }
79    fn statement_weight(&self, statement: &Statement<'db>) -> isize {
80        match statement {
81            // TODO(orizi): Add analysis of existing compilation to provide proper approximation for
82            // libfunc sizes.
83
84            // Current approximation is only based on assuming all libfunc require preparation for
85            // their arguments.
86            Statement::Call(statement_call) => self.inputs_size(&statement_call.inputs),
87            _ => 0,
88        }
89        .into_or_panic()
90    }
91
92    fn block_end_weight(&self, block_end: &BlockEnd<'db>) -> isize {
93        match block_end {
94            // Returns are removed when the function is inlined.
95            BlockEnd::Return(..) => 0,
96            // Goto requires the size of the variables in the mappings, as these are likely to be
97            // stored for merge.
98            BlockEnd::Goto(_, r) => self.vars_size(r.keys()),
99            // The required store for the branch parameter, as well as the branch aligns.
100            BlockEnd::Match { info } => info.arms().len() + self.inputs_size(info.inputs()),
101            BlockEnd::Panic(_) | BlockEnd::NotSet => unreachable!(),
102        }
103        .into_or_panic()
104    }
105}