cairo_lang_lowering/optimizations/
split_structs.rs1#[cfg(test)]
2#[path = "split_structs_test.rs"]
3mod test;
4
5use cairo_lang_utils::ordered_hash_map::OrderedHashMap;
6use cairo_lang_utils::unordered_hash_map::UnorderedHashMap;
7use itertools::{Itertools, zip_eq};
8
9use super::var_renamer::VarRenamer;
10use crate::ids::LocationId;
11use crate::utils::{Rebuilder, RebuilderEx};
12use crate::{
13 BlockEnd, BlockId, Lowered, Statement, StatementStructConstruct, StatementStructDestructure,
14 VarRemapping, VarUsage, VariableArena, VariableId,
15};
16
17pub fn split_structs(lowered: &mut Lowered<'_>) {
22 if lowered.blocks.is_empty() {
23 return;
24 }
25
26 let split = get_var_split(lowered);
27 rebuild_blocks(lowered, split);
28}
29
30struct SplitInfo {
32 block_id: BlockId,
34 vars: Vec<VariableId>,
36}
37
38type SplitMapping = UnorderedHashMap<VariableId, SplitInfo>;
39
40type ReconstructionMapping = OrderedHashMap<VariableId, Option<BlockId>>;
45
46fn get_var_split(lowered: &mut Lowered<'_>) -> SplitMapping {
48 let mut split = UnorderedHashMap::<VariableId, SplitInfo>::default();
49
50 let mut stack = vec![BlockId::root()];
51 let mut visited = vec![false; lowered.blocks.len()];
52 while let Some(block_id) = stack.pop() {
53 if visited[block_id.0] {
54 continue;
55 }
56 visited[block_id.0] = true;
57
58 let block = &lowered.blocks[block_id];
59
60 for stmt in block.statements.iter() {
61 if let Statement::StructConstruct(stmt) = stmt {
62 assert!(
63 split
64 .insert(
65 stmt.output,
66 SplitInfo {
67 block_id,
68 vars: stmt.inputs.iter().map(|input| input.var_id).collect_vec(),
69 },
70 )
71 .is_none()
72 );
73 }
74 }
75
76 match &block.end {
77 BlockEnd::Goto(block_id, remappings) => {
78 stack.push(*block_id);
79
80 for (dst, src) in remappings.iter() {
81 split_remapping(
82 *block_id,
83 &mut split,
84 &mut lowered.variables,
85 *dst,
86 src.var_id,
87 );
88 }
89 }
90 BlockEnd::Match { info } => {
91 stack.extend(info.arms().iter().map(|arm| arm.block_id));
92 }
93 BlockEnd::Return(..) => {}
94 BlockEnd::Panic(_) | BlockEnd::NotSet => unreachable!(),
95 }
96 }
97
98 split
99}
100
101fn split_remapping<'db>(
109 target_block_id: BlockId,
110 split: &mut SplitMapping,
111 variables: &mut VariableArena<'db>,
112 dst: VariableId,
113 src: VariableId,
114) {
115 let mut stack = vec![(dst, src)];
116
117 while let Some((dst, src)) = stack.pop() {
118 if split.contains_key(&dst) {
119 continue;
120 }
121 if let Some(SplitInfo { block_id: _, vars: src_vars }) = split.get(&src) {
122 let mut dst_vars = vec![];
123 for split_src in src_vars {
124 let new_var = variables.alloc(variables[*split_src].clone());
125 stack.push((new_var, *split_src));
127 dst_vars.push(new_var);
128 }
129
130 split.insert(dst, SplitInfo { block_id: target_block_id, vars: dst_vars });
131 }
132 }
133}
134
135struct SplitStructsContext<'db, 'a> {
137 reconstructed: ReconstructionMapping,
139 var_remapper: VarRenamer,
141 variables: &'a mut VariableArena<'db>,
143}
144
145fn rebuild_blocks(lowered: &mut Lowered<'_>, split: SplitMapping) {
147 let mut ctx = SplitStructsContext {
148 reconstructed: Default::default(),
149 var_remapper: VarRenamer::default(),
150 variables: &mut lowered.variables,
151 };
152
153 let mut stack = vec![BlockId::root()];
154 let mut visited = vec![false; lowered.blocks.len()];
155 while let Some(block_id) = stack.pop() {
156 if visited[block_id.0] {
157 continue;
158 }
159 visited[block_id.0] = true;
160
161 let block = &mut lowered.blocks[block_id];
162 let old_statements = std::mem::take(&mut block.statements);
163 let statements = &mut block.statements;
164
165 for mut stmt in old_statements {
166 match stmt {
167 Statement::StructDestructure(stmt) => {
168 if let Some(output_split) =
169 split.get(&ctx.var_remapper.map_var_id(stmt.input.var_id))
170 {
171 for (output, new_var) in zip_eq(&stmt.outputs, &output_split.vars) {
172 assert!(
173 ctx.var_remapper.renamed_vars.insert(*output, *new_var).is_none()
174 )
175 }
176 } else {
177 statements.push(Statement::StructDestructure(stmt));
178 }
179 }
180 Statement::StructConstruct(stmt)
181 if split.contains_key(&ctx.var_remapper.map_var_id(stmt.output)) =>
182 {
183 }
185 _ => {
186 for input in stmt.inputs_mut() {
187 input.var_id = ctx.maybe_reconstruct_var(
188 &split,
189 input.var_id,
190 block_id,
191 statements,
192 input.location,
193 );
194 }
195
196 statements.push(stmt);
197 }
198 }
199 }
200
201 match &mut block.end {
202 BlockEnd::Goto(target_block_id, remappings) => {
203 stack.push(*target_block_id);
204
205 let old_remappings = std::mem::take(remappings);
206
207 ctx.rebuild_remapping(
208 &split,
209 block_id,
210 &mut block.statements,
211 old_remappings.remapping.into_iter(),
212 remappings,
213 );
214 }
215 BlockEnd::Match { info } => {
216 stack.extend(info.arms().iter().map(|arm| arm.block_id));
217
218 for input in info.inputs_mut() {
219 input.var_id = ctx.maybe_reconstruct_var(
220 &split,
221 input.var_id,
222 block_id,
223 statements,
224 input.location,
225 );
226 }
227 }
228 BlockEnd::Return(vars, _location) => {
229 for var in vars.iter_mut() {
230 var.var_id = ctx.maybe_reconstruct_var(
231 &split,
232 var.var_id,
233 block_id,
234 statements,
235 var.location,
236 );
237 }
238 }
239 BlockEnd::Panic(_) | BlockEnd::NotSet => unreachable!(),
240 }
241
242 *block = ctx.var_remapper.rebuild_block(block);
244 }
245
246 for (var_id, opt_block_id) in ctx.reconstructed.iter() {
248 if let Some(block_id) = opt_block_id {
249 let split_vars =
250 split.get(var_id).expect("Should be check in `maybe_reconstruct_var`.");
251 lowered.blocks[*block_id].statements.push(Statement::StructConstruct(
252 StatementStructConstruct {
253 inputs: split_vars
254 .vars
255 .iter()
256 .map(|var_id| VarUsage {
257 var_id: ctx.var_remapper.map_var_id(*var_id),
258 location: ctx.variables[*var_id].location,
259 })
260 .collect_vec(),
261 output: *var_id,
262 },
263 ));
264 }
265 }
266}
267
268impl<'db> SplitStructsContext<'db, '_> {
269 fn maybe_reconstruct_var(
273 &mut self,
274 split: &SplitMapping,
275 var_id: VariableId,
276 block_id: BlockId,
277 statements: &mut Vec<Statement<'db>>,
278 location: LocationId<'db>,
279 ) -> VariableId {
280 let var_id = self.var_remapper.map_var_id(var_id);
281 if self.reconstructed.contains_key(&var_id) {
282 return var_id;
283 }
284
285 let Some(split_info) = split.get(&var_id) else {
286 return var_id;
287 };
288
289 let inputs = split_info
290 .vars
291 .iter()
292 .map(|input_var_id| VarUsage {
293 var_id: self.maybe_reconstruct_var(
294 split,
295 *input_var_id,
296 block_id,
297 statements,
298 location,
299 ),
300 location,
301 })
302 .collect_vec();
303
304 if block_id == split_info.block_id || self.variables[var_id].info.copyable.is_err() {
309 let reconstructed_var_id = if block_id == split_info.block_id {
310 self.reconstructed.insert(var_id, None);
313 var_id
314 } else {
315 self.variables.alloc(self.variables[var_id].clone())
317 };
318
319 statements.push(Statement::StructConstruct(StatementStructConstruct {
320 inputs,
321 output: reconstructed_var_id,
322 }));
323
324 reconstructed_var_id
325 } else {
326 assert!(
328 zip_eq(&inputs, &split_info.vars)
329 .all(|(input, var_id)| input.var_id == self.var_remapper.map_var_id(*var_id))
330 );
331
332 self.reconstructed.insert(var_id, Some(split_info.block_id));
334 var_id
335 }
336 }
337
338 fn rebuild_remapping(
341 &mut self,
342 split: &SplitMapping,
343 block_id: BlockId,
344 statements: &mut Vec<Statement<'db>>,
345 remappings: impl DoubleEndedIterator<Item = (VariableId, VarUsage<'db>)>,
346 new_remappings: &mut VarRemapping<'db>,
347 ) {
348 let mut stack = remappings.rev().collect_vec();
349 while let Some((orig_dst, orig_src)) = stack.pop() {
350 let dst = self.var_remapper.map_var_id(orig_dst);
351 let src = self.var_remapper.map_var_id(orig_src.var_id);
352 match (split.get(&dst), split.get(&src)) {
353 (None, None) => {
354 new_remappings
355 .insert(dst, VarUsage { var_id: src, location: orig_src.location });
356 }
357 (Some(dst_split), Some(src_split)) => {
358 stack.extend(zip_eq(
359 dst_split.vars.iter().cloned().rev(),
360 src_split
361 .vars
362 .iter()
363 .map(|var_id| VarUsage { var_id: *var_id, location: orig_src.location })
364 .rev(),
365 ));
366 }
367 (Some(dst_split), None) => {
368 let mut src_vars = vec![];
369
370 for dst in &dst_split.vars {
371 src_vars.push(self.variables.alloc(self.variables[*dst].clone()));
372 }
373
374 statements.push(Statement::StructDestructure(StatementStructDestructure {
375 input: VarUsage { var_id: src, location: orig_src.location },
376 outputs: src_vars.clone(),
377 }));
378
379 stack.extend(zip_eq(
380 dst_split.vars.iter().cloned().rev(),
381 src_vars
382 .into_iter()
383 .map(|var_id| VarUsage { var_id, location: orig_src.location })
384 .rev(),
385 ));
386 }
387 (None, Some(_src_vars)) => {
388 let reconstructed_src = self.maybe_reconstruct_var(
389 split,
390 src,
391 block_id,
392 statements,
393 orig_src.location,
394 );
395 new_remappings.insert(
396 dst,
397 VarUsage { var_id: reconstructed_src, location: orig_src.location },
398 );
399 }
400 }
401 }
402 }
403}