cairo_lang_lowering/optimizations/
split_structs.rs1#[cfg(test)]
2#[path = "split_structs_test.rs"]
3mod test;
4
5use std::vec;
6
7use cairo_lang_utils::ordered_hash_map::OrderedHashMap;
8use cairo_lang_utils::unordered_hash_map::UnorderedHashMap;
9use itertools::{Itertools, zip_eq};
10
11use super::var_renamer::VarRenamer;
12use crate::ids::LocationId;
13use crate::utils::{Rebuilder, RebuilderEx};
14use crate::{
15 BlockEnd, BlockId, Lowered, Statement, StatementStructConstruct, StatementStructDestructure,
16 VarRemapping, VarUsage, VariableArena, VariableId,
17};
18
19pub fn split_structs(lowered: &mut Lowered<'_>) {
24 if lowered.blocks.is_empty() {
25 return;
26 }
27
28 let split = get_var_split(lowered);
29 rebuild_blocks(lowered, split);
30}
31
32struct SplitInfo {
34 block_id: BlockId,
36 vars: Vec<VariableId>,
38}
39
40type SplitMapping = UnorderedHashMap<VariableId, SplitInfo>;
41
42type ReconstructionMapping = OrderedHashMap<VariableId, Option<BlockId>>;
47
48fn get_var_split(lowered: &mut Lowered<'_>) -> SplitMapping {
50 let mut split = UnorderedHashMap::<VariableId, SplitInfo>::default();
51
52 let mut stack = vec![BlockId::root()];
53 let mut visited = vec![false; lowered.blocks.len()];
54 while let Some(block_id) = stack.pop() {
55 if visited[block_id.0] {
56 continue;
57 }
58 visited[block_id.0] = true;
59
60 let block = &lowered.blocks[block_id];
61
62 for stmt in block.statements.iter() {
63 if let Statement::StructConstruct(stmt) = stmt {
64 assert!(
65 split
66 .insert(
67 stmt.output,
68 SplitInfo {
69 block_id,
70 vars: stmt.inputs.iter().map(|input| input.var_id).collect_vec(),
71 },
72 )
73 .is_none()
74 );
75 }
76 }
77
78 match &block.end {
79 BlockEnd::Goto(block_id, remappings) => {
80 stack.push(*block_id);
81
82 for (dst, src) in remappings.iter() {
83 split_remapping(
84 *block_id,
85 &mut split,
86 &mut lowered.variables,
87 *dst,
88 src.var_id,
89 );
90 }
91 }
92 BlockEnd::Match { info } => {
93 stack.extend(info.arms().iter().map(|arm| arm.block_id));
94 }
95 BlockEnd::Return(..) => {}
96 BlockEnd::Panic(_) | BlockEnd::NotSet => unreachable!(),
97 }
98 }
99
100 split
101}
102
103fn split_remapping<'db>(
111 target_block_id: BlockId,
112 split: &mut SplitMapping,
113 variables: &mut VariableArena<'db>,
114 dst: VariableId,
115 src: VariableId,
116) {
117 let mut stack = vec![(dst, src)];
118
119 while let Some((dst, src)) = stack.pop() {
120 if split.contains_key(&dst) {
121 continue;
122 }
123 if let Some(SplitInfo { block_id: _, vars: src_vars }) = split.get(&src) {
124 let mut dst_vars = vec![];
125 for split_src in src_vars {
126 let new_var = variables.alloc(variables[*split_src].clone());
127 stack.push((new_var, *split_src));
129 dst_vars.push(new_var);
130 }
131
132 split.insert(dst, SplitInfo { block_id: target_block_id, vars: dst_vars });
133 }
134 }
135}
136
137struct SplitStructsContext<'db, 'a> {
139 reconstructed: ReconstructionMapping,
141 var_remapper: VarRenamer,
143 variables: &'a mut VariableArena<'db>,
145}
146
147fn rebuild_blocks(lowered: &mut Lowered<'_>, split: SplitMapping) {
149 let mut ctx = SplitStructsContext {
150 reconstructed: Default::default(),
151 var_remapper: VarRenamer::default(),
152 variables: &mut lowered.variables,
153 };
154
155 let mut stack = vec![BlockId::root()];
156 let mut visited = vec![false; lowered.blocks.len()];
157 while let Some(block_id) = stack.pop() {
158 if visited[block_id.0] {
159 continue;
160 }
161 visited[block_id.0] = true;
162
163 let block = &mut lowered.blocks[block_id];
164 let old_statements = std::mem::take(&mut block.statements);
165 let statements = &mut block.statements;
166
167 for mut stmt in old_statements {
168 match stmt {
169 Statement::StructDestructure(stmt) => {
170 if let Some(output_split) =
171 split.get(&ctx.var_remapper.map_var_id(stmt.input.var_id))
172 {
173 for (output, new_var) in
174 zip_eq(stmt.outputs.iter(), output_split.vars.to_vec())
175 {
176 assert!(
177 ctx.var_remapper.renamed_vars.insert(*output, new_var).is_none()
178 )
179 }
180 } else {
181 statements.push(Statement::StructDestructure(stmt));
182 }
183 }
184 Statement::StructConstruct(stmt)
185 if split.contains_key(&ctx.var_remapper.map_var_id(stmt.output)) =>
186 {
187 }
189 _ => {
190 for input in stmt.inputs_mut() {
191 input.var_id = ctx.maybe_reconstruct_var(
192 &split,
193 input.var_id,
194 block_id,
195 statements,
196 input.location,
197 );
198 }
199
200 statements.push(stmt);
201 }
202 }
203 }
204
205 match &mut block.end {
206 BlockEnd::Goto(target_block_id, remappings) => {
207 stack.push(*target_block_id);
208
209 let old_remappings = std::mem::take(remappings);
210
211 ctx.rebuild_remapping(
212 &split,
213 block_id,
214 &mut block.statements,
215 old_remappings.remapping.into_iter(),
216 remappings,
217 );
218 }
219 BlockEnd::Match { info } => {
220 stack.extend(info.arms().iter().map(|arm| arm.block_id));
221
222 for input in info.inputs_mut() {
223 input.var_id = ctx.maybe_reconstruct_var(
224 &split,
225 input.var_id,
226 block_id,
227 statements,
228 input.location,
229 );
230 }
231 }
232 BlockEnd::Return(vars, _location) => {
233 for var in vars.iter_mut() {
234 var.var_id = ctx.maybe_reconstruct_var(
235 &split,
236 var.var_id,
237 block_id,
238 statements,
239 var.location,
240 );
241 }
242 }
243 BlockEnd::Panic(_) | BlockEnd::NotSet => unreachable!(),
244 }
245
246 *block = ctx.var_remapper.rebuild_block(block);
248 }
249
250 for (var_id, opt_block_id) in ctx.reconstructed.iter() {
252 if let Some(block_id) = opt_block_id {
253 let split_vars =
254 split.get(var_id).expect("Should be check in `maybe_reconstruct_var`.");
255 lowered.blocks[*block_id].statements.push(Statement::StructConstruct(
256 StatementStructConstruct {
257 inputs: split_vars
258 .vars
259 .iter()
260 .map(|var_id| VarUsage {
261 var_id: ctx.var_remapper.map_var_id(*var_id),
262 location: ctx.variables[*var_id].location,
263 })
264 .collect_vec(),
265 output: *var_id,
266 },
267 ));
268 }
269 }
270}
271
272impl<'db> SplitStructsContext<'db, '_> {
273 fn maybe_reconstruct_var(
277 &mut self,
278 split: &SplitMapping,
279 var_id: VariableId,
280 block_id: BlockId,
281 statements: &mut Vec<Statement<'db>>,
282 location: LocationId<'db>,
283 ) -> VariableId {
284 let var_id = self.var_remapper.map_var_id(var_id);
285 if self.reconstructed.contains_key(&var_id) {
286 return var_id;
287 }
288
289 let Some(split_info) = split.get(&var_id) else {
290 return var_id;
291 };
292
293 let inputs = split_info
294 .vars
295 .iter()
296 .map(|input_var_id| VarUsage {
297 var_id: self.maybe_reconstruct_var(
298 split,
299 *input_var_id,
300 block_id,
301 statements,
302 location,
303 ),
304 location,
305 })
306 .collect_vec();
307
308 if block_id == split_info.block_id || self.variables[var_id].info.copyable.is_err() {
313 let reconstructed_var_id = if block_id == split_info.block_id {
314 self.reconstructed.insert(var_id, None);
317 var_id
318 } else {
319 self.variables.alloc(self.variables[var_id].clone())
321 };
322
323 statements.push(Statement::StructConstruct(StatementStructConstruct {
324 inputs,
325 output: reconstructed_var_id,
326 }));
327
328 reconstructed_var_id
329 } else {
330 assert!(
332 zip_eq(&inputs, &split_info.vars)
333 .all(|(input, var_id)| input.var_id == self.var_remapper.map_var_id(*var_id))
334 );
335
336 self.reconstructed.insert(var_id, Some(split_info.block_id));
338 var_id
339 }
340 }
341
342 fn rebuild_remapping(
345 &mut self,
346 split: &SplitMapping,
347 block_id: BlockId,
348 statements: &mut Vec<Statement<'db>>,
349 remappings: impl DoubleEndedIterator<Item = (VariableId, VarUsage<'db>)>,
350 new_remappings: &mut VarRemapping<'db>,
351 ) {
352 let mut stack = remappings.rev().collect_vec();
353 while let Some((orig_dst, orig_src)) = stack.pop() {
354 let dst = self.var_remapper.map_var_id(orig_dst);
355 let src = self.var_remapper.map_var_id(orig_src.var_id);
356 match (split.get(&dst), split.get(&src)) {
357 (None, None) => {
358 new_remappings
359 .insert(dst, VarUsage { var_id: src, location: orig_src.location });
360 }
361 (Some(dst_split), Some(src_split)) => {
362 stack.extend(zip_eq(
363 dst_split.vars.iter().cloned().rev(),
364 src_split
365 .vars
366 .iter()
367 .map(|var_id| VarUsage { var_id: *var_id, location: orig_src.location })
368 .rev(),
369 ));
370 }
371 (Some(dst_split), None) => {
372 let mut src_vars = vec![];
373
374 for dst in &dst_split.vars {
375 src_vars.push(self.variables.alloc(self.variables[*dst].clone()));
376 }
377
378 statements.push(Statement::StructDestructure(StatementStructDestructure {
379 input: VarUsage { var_id: src, location: orig_src.location },
380 outputs: src_vars.clone(),
381 }));
382
383 stack.extend(zip_eq(
384 dst_split.vars.iter().cloned().rev(),
385 src_vars
386 .into_iter()
387 .map(|var_id| VarUsage { var_id, location: orig_src.location })
388 .rev(),
389 ));
390 }
391 (None, Some(_src_vars)) => {
392 let reconstructed_src = self.maybe_reconstruct_var(
393 split,
394 src,
395 block_id,
396 statements,
397 orig_src.location,
398 );
399 new_remappings.insert(
400 dst,
401 VarUsage { var_id: reconstructed_src, location: orig_src.location },
402 );
403 }
404 }
405 }
406 }
407}