1#[cfg(test)]
2#[path = "dedup_blocks_test.rs"]
3mod test;
4
5use cairo_lang_semantic::items::constant::ConstValueId;
6use cairo_lang_semantic::{ConcreteVariant, TypeId};
7use cairo_lang_utils::ordered_hash_map::OrderedHashMap;
8use cairo_lang_utils::unordered_hash_map::{self, UnorderedHashMap};
9use itertools::{Itertools, zip_eq};
10
11use crate::ids::FunctionId;
12use crate::utils::{Rebuilder, RebuilderEx};
13use crate::{
14 Block, BlockEnd, BlockId, Lowered, Statement, StatementCall, StatementConst, StatementDesnap,
15 StatementEnumConstruct, StatementSnapshot, StatementStructConstruct,
16 StatementStructDestructure, VarRemapping, VarUsage, VariableArena, VariableId,
17};
18
19#[derive(Hash, PartialEq, Eq)]
22struct CanonicBlock<'db> {
23 stmts: Vec<CanonicStatement<'db>>,
25 types: Vec<TypeId<'db>>,
27 returns: Vec<CanonicVar>,
29}
30
31#[derive(Hash, PartialEq, Eq)]
33struct CanonicVar(usize);
34
35#[derive(Hash, PartialEq, Eq)]
37enum CanonicStatement<'db> {
38 Const {
39 value: ConstValueId<'db>,
40 output: CanonicVar,
41 boxed: bool,
42 },
43 Call {
44 function: FunctionId<'db>,
45 inputs: Vec<CanonicVar>,
46 with_coupon: bool,
47 outputs: Vec<CanonicVar>,
48 },
49 StructConstruct {
50 inputs: Vec<CanonicVar>,
51 output: CanonicVar,
52 },
53 StructDestructure {
54 input: CanonicVar,
55 outputs: Vec<CanonicVar>,
56 },
57 EnumConstruct {
58 variant: ConcreteVariant<'db>,
59 input: CanonicVar,
60 output: CanonicVar,
61 },
62
63 Snapshot {
64 input: CanonicVar,
65 outputs: [CanonicVar; 2],
66 },
67 Desnap {
68 input: CanonicVar,
69 output: CanonicVar,
70 },
71}
72
73struct CanonicBlockBuilder<'db, 'a> {
74 variable: &'a VariableArena<'db>,
75 vars: UnorderedHashMap<VariableId, usize>,
76 types: Vec<TypeId<'db>>,
77 inputs: Vec<VarUsage<'db>>,
78}
79
80impl<'db, 'a> CanonicBlockBuilder<'db, 'a> {
81 fn new(variable: &'a VariableArena<'db>) -> CanonicBlockBuilder<'db, 'a> {
82 CanonicBlockBuilder {
83 variable,
84 vars: Default::default(),
85 types: vec![],
86 inputs: Default::default(),
87 }
88 }
89
90 fn handle_input(&mut self, var_usage: &VarUsage<'db>) -> CanonicVar {
92 let v = var_usage.var_id;
93
94 CanonicVar(match self.vars.entry(v) {
95 std::collections::hash_map::Entry::Occupied(e) => *e.get(),
96 std::collections::hash_map::Entry::Vacant(e) => {
97 self.types.push(self.variable[v].ty);
98 let new_id = *e.insert(self.types.len() - 1);
99 self.inputs.push(*var_usage);
100 new_id
101 }
102 })
103 }
104
105 fn handle_output(&mut self, v: &VariableId) -> CanonicVar {
107 CanonicVar(match self.vars.entry(*v) {
108 std::collections::hash_map::Entry::Occupied(e) => *e.get(),
109 std::collections::hash_map::Entry::Vacant(e) => {
110 self.types.push(self.variable[*v].ty);
111 *e.insert(self.types.len() - 1)
112 }
113 })
114 }
115
116 fn handle_statement(&mut self, statement: &Statement<'db>) -> CanonicStatement<'db> {
118 match statement {
119 Statement::Const(StatementConst { value, boxed, output }) => CanonicStatement::Const {
120 value: *value,
121 output: self.handle_output(output),
122 boxed: *boxed,
123 },
124 Statement::Call(StatementCall {
125 function,
126 inputs,
127 with_coupon,
128 outputs,
129 location: _,
130 is_specialization_base_call: _,
131 }) => CanonicStatement::Call {
132 function: *function,
133 inputs: inputs.iter().map(|input| self.handle_input(input)).collect(),
134 with_coupon: *with_coupon,
135 outputs: outputs.iter().map(|output| self.handle_output(output)).collect(),
136 },
137 Statement::StructConstruct(StatementStructConstruct { inputs, output }) => {
138 CanonicStatement::StructConstruct {
139 inputs: inputs.iter().map(|input| self.handle_input(input)).collect(),
140 output: self.handle_output(output),
141 }
142 }
143 Statement::StructDestructure(StatementStructDestructure { input, outputs }) => {
144 CanonicStatement::StructDestructure {
145 input: self.handle_input(input),
146 outputs: outputs.iter().map(|output| self.handle_output(output)).collect(),
147 }
148 }
149 Statement::EnumConstruct(StatementEnumConstruct { variant, input, output }) => {
150 CanonicStatement::EnumConstruct {
151 variant: *variant,
152 input: self.handle_input(input),
153 output: self.handle_output(output),
154 }
155 }
156 Statement::Snapshot(StatementSnapshot { input, outputs }) => {
157 CanonicStatement::Snapshot {
158 input: self.handle_input(input),
159 outputs: outputs.map(|output| self.handle_output(&output)),
160 }
161 }
162 Statement::Desnap(StatementDesnap { input, output }) => CanonicStatement::Desnap {
163 input: self.handle_input(input),
164 output: self.handle_output(output),
165 },
166 }
167 }
168}
169
170impl<'db> CanonicBlock<'db> {
171 fn try_from_block(
175 variable: &VariableArena<'db>,
176 block: &Block<'db>,
177 ) -> Option<(CanonicBlock<'db>, Vec<VarUsage<'db>>)> {
178 let BlockEnd::Return(returned_vars, _) = &block.end else {
179 return None;
180 };
181
182 if block.statements.is_empty() {
183 return None;
185 }
186
187 let mut builder = CanonicBlockBuilder::new(variable);
188
189 let stmts = block
190 .statements
191 .iter()
192 .map(|statement| builder.handle_statement(statement))
193 .collect_vec();
194
195 let returns = returned_vars.iter().map(|input| builder.handle_input(input)).collect();
196
197 Some((CanonicBlock { stmts, types: builder.types, returns }, builder.inputs))
198 }
199}
200pub struct VarReassigner<'db, 'a> {
202 pub variables: &'a mut VariableArena<'db>,
203
204 pub vars: UnorderedHashMap<VariableId, VariableId>,
206}
207
208impl<'db, 'a> VarReassigner<'db, 'a> {
209 pub fn new(variables: &'a mut VariableArena<'db>) -> Self {
210 Self { variables, vars: UnorderedHashMap::default() }
211 }
212}
213
214impl<'db, 'a> Rebuilder<'db> for VarReassigner<'db, 'a> {
215 fn map_var_id(&mut self, var: VariableId) -> VariableId {
216 *self.vars.entry(var).or_insert_with(|| self.variables.alloc(self.variables[var].clone()))
217 }
218}
219
220#[derive(Default)]
221struct DedupContext<'db> {
222 canonic_blocks: UnorderedHashMap<CanonicBlock<'db>, BlockId>,
224
225 block_id_to_inputs: UnorderedHashMap<BlockId, Vec<VarUsage<'db>>>,
227}
228
229fn rebuild_block_and_inputs<'db>(
232 variables: &mut VariableArena<'db>,
233 block: &Block<'db>,
234 inputs: &[VarUsage<'db>],
235) -> (Block<'db>, Vec<VarUsage<'db>>) {
236 let mut var_reassigner = VarReassigner::new(variables);
237 (
238 var_reassigner.rebuild_block(block),
239 inputs.iter().map(|var_usage| var_reassigner.map_var_usage(*var_usage)).collect(),
240 )
241}
242
243pub fn dedup_blocks<'db>(lowered: &mut Lowered<'db>) {
246 if lowered.blocks.has_root().is_err() {
247 return;
248 }
249
250 let mut ctx = DedupContext::default();
251 let mut duplicates: UnorderedHashMap<BlockId, (BlockId, Vec<VarUsage<'_>>)> =
254 Default::default();
255
256 let mut new_blocks = vec![];
257 let mut next_block_id = BlockId(lowered.blocks.len());
258
259 for (block_id, block) in lowered.blocks.iter() {
260 let Some((canonical_block, inputs)) =
261 CanonicBlock::try_from_block(&lowered.variables, block)
262 else {
263 continue;
264 };
265
266 match ctx.canonic_blocks.entry(canonical_block) {
267 unordered_hash_map::Entry::Occupied(e) => {
268 let block_and_inputs = duplicates
269 .entry(*e.get())
270 .or_insert_with(|| {
271 let (block, new_inputs) =
272 rebuild_block_and_inputs(&mut lowered.variables, block, &inputs);
273 new_blocks.push(block);
274 let new_block_id = next_block_id;
275 next_block_id = next_block_id.next_block_id();
276
277 (new_block_id, new_inputs)
278 })
279 .clone();
280
281 duplicates.insert(block_id, block_and_inputs);
282 }
283 unordered_hash_map::Entry::Vacant(e) => {
284 e.insert(block_id);
285 }
286 };
287
288 ctx.block_id_to_inputs.insert(block_id, inputs);
289 }
290
291 let mut new_goto_block =
292 |block_id, inputs: &Vec<VarUsage<'db>>, target_inputs: &Vec<VarUsage<'db>>| {
293 new_blocks.push(Block {
294 statements: vec![],
295 end: BlockEnd::Goto(
296 block_id,
297 VarRemapping {
298 remapping: OrderedHashMap::from_iter(zip_eq(
299 target_inputs.iter().map(|var_usage| var_usage.var_id),
300 inputs.iter().cloned(),
301 )),
302 },
303 ),
304 });
305
306 let new_block_id = next_block_id;
307 next_block_id = next_block_id.next_block_id();
308 new_block_id
309 };
310
311 for block in lowered.blocks.iter_mut() {
314 match &mut block.end {
315 BlockEnd::Goto(target_block, remappings) => {
316 let Some((block_id, target_inputs)) = duplicates.get(target_block) else {
317 continue;
318 };
319
320 let inputs = ctx.block_id_to_inputs.get(target_block).unwrap();
321 let mut inputs_remapping = VarRemapping {
322 remapping: OrderedHashMap::from_iter(zip_eq(
323 target_inputs.iter().map(|var_usage| var_usage.var_id),
324 inputs.iter().cloned(),
325 )),
326 };
327 for (_, src) in inputs_remapping.iter_mut() {
328 if let Some(src_before_remapping) = remappings.get(&src.var_id) {
329 *src = *src_before_remapping;
330 }
331 }
332
333 *target_block = *block_id;
334 *remappings = inputs_remapping;
335 }
336 BlockEnd::Match { info } => {
337 for arm in info.arms_mut() {
338 let Some((block_id, target_inputs)) = duplicates.get(&arm.block_id) else {
339 continue;
340 };
341
342 let inputs = &ctx.block_id_to_inputs[&arm.block_id];
343 arm.block_id = new_goto_block(*block_id, inputs, target_inputs);
344 }
345 }
346 _ => {}
347 }
348 }
349
350 for block in new_blocks {
351 lowered.blocks.push(block);
352 }
353}