cairo_lang_lowering/
utils.rs

1use cairo_lang_utils::ordered_hash_map::OrderedHashMap;
2
3use crate::ids::LocationId;
4use crate::{
5    Block, BlockEnd, BlockId, MatchArm, MatchEnumInfo, MatchEnumValue, MatchExternInfo, MatchInfo,
6    Statement, StatementCall, StatementConst, StatementDesnap, StatementEnumConstruct,
7    StatementSnapshot, StatementStructConstruct, StatementStructDestructure, VarRemapping,
8    VarUsage, VariableId,
9};
10
11/// Options for the `inlining-strategy` arguments.
12#[derive(Copy, Clone, Debug, Default, PartialEq, Eq)]
13pub enum InliningStrategy {
14    /// Do not override inlining strategy.
15    ///
16    /// Note: equivalent to `InlineSmallFunctions(DEFAULT_INLINE_SMALL_FUNCTIONS_THRESHOLD)`.
17    #[default]
18    Default,
19    /// Should inline small functions up to the given weight.
20    ///
21    /// Note: the weight exact definition is subject to change.
22    InlineSmallFunctions(usize),
23    /// Inline only in the case of an `inline(always)` annotation.
24    Avoid,
25}
26
27/// A rebuilder trait for rebuilding lowered representation.
28pub trait Rebuilder<'db> {
29    fn map_var_id(&mut self, var: VariableId) -> VariableId;
30    fn map_var_usage(&mut self, var_usage: VarUsage<'db>) -> VarUsage<'db> {
31        VarUsage {
32            var_id: self.map_var_id(var_usage.var_id),
33            location: self.map_location(var_usage.location),
34        }
35    }
36    fn map_location(&mut self, location: LocationId<'db>) -> LocationId<'db> {
37        location
38    }
39    fn map_block_id(&mut self, block: BlockId) -> BlockId {
40        block
41    }
42    fn transform_statement(&mut self, _statement: &mut Statement<'db>) {}
43    fn transform_remapping(&mut self, _remapping: &mut VarRemapping<'db>) {}
44    fn transform_end(&mut self, _end: &mut BlockEnd<'db>) {}
45    fn transform_block(&mut self, _block: &mut Block<'db>) {}
46}
47
48pub trait RebuilderEx<'db>: Rebuilder<'db> {
49    /// Rebuilds the statement with renamed var and block ids.
50    fn rebuild_statement(&mut self, statement: &Statement<'db>) -> Statement<'db> {
51        let mut statement = match statement {
52            Statement::Const(stmt) => Statement::Const(StatementConst::new(
53                stmt.value,
54                self.map_var_id(stmt.output),
55                stmt.boxed,
56            )),
57            Statement::Call(stmt) => Statement::Call(StatementCall {
58                function: stmt.function,
59                inputs: stmt.inputs.iter().map(|v| self.map_var_usage(*v)).collect(),
60                with_coupon: stmt.with_coupon,
61                outputs: stmt.outputs.iter().map(|v| self.map_var_id(*v)).collect(),
62                location: self.map_location(stmt.location),
63                is_specialization_base_call: stmt.is_specialization_base_call,
64            }),
65            Statement::StructConstruct(stmt) => {
66                Statement::StructConstruct(StatementStructConstruct {
67                    inputs: stmt.inputs.iter().map(|v| self.map_var_usage(*v)).collect(),
68                    output: self.map_var_id(stmt.output),
69                })
70            }
71            Statement::StructDestructure(stmt) => {
72                Statement::StructDestructure(StatementStructDestructure {
73                    input: self.map_var_usage(stmt.input),
74                    outputs: stmt.outputs.iter().map(|v| self.map_var_id(*v)).collect(),
75                })
76            }
77            Statement::EnumConstruct(stmt) => Statement::EnumConstruct(StatementEnumConstruct {
78                variant: stmt.variant,
79                input: self.map_var_usage(stmt.input),
80                output: self.map_var_id(stmt.output),
81            }),
82            Statement::Snapshot(stmt) => Statement::Snapshot(StatementSnapshot::new(
83                self.map_var_usage(stmt.input),
84                self.map_var_id(stmt.original()),
85                self.map_var_id(stmt.snapshot()),
86            )),
87            Statement::Desnap(stmt) => Statement::Desnap(StatementDesnap {
88                input: self.map_var_usage(stmt.input),
89                output: self.map_var_id(stmt.output),
90            }),
91        };
92        self.transform_statement(&mut statement);
93        statement
94    }
95
96    /// Apply map_var_id to all the variables in the `remapping`.
97    fn rebuild_remapping(&mut self, remapping: &VarRemapping<'db>) -> VarRemapping<'db> {
98        let mut remapping = VarRemapping {
99            remapping: OrderedHashMap::from_iter(remapping.iter().map(|(dst, src_var_usage)| {
100                (self.map_var_id(*dst), self.map_var_usage(*src_var_usage))
101            })),
102        };
103        self.transform_remapping(&mut remapping);
104        remapping
105    }
106
107    /// Rebuilds the block end with renamed var and block ids.
108    fn rebuild_end(&mut self, end: &BlockEnd<'db>) -> BlockEnd<'db> {
109        let mut end = match end {
110            BlockEnd::Return(returns, location) => BlockEnd::Return(
111                returns.iter().map(|var_usage| self.map_var_usage(*var_usage)).collect(),
112                self.map_location(*location),
113            ),
114            BlockEnd::Panic(data) => BlockEnd::Panic(self.map_var_usage(*data)),
115            BlockEnd::Goto(block_id, remapping) => {
116                BlockEnd::Goto(self.map_block_id(*block_id), self.rebuild_remapping(remapping))
117            }
118            BlockEnd::NotSet => unreachable!(),
119            BlockEnd::Match { info } => BlockEnd::Match {
120                info: match info {
121                    MatchInfo::Extern(stmt) => MatchInfo::Extern(MatchExternInfo {
122                        function: stmt.function,
123                        inputs: stmt.inputs.iter().map(|v| self.map_var_usage(*v)).collect(),
124                        arms: stmt
125                            .arms
126                            .iter()
127                            .map(|arm| MatchArm {
128                                arm_selector: arm.arm_selector.clone(),
129                                block_id: self.map_block_id(arm.block_id),
130                                var_ids: arm
131                                    .var_ids
132                                    .iter()
133                                    .map(|var_id| self.map_var_id(*var_id))
134                                    .collect(),
135                            })
136                            .collect(),
137                        location: self.map_location(stmt.location),
138                    }),
139                    MatchInfo::Enum(stmt) => MatchInfo::Enum(MatchEnumInfo {
140                        concrete_enum_id: stmt.concrete_enum_id,
141                        input: self.map_var_usage(stmt.input),
142                        arms: stmt
143                            .arms
144                            .iter()
145                            .map(|arm| MatchArm {
146                                arm_selector: arm.arm_selector.clone(),
147                                block_id: self.map_block_id(arm.block_id),
148                                var_ids: arm
149                                    .var_ids
150                                    .iter()
151                                    .map(|var_id| self.map_var_id(*var_id))
152                                    .collect(),
153                            })
154                            .collect(),
155                        location: self.map_location(stmt.location),
156                    }),
157                    MatchInfo::Value(stmt) => MatchInfo::Value(MatchEnumValue {
158                        num_of_arms: stmt.num_of_arms,
159                        input: self.map_var_usage(stmt.input),
160                        arms: stmt
161                            .arms
162                            .iter()
163                            .map(|arm| MatchArm {
164                                arm_selector: arm.arm_selector.clone(),
165                                block_id: self.map_block_id(arm.block_id),
166                                var_ids: arm
167                                    .var_ids
168                                    .iter()
169                                    .map(|var_id| self.map_var_id(*var_id))
170                                    .collect(),
171                            })
172                            .collect(),
173                        location: self.map_location(stmt.location),
174                    }),
175                },
176            },
177        };
178        self.transform_end(&mut end);
179        end
180    }
181
182    /// Rebuilds the block with renamed var and block ids.
183    fn rebuild_block(&mut self, block: &Block<'db>) -> Block<'db> {
184        let statements = block.statements.iter().map(|stmt| self.rebuild_statement(stmt)).collect();
185        let end = self.rebuild_end(&block.end);
186        let mut block = Block { statements, end };
187        self.transform_block(&mut block);
188        block
189    }
190}
191
192impl<'db, T: Rebuilder<'db>> RebuilderEx<'db> for T {}