cairo_lang_lowering/
utils.rs

1use cairo_lang_utils::ordered_hash_map::OrderedHashMap;
2
3use crate::ids::LocationId;
4use crate::{
5    Block, BlockEnd, BlockId, MatchArm, MatchEnumInfo, MatchEnumValue, MatchExternInfo, MatchInfo,
6    Statement, StatementCall, StatementConst, StatementDesnap, StatementEnumConstruct,
7    StatementSnapshot, StatementStructConstruct, StatementStructDestructure, VarRemapping,
8    VarUsage, VariableId,
9};
10
11/// Options for the `inlining-strategy` arguments.
12#[derive(Copy, Clone, Debug, Default, PartialEq, Eq)]
13pub enum InliningStrategy {
14    /// Do not override inlining strategy.
15    ///
16    /// Note: equivalent to `InlineSmallFunctions(DEFAULT_INLINE_SMALL_FUNCTIONS_THRESHOLD)`.
17    #[default]
18    Default,
19    /// Should inline small functions up to the given weight.
20    ///
21    /// Note: the weight exact definition is subject to change.
22    InlineSmallFunctions(usize),
23    /// Inline only in the case of an `inline(always)` annotation.
24    Avoid,
25}
26
27/// A rebuilder trait for rebuilding lowered representation.
28pub trait Rebuilder<'db> {
29    fn map_var_id(&mut self, var: VariableId) -> VariableId;
30    fn map_var_usage(&mut self, var_usage: VarUsage<'db>) -> VarUsage<'db> {
31        VarUsage {
32            var_id: self.map_var_id(var_usage.var_id),
33            location: self.map_location(var_usage.location),
34        }
35    }
36    fn map_location(&mut self, location: LocationId<'db>) -> LocationId<'db> {
37        location
38    }
39    fn map_block_id(&mut self, block: BlockId) -> BlockId {
40        block
41    }
42    fn transform_statement(&mut self, _statement: &mut Statement<'db>) {}
43    fn transform_remapping(&mut self, _remapping: &mut VarRemapping<'db>) {}
44    fn transform_end(&mut self, _end: &mut BlockEnd<'db>) {}
45    fn transform_block(&mut self, _block: &mut Block<'db>) {}
46}
47
48pub trait RebuilderEx<'db>: Rebuilder<'db> {
49    /// Rebuilds the statement with renamed var and block ids.
50    fn rebuild_statement(&mut self, statement: &Statement<'db>) -> Statement<'db> {
51        let mut statement = match statement {
52            Statement::Const(stmt) => Statement::Const(StatementConst::new(
53                stmt.value,
54                self.map_var_id(stmt.output),
55                stmt.boxed,
56            )),
57            Statement::Call(stmt) => Statement::Call(StatementCall {
58                function: stmt.function,
59                inputs: stmt.inputs.iter().map(|v| self.map_var_usage(*v)).collect(),
60                with_coupon: stmt.with_coupon,
61                outputs: stmt.outputs.iter().map(|v| self.map_var_id(*v)).collect(),
62                location: self.map_location(stmt.location),
63            }),
64            Statement::StructConstruct(stmt) => {
65                Statement::StructConstruct(StatementStructConstruct {
66                    inputs: stmt.inputs.iter().map(|v| self.map_var_usage(*v)).collect(),
67                    output: self.map_var_id(stmt.output),
68                })
69            }
70            Statement::StructDestructure(stmt) => {
71                Statement::StructDestructure(StatementStructDestructure {
72                    input: self.map_var_usage(stmt.input),
73                    outputs: stmt.outputs.iter().map(|v| self.map_var_id(*v)).collect(),
74                })
75            }
76            Statement::EnumConstruct(stmt) => Statement::EnumConstruct(StatementEnumConstruct {
77                variant: stmt.variant,
78                input: self.map_var_usage(stmt.input),
79                output: self.map_var_id(stmt.output),
80            }),
81            Statement::Snapshot(stmt) => Statement::Snapshot(StatementSnapshot::new(
82                self.map_var_usage(stmt.input),
83                self.map_var_id(stmt.original()),
84                self.map_var_id(stmt.snapshot()),
85            )),
86            Statement::Desnap(stmt) => Statement::Desnap(StatementDesnap {
87                input: self.map_var_usage(stmt.input),
88                output: self.map_var_id(stmt.output),
89            }),
90        };
91        self.transform_statement(&mut statement);
92        statement
93    }
94
95    /// Apply map_var_id to all the variables in the `remapping`.
96    fn rebuild_remapping(&mut self, remapping: &VarRemapping<'db>) -> VarRemapping<'db> {
97        let mut remapping = VarRemapping {
98            remapping: OrderedHashMap::from_iter(remapping.iter().map(|(dst, src_var_usage)| {
99                (self.map_var_id(*dst), self.map_var_usage(*src_var_usage))
100            })),
101        };
102        self.transform_remapping(&mut remapping);
103        remapping
104    }
105
106    /// Rebuilds the block end with renamed var and block ids.
107    fn rebuild_end(&mut self, end: &BlockEnd<'db>) -> BlockEnd<'db> {
108        let mut end = match end {
109            BlockEnd::Return(returns, location) => BlockEnd::Return(
110                returns.iter().map(|var_usage| self.map_var_usage(*var_usage)).collect(),
111                self.map_location(*location),
112            ),
113            BlockEnd::Panic(data) => BlockEnd::Panic(self.map_var_usage(*data)),
114            BlockEnd::Goto(block_id, remapping) => {
115                BlockEnd::Goto(self.map_block_id(*block_id), self.rebuild_remapping(remapping))
116            }
117            BlockEnd::NotSet => unreachable!(),
118            BlockEnd::Match { info } => BlockEnd::Match {
119                info: match info {
120                    MatchInfo::Extern(stmt) => MatchInfo::Extern(MatchExternInfo {
121                        function: stmt.function,
122                        inputs: stmt.inputs.iter().map(|v| self.map_var_usage(*v)).collect(),
123                        arms: stmt
124                            .arms
125                            .iter()
126                            .map(|arm| MatchArm {
127                                arm_selector: arm.arm_selector.clone(),
128                                block_id: self.map_block_id(arm.block_id),
129                                var_ids: arm
130                                    .var_ids
131                                    .iter()
132                                    .map(|var_id| self.map_var_id(*var_id))
133                                    .collect(),
134                            })
135                            .collect(),
136                        location: self.map_location(stmt.location),
137                    }),
138                    MatchInfo::Enum(stmt) => MatchInfo::Enum(MatchEnumInfo {
139                        concrete_enum_id: stmt.concrete_enum_id,
140                        input: self.map_var_usage(stmt.input),
141                        arms: stmt
142                            .arms
143                            .iter()
144                            .map(|arm| MatchArm {
145                                arm_selector: arm.arm_selector.clone(),
146                                block_id: self.map_block_id(arm.block_id),
147                                var_ids: arm
148                                    .var_ids
149                                    .iter()
150                                    .map(|var_id| self.map_var_id(*var_id))
151                                    .collect(),
152                            })
153                            .collect(),
154                        location: self.map_location(stmt.location),
155                    }),
156                    MatchInfo::Value(stmt) => MatchInfo::Value(MatchEnumValue {
157                        num_of_arms: stmt.num_of_arms,
158                        input: self.map_var_usage(stmt.input),
159                        arms: stmt
160                            .arms
161                            .iter()
162                            .map(|arm| MatchArm {
163                                arm_selector: arm.arm_selector.clone(),
164                                block_id: self.map_block_id(arm.block_id),
165                                var_ids: arm
166                                    .var_ids
167                                    .iter()
168                                    .map(|var_id| self.map_var_id(*var_id))
169                                    .collect(),
170                            })
171                            .collect(),
172                        location: self.map_location(stmt.location),
173                    }),
174                },
175            },
176        };
177        self.transform_end(&mut end);
178        end
179    }
180
181    /// Rebuilds the block with renamed var and block ids.
182    fn rebuild_block(&mut self, block: &Block<'db>) -> Block<'db> {
183        let statements = block.statements.iter().map(|stmt| self.rebuild_statement(stmt)).collect();
184        let end = self.rebuild_end(&block.end);
185        let mut block = Block { statements, end };
186        self.transform_block(&mut block);
187        block
188    }
189}
190
191impl<'db, T: Rebuilder<'db>> RebuilderEx<'db> for T {}