cairo_lang_lowering/optimizations/
dedup_blocks.rs1#[cfg(test)]
2#[path = "dedup_blocks_test.rs"]
3mod test;
4
5use cairo_lang_semantic::items::constant::ConstValueId;
6use cairo_lang_semantic::{ConcreteVariant, TypeId};
7use cairo_lang_utils::ordered_hash_map::OrderedHashMap;
8use cairo_lang_utils::unordered_hash_map::{self, UnorderedHashMap};
9use itertools::{Itertools, zip_eq};
10
11use crate::ids::FunctionId;
12use crate::utils::{Rebuilder, RebuilderEx};
13use crate::{
14 Block, BlockEnd, BlockId, Lowered, Statement, StatementCall, StatementConst, StatementDesnap,
15 StatementEnumConstruct, StatementSnapshot, StatementStructConstruct,
16 StatementStructDestructure, VarRemapping, VarUsage, VariableArena, VariableId,
17};
18
19#[derive(Hash, PartialEq, Eq)]
22struct CanonicBlock<'db> {
23 stmts: Vec<CanonicStatement<'db>>,
25 types: Vec<TypeId<'db>>,
27 returns: Vec<CanonicVar>,
29}
30
31#[derive(Hash, PartialEq, Eq)]
33struct CanonicVar(usize);
34
35#[derive(Hash, PartialEq, Eq)]
37enum CanonicStatement<'db> {
38 Const {
39 value: ConstValueId<'db>,
40 output: CanonicVar,
41 boxed: bool,
42 },
43 Call {
44 function: FunctionId<'db>,
45 inputs: Vec<CanonicVar>,
46 with_coupon: bool,
47 outputs: Vec<CanonicVar>,
48 },
49 StructConstruct {
50 inputs: Vec<CanonicVar>,
51 output: CanonicVar,
52 },
53 StructDestructure {
54 input: CanonicVar,
55 outputs: Vec<CanonicVar>,
56 },
57 EnumConstruct {
58 variant: ConcreteVariant<'db>,
59 input: CanonicVar,
60 output: CanonicVar,
61 },
62
63 Snapshot {
64 input: CanonicVar,
65 outputs: [CanonicVar; 2],
66 },
67 Desnap {
68 input: CanonicVar,
69 output: CanonicVar,
70 },
71}
72
73struct CanonicBlockBuilder<'db, 'a> {
74 variable: &'a VariableArena<'db>,
75 vars: UnorderedHashMap<VariableId, usize>,
76 types: Vec<TypeId<'db>>,
77 inputs: Vec<VarUsage<'db>>,
78}
79
80impl<'db, 'a> CanonicBlockBuilder<'db, 'a> {
81 fn new(variable: &'a VariableArena<'db>) -> CanonicBlockBuilder<'db, 'a> {
82 CanonicBlockBuilder {
83 variable,
84 vars: Default::default(),
85 types: vec![],
86 inputs: Default::default(),
87 }
88 }
89
90 fn handle_input(&mut self, var_usage: &VarUsage<'db>) -> CanonicVar {
92 let v = var_usage.var_id;
93
94 CanonicVar(match self.vars.entry(v) {
95 std::collections::hash_map::Entry::Occupied(e) => *e.get(),
96 std::collections::hash_map::Entry::Vacant(e) => {
97 self.types.push(self.variable[v].ty);
98 let new_id = *e.insert(self.types.len() - 1);
99 self.inputs.push(*var_usage);
100 new_id
101 }
102 })
103 }
104
105 fn handle_output(&mut self, v: &VariableId) -> CanonicVar {
107 CanonicVar(match self.vars.entry(*v) {
108 std::collections::hash_map::Entry::Occupied(e) => *e.get(),
109 std::collections::hash_map::Entry::Vacant(e) => {
110 self.types.push(self.variable[*v].ty);
111 *e.insert(self.types.len() - 1)
112 }
113 })
114 }
115
116 fn handle_statement(&mut self, statement: &Statement<'db>) -> CanonicStatement<'db> {
118 match statement {
119 Statement::Const(StatementConst { value, boxed, output }) => CanonicStatement::Const {
120 value: *value,
121 output: self.handle_output(output),
122 boxed: *boxed,
123 },
124 Statement::Call(StatementCall {
125 function,
126 inputs,
127 with_coupon,
128 outputs,
129 location: _,
130 }) => CanonicStatement::Call {
131 function: *function,
132 inputs: inputs.iter().map(|input| self.handle_input(input)).collect(),
133 with_coupon: *with_coupon,
134 outputs: outputs.iter().map(|output| self.handle_output(output)).collect(),
135 },
136 Statement::StructConstruct(StatementStructConstruct { inputs, output }) => {
137 CanonicStatement::StructConstruct {
138 inputs: inputs.iter().map(|input| self.handle_input(input)).collect(),
139 output: self.handle_output(output),
140 }
141 }
142 Statement::StructDestructure(StatementStructDestructure { input, outputs }) => {
143 CanonicStatement::StructDestructure {
144 input: self.handle_input(input),
145 outputs: outputs.iter().map(|output| self.handle_output(output)).collect(),
146 }
147 }
148 Statement::EnumConstruct(StatementEnumConstruct { variant, input, output }) => {
149 CanonicStatement::EnumConstruct {
150 variant: *variant,
151 input: self.handle_input(input),
152 output: self.handle_output(output),
153 }
154 }
155 Statement::Snapshot(StatementSnapshot { input, outputs }) => {
156 CanonicStatement::Snapshot {
157 input: self.handle_input(input),
158 outputs: outputs.map(|output| self.handle_output(&output)),
159 }
160 }
161 Statement::Desnap(StatementDesnap { input, output }) => CanonicStatement::Desnap {
162 input: self.handle_input(input),
163 output: self.handle_output(output),
164 },
165 }
166 }
167}
168
169impl<'db> CanonicBlock<'db> {
170 fn try_from_block(
174 variable: &VariableArena<'db>,
175 block: &Block<'db>,
176 ) -> Option<(CanonicBlock<'db>, Vec<VarUsage<'db>>)> {
177 let BlockEnd::Return(returned_vars, _) = &block.end else {
178 return None;
179 };
180
181 if block.statements.is_empty() {
182 return None;
184 }
185
186 let mut builder = CanonicBlockBuilder::new(variable);
187
188 let stmts = block
189 .statements
190 .iter()
191 .map(|statement| builder.handle_statement(statement))
192 .collect_vec();
193
194 let returns = returned_vars.iter().map(|input| builder.handle_input(input)).collect();
195
196 Some((CanonicBlock { stmts, types: builder.types, returns }, builder.inputs))
197 }
198}
199pub struct VarReassigner<'db, 'a> {
201 pub variables: &'a mut VariableArena<'db>,
202
203 pub vars: UnorderedHashMap<VariableId, VariableId>,
205}
206
207impl<'db, 'a> VarReassigner<'db, 'a> {
208 pub fn new(variables: &'a mut VariableArena<'db>) -> Self {
209 Self { variables, vars: UnorderedHashMap::default() }
210 }
211}
212
213impl<'db, 'a> Rebuilder<'db> for VarReassigner<'db, 'a> {
214 fn map_var_id(&mut self, var: VariableId) -> VariableId {
215 *self.vars.entry(var).or_insert_with(|| self.variables.alloc(self.variables[var].clone()))
216 }
217}
218
219#[derive(Default)]
220struct DedupContext<'db> {
221 canonic_blocks: UnorderedHashMap<CanonicBlock<'db>, BlockId>,
223
224 block_id_to_inputs: UnorderedHashMap<BlockId, Vec<VarUsage<'db>>>,
226}
227
228fn rebuild_block_and_inputs<'db>(
231 variables: &mut VariableArena<'db>,
232 block: &Block<'db>,
233 inputs: &[VarUsage<'db>],
234) -> (Block<'db>, Vec<VarUsage<'db>>) {
235 let mut var_reassigner = VarReassigner::new(variables);
236 (
237 var_reassigner.rebuild_block(block),
238 inputs.iter().map(|var_usage| var_reassigner.map_var_usage(*var_usage)).collect(),
239 )
240}
241
242pub fn dedup_blocks<'db>(lowered: &mut Lowered<'db>) {
245 if lowered.blocks.has_root().is_err() {
246 return;
247 }
248
249 let mut ctx = DedupContext::default();
250 let mut duplicates: UnorderedHashMap<BlockId, (BlockId, Vec<VarUsage<'_>>)> =
253 Default::default();
254
255 let mut new_blocks = vec![];
256 let mut next_block_id = BlockId(lowered.blocks.len());
257
258 for (block_id, block) in lowered.blocks.iter() {
259 let Some((canonical_block, inputs)) =
260 CanonicBlock::try_from_block(&lowered.variables, block)
261 else {
262 continue;
263 };
264
265 match ctx.canonic_blocks.entry(canonical_block) {
266 unordered_hash_map::Entry::Occupied(e) => {
267 let block_and_inputs = duplicates
268 .entry(*e.get())
269 .or_insert_with(|| {
270 let (block, new_inputs) =
271 rebuild_block_and_inputs(&mut lowered.variables, block, &inputs);
272 new_blocks.push(block);
273 let new_block_id = next_block_id;
274 next_block_id = next_block_id.next_block_id();
275
276 (new_block_id, new_inputs)
277 })
278 .clone();
279
280 duplicates.insert(block_id, block_and_inputs);
281 }
282 unordered_hash_map::Entry::Vacant(e) => {
283 e.insert(block_id);
284 }
285 };
286
287 ctx.block_id_to_inputs.insert(block_id, inputs);
288 }
289
290 let mut new_goto_block =
291 |block_id, inputs: &Vec<VarUsage<'db>>, target_inputs: &Vec<VarUsage<'db>>| {
292 new_blocks.push(Block {
293 statements: vec![],
294 end: BlockEnd::Goto(
295 block_id,
296 VarRemapping {
297 remapping: OrderedHashMap::from_iter(zip_eq(
298 target_inputs.iter().map(|var_usage| var_usage.var_id),
299 inputs.iter().cloned(),
300 )),
301 },
302 ),
303 });
304
305 let new_block_id = next_block_id;
306 next_block_id = next_block_id.next_block_id();
307 new_block_id
308 };
309
310 for block in lowered.blocks.iter_mut() {
313 match &mut block.end {
314 BlockEnd::Goto(target_block, remappings) => {
315 let Some((block_id, target_inputs)) = duplicates.get(target_block) else {
316 continue;
317 };
318
319 let inputs = ctx.block_id_to_inputs.get(target_block).unwrap();
320 let mut inputs_remapping = VarRemapping {
321 remapping: OrderedHashMap::from_iter(zip_eq(
322 target_inputs.iter().map(|var_usage| var_usage.var_id),
323 inputs.iter().cloned(),
324 )),
325 };
326 for (_, src) in inputs_remapping.iter_mut() {
327 if let Some(src_before_remapping) = remappings.get(&src.var_id) {
328 *src = *src_before_remapping;
329 }
330 }
331
332 *target_block = *block_id;
333 *remappings = inputs_remapping;
334 }
335 BlockEnd::Match { info } => {
336 for arm in info.arms_mut() {
337 let Some((block_id, target_inputs)) = duplicates.get(&arm.block_id) else {
338 continue;
339 };
340
341 let inputs = &ctx.block_id_to_inputs[&arm.block_id];
342 arm.block_id = new_goto_block(*block_id, inputs, target_inputs);
343 }
344 }
345 _ => {}
346 }
347 }
348
349 for block in new_blocks {
350 lowered.blocks.push(block);
351 }
352}