1use cairo_lang_utils::ordered_hash_map::OrderedHashMap;
2
3use crate::ids::LocationId;
4use crate::{
5 Block, BlockEnd, BlockId, MatchArm, MatchEnumInfo, MatchEnumValue, MatchExternInfo, MatchInfo,
6 Statement, StatementCall, StatementConst, StatementDesnap, StatementEnumConstruct,
7 StatementSnapshot, StatementStructConstruct, StatementStructDestructure, VarRemapping,
8 VarUsage, VariableId,
9};
10
11#[derive(Copy, Clone, Debug, Default, PartialEq, Eq)]
13pub enum InliningStrategy {
14 #[default]
18 Default,
19 InlineSmallFunctions(usize),
23 Avoid,
25}
26
27pub trait Rebuilder<'db> {
29 fn map_var_id(&mut self, var: VariableId) -> VariableId;
30 fn map_var_usage(&mut self, var_usage: VarUsage<'db>) -> VarUsage<'db> {
31 VarUsage {
32 var_id: self.map_var_id(var_usage.var_id),
33 location: self.map_location(var_usage.location),
34 }
35 }
36 fn map_location(&mut self, location: LocationId<'db>) -> LocationId<'db> {
37 location
38 }
39 fn map_block_id(&mut self, block: BlockId) -> BlockId {
40 block
41 }
42 fn transform_statement(&mut self, _statement: &mut Statement<'db>) {}
43 fn transform_remapping(&mut self, _remapping: &mut VarRemapping<'db>) {}
44 fn transform_end(&mut self, _end: &mut BlockEnd<'db>) {}
45 fn transform_block(&mut self, _block: &mut Block<'db>) {}
46}
47
48pub trait RebuilderEx<'db>: Rebuilder<'db> {
49 fn rebuild_statement(&mut self, statement: &Statement<'db>) -> Statement<'db> {
51 let mut statement = match statement {
52 Statement::Const(stmt) => Statement::Const(StatementConst::new(
53 stmt.value,
54 self.map_var_id(stmt.output),
55 stmt.boxed,
56 )),
57 Statement::Call(stmt) => Statement::Call(StatementCall {
58 function: stmt.function,
59 inputs: stmt.inputs.iter().map(|v| self.map_var_usage(*v)).collect(),
60 with_coupon: stmt.with_coupon,
61 outputs: stmt.outputs.iter().map(|v| self.map_var_id(*v)).collect(),
62 location: self.map_location(stmt.location),
63 is_specialization_base_call: stmt.is_specialization_base_call,
64 }),
65 Statement::StructConstruct(stmt) => {
66 Statement::StructConstruct(StatementStructConstruct {
67 inputs: stmt.inputs.iter().map(|v| self.map_var_usage(*v)).collect(),
68 output: self.map_var_id(stmt.output),
69 })
70 }
71 Statement::StructDestructure(stmt) => {
72 Statement::StructDestructure(StatementStructDestructure {
73 input: self.map_var_usage(stmt.input),
74 outputs: stmt.outputs.iter().map(|v| self.map_var_id(*v)).collect(),
75 })
76 }
77 Statement::EnumConstruct(stmt) => Statement::EnumConstruct(StatementEnumConstruct {
78 variant: stmt.variant,
79 input: self.map_var_usage(stmt.input),
80 output: self.map_var_id(stmt.output),
81 }),
82 Statement::Snapshot(stmt) => Statement::Snapshot(StatementSnapshot::new(
83 self.map_var_usage(stmt.input),
84 self.map_var_id(stmt.original()),
85 self.map_var_id(stmt.snapshot()),
86 )),
87 Statement::Desnap(stmt) => Statement::Desnap(StatementDesnap {
88 input: self.map_var_usage(stmt.input),
89 output: self.map_var_id(stmt.output),
90 }),
91 };
92 self.transform_statement(&mut statement);
93 statement
94 }
95
96 fn rebuild_remapping(&mut self, remapping: &VarRemapping<'db>) -> VarRemapping<'db> {
98 let mut remapping = VarRemapping {
99 remapping: OrderedHashMap::from_iter(remapping.iter().map(|(dst, src_var_usage)| {
100 (self.map_var_id(*dst), self.map_var_usage(*src_var_usage))
101 })),
102 };
103 self.transform_remapping(&mut remapping);
104 remapping
105 }
106
107 fn rebuild_end(&mut self, end: &BlockEnd<'db>) -> BlockEnd<'db> {
109 let mut end = match end {
110 BlockEnd::Return(returns, location) => BlockEnd::Return(
111 returns.iter().map(|var_usage| self.map_var_usage(*var_usage)).collect(),
112 self.map_location(*location),
113 ),
114 BlockEnd::Panic(data) => BlockEnd::Panic(self.map_var_usage(*data)),
115 BlockEnd::Goto(block_id, remapping) => {
116 BlockEnd::Goto(self.map_block_id(*block_id), self.rebuild_remapping(remapping))
117 }
118 BlockEnd::NotSet => unreachable!(),
119 BlockEnd::Match { info } => BlockEnd::Match {
120 info: match info {
121 MatchInfo::Extern(stmt) => MatchInfo::Extern(MatchExternInfo {
122 function: stmt.function,
123 inputs: stmt.inputs.iter().map(|v| self.map_var_usage(*v)).collect(),
124 arms: stmt
125 .arms
126 .iter()
127 .map(|arm| MatchArm {
128 arm_selector: arm.arm_selector.clone(),
129 block_id: self.map_block_id(arm.block_id),
130 var_ids: arm
131 .var_ids
132 .iter()
133 .map(|var_id| self.map_var_id(*var_id))
134 .collect(),
135 })
136 .collect(),
137 location: self.map_location(stmt.location),
138 }),
139 MatchInfo::Enum(stmt) => MatchInfo::Enum(MatchEnumInfo {
140 concrete_enum_id: stmt.concrete_enum_id,
141 input: self.map_var_usage(stmt.input),
142 arms: stmt
143 .arms
144 .iter()
145 .map(|arm| MatchArm {
146 arm_selector: arm.arm_selector.clone(),
147 block_id: self.map_block_id(arm.block_id),
148 var_ids: arm
149 .var_ids
150 .iter()
151 .map(|var_id| self.map_var_id(*var_id))
152 .collect(),
153 })
154 .collect(),
155 location: self.map_location(stmt.location),
156 }),
157 MatchInfo::Value(stmt) => MatchInfo::Value(MatchEnumValue {
158 num_of_arms: stmt.num_of_arms,
159 input: self.map_var_usage(stmt.input),
160 arms: stmt
161 .arms
162 .iter()
163 .map(|arm| MatchArm {
164 arm_selector: arm.arm_selector.clone(),
165 block_id: self.map_block_id(arm.block_id),
166 var_ids: arm
167 .var_ids
168 .iter()
169 .map(|var_id| self.map_var_id(*var_id))
170 .collect(),
171 })
172 .collect(),
173 location: self.map_location(stmt.location),
174 }),
175 },
176 },
177 };
178 self.transform_end(&mut end);
179 end
180 }
181
182 fn rebuild_block(&mut self, block: &Block<'db>) -> Block<'db> {
184 let statements = block.statements.iter().map(|stmt| self.rebuild_statement(stmt)).collect();
185 let end = self.rebuild_end(&block.end);
186 let mut block = Block { statements, end };
187 self.transform_block(&mut block);
188 block
189 }
190}
191
192impl<'db, T: Rebuilder<'db>> RebuilderEx<'db> for T {}