cairo_lang_lowering/optimizations/
split_structs.rs

1#[cfg(test)]
2#[path = "split_structs_test.rs"]
3mod test;
4
5use std::vec;
6
7use cairo_lang_utils::ordered_hash_map::OrderedHashMap;
8use cairo_lang_utils::unordered_hash_map::UnorderedHashMap;
9use itertools::{Itertools, zip_eq};
10
11use super::var_renamer::VarRenamer;
12use crate::ids::LocationId;
13use crate::utils::{Rebuilder, RebuilderEx};
14use crate::{
15    BlockEnd, BlockId, Lowered, Statement, StatementStructConstruct, StatementStructDestructure,
16    VarRemapping, VarUsage, VariableArena, VariableId,
17};
18
19/// Splits all the variables that were created by struct_construct and reintroduces the
20/// struct_construct statement when needed.
21///
22/// Note that if a member is used after the struct then it means that the struct is copyable.
23pub fn split_structs(lowered: &mut Lowered<'_>) {
24    if lowered.blocks.is_empty() {
25        return;
26    }
27
28    let split = get_var_split(lowered);
29    rebuild_blocks(lowered, split);
30}
31
32/// Information about a split variable.
33struct SplitInfo {
34    /// The block_id where the variable was defined.
35    block_id: BlockId,
36    /// The variables resulting from the split.
37    vars: Vec<VariableId>,
38}
39
40type SplitMapping = UnorderedHashMap<VariableId, SplitInfo>;
41
42/// Keeps track of the variables that were reconstructed.
43/// If the value is None the variable was reconstructed at the first usage.
44/// If the value is Some(Block_id) then the variable needs to be reconstructed at the end of
45/// `block_id`
46type ReconstructionMapping = OrderedHashMap<VariableId, Option<BlockId>>;
47
48/// Returns a mapping from variables that should be split to the variables resulting from the split.
49fn get_var_split(lowered: &mut Lowered<'_>) -> SplitMapping {
50    let mut split = UnorderedHashMap::<VariableId, SplitInfo>::default();
51
52    let mut stack = vec![BlockId::root()];
53    let mut visited = vec![false; lowered.blocks.len()];
54    while let Some(block_id) = stack.pop() {
55        if visited[block_id.0] {
56            continue;
57        }
58        visited[block_id.0] = true;
59
60        let block = &lowered.blocks[block_id];
61
62        for stmt in block.statements.iter() {
63            if let Statement::StructConstruct(stmt) = stmt {
64                assert!(
65                    split
66                        .insert(
67                            stmt.output,
68                            SplitInfo {
69                                block_id,
70                                vars: stmt.inputs.iter().map(|input| input.var_id).collect_vec(),
71                            },
72                        )
73                        .is_none()
74                );
75            }
76        }
77
78        match &block.end {
79            BlockEnd::Goto(block_id, remappings) => {
80                stack.push(*block_id);
81
82                for (dst, src) in remappings.iter() {
83                    split_remapping(
84                        *block_id,
85                        &mut split,
86                        &mut lowered.variables,
87                        *dst,
88                        src.var_id,
89                    );
90                }
91            }
92            BlockEnd::Match { info } => {
93                stack.extend(info.arms().iter().map(|arm| arm.block_id));
94            }
95            BlockEnd::Return(..) => {}
96            BlockEnd::Panic(_) | BlockEnd::NotSet => unreachable!(),
97        }
98    }
99
100    split
101}
102
103/// Splits 'dst' according to the split of 'src'.
104///
105/// For example if we have
106///     split('dst') is None
107///     split('src') = (v0, v1) and split(`v1`) = (v3, v4, v5).
108/// The function will create new variables and set:
109///     split('dst') = (v100, v101) and split(`v101`) = (v102, v103, v104).
110fn split_remapping<'db>(
111    target_block_id: BlockId,
112    split: &mut SplitMapping,
113    variables: &mut VariableArena<'db>,
114    dst: VariableId,
115    src: VariableId,
116) {
117    let mut stack = vec![(dst, src)];
118
119    while let Some((dst, src)) = stack.pop() {
120        if split.contains_key(&dst) {
121            continue;
122        }
123        if let Some(SplitInfo { block_id: _, vars: src_vars }) = split.get(&src) {
124            let mut dst_vars = vec![];
125            for split_src in src_vars {
126                let new_var = variables.alloc(variables[*split_src].clone());
127                // Queue inner remapping for possible splitting.
128                stack.push((new_var, *split_src));
129                dst_vars.push(new_var);
130            }
131
132            split.insert(dst, SplitInfo { block_id: target_block_id, vars: dst_vars });
133        }
134    }
135}
136
137// Context for rebuilding the blocks.
138struct SplitStructsContext<'db, 'a> {
139    /// The variables that were reconstructed as they were needed.
140    reconstructed: ReconstructionMapping,
141    // A renamer that keeps track of renamed vars.
142    var_remapper: VarRenamer,
143    // The variables arena.
144    variables: &'a mut VariableArena<'db>,
145}
146
147/// Rebuilds the blocks, with the splitting.
148fn rebuild_blocks(lowered: &mut Lowered<'_>, split: SplitMapping) {
149    let mut ctx = SplitStructsContext {
150        reconstructed: Default::default(),
151        var_remapper: VarRenamer::default(),
152        variables: &mut lowered.variables,
153    };
154
155    let mut stack = vec![BlockId::root()];
156    let mut visited = vec![false; lowered.blocks.len()];
157    while let Some(block_id) = stack.pop() {
158        if visited[block_id.0] {
159            continue;
160        }
161        visited[block_id.0] = true;
162
163        let block = &mut lowered.blocks[block_id];
164        let old_statements = std::mem::take(&mut block.statements);
165        let statements = &mut block.statements;
166
167        for mut stmt in old_statements {
168            match stmt {
169                Statement::StructDestructure(stmt) => {
170                    if let Some(output_split) =
171                        split.get(&ctx.var_remapper.map_var_id(stmt.input.var_id))
172                    {
173                        for (output, new_var) in
174                            zip_eq(stmt.outputs.iter(), output_split.vars.to_vec())
175                        {
176                            assert!(
177                                ctx.var_remapper.renamed_vars.insert(*output, new_var).is_none()
178                            )
179                        }
180                    } else {
181                        statements.push(Statement::StructDestructure(stmt));
182                    }
183                }
184                Statement::StructConstruct(stmt)
185                    if split.contains_key(&ctx.var_remapper.map_var_id(stmt.output)) =>
186                {
187                    // Remove StructConstruct statement.
188                }
189                _ => {
190                    for input in stmt.inputs_mut() {
191                        input.var_id = ctx.maybe_reconstruct_var(
192                            &split,
193                            input.var_id,
194                            block_id,
195                            statements,
196                            input.location,
197                        );
198                    }
199
200                    statements.push(stmt);
201                }
202            }
203        }
204
205        match &mut block.end {
206            BlockEnd::Goto(target_block_id, remappings) => {
207                stack.push(*target_block_id);
208
209                let old_remappings = std::mem::take(remappings);
210
211                ctx.rebuild_remapping(
212                    &split,
213                    block_id,
214                    &mut block.statements,
215                    old_remappings.remapping.into_iter(),
216                    remappings,
217                );
218            }
219            BlockEnd::Match { info } => {
220                stack.extend(info.arms().iter().map(|arm| arm.block_id));
221
222                for input in info.inputs_mut() {
223                    input.var_id = ctx.maybe_reconstruct_var(
224                        &split,
225                        input.var_id,
226                        block_id,
227                        statements,
228                        input.location,
229                    );
230                }
231            }
232            BlockEnd::Return(vars, _location) => {
233                for var in vars.iter_mut() {
234                    var.var_id = ctx.maybe_reconstruct_var(
235                        &split,
236                        var.var_id,
237                        block_id,
238                        statements,
239                        var.location,
240                    );
241                }
242            }
243            BlockEnd::Panic(_) | BlockEnd::NotSet => unreachable!(),
244        }
245
246        // Remap block variables.
247        *block = ctx.var_remapper.rebuild_block(block);
248    }
249
250    // Add all the end of block reconstructions.
251    for (var_id, opt_block_id) in ctx.reconstructed.iter() {
252        if let Some(block_id) = opt_block_id {
253            let split_vars =
254                split.get(var_id).expect("Should be check in `maybe_reconstruct_var`.");
255            lowered.blocks[*block_id].statements.push(Statement::StructConstruct(
256                StatementStructConstruct {
257                    inputs: split_vars
258                        .vars
259                        .iter()
260                        .map(|var_id| VarUsage {
261                            var_id: ctx.var_remapper.map_var_id(*var_id),
262                            location: ctx.variables[*var_id].location,
263                        })
264                        .collect_vec(),
265                    output: *var_id,
266                },
267            ));
268        }
269    }
270}
271
272impl<'db> SplitStructsContext<'db, '_> {
273    /// Given 'var_id' check if `var_remapper.map_var_id(*var_id)` was split and not reconstructed
274    /// yet, if this is the case, reconstructs the var or marks the variable for reconstruction and
275    /// returns the reconstructed variable id.
276    fn maybe_reconstruct_var(
277        &mut self,
278        split: &SplitMapping,
279        var_id: VariableId,
280        block_id: BlockId,
281        statements: &mut Vec<Statement<'db>>,
282        location: LocationId<'db>,
283    ) -> VariableId {
284        let var_id = self.var_remapper.map_var_id(var_id);
285        if self.reconstructed.contains_key(&var_id) {
286            return var_id;
287        }
288
289        let Some(split_info) = split.get(&var_id) else {
290            return var_id;
291        };
292
293        let inputs = split_info
294            .vars
295            .iter()
296            .map(|input_var_id| VarUsage {
297                var_id: self.maybe_reconstruct_var(
298                    split,
299                    *input_var_id,
300                    block_id,
301                    statements,
302                    location,
303                ),
304                location,
305            })
306            .collect_vec();
307
308        // If the variable was defined in the same block or it is non-copyable then we can
309        // reconstruct it before the first usage. If not we need to reconstruct it at the
310        // end of the original block as it might be used by more than one of the
311        // children.
312        if block_id == split_info.block_id || self.variables[var_id].info.copyable.is_err() {
313            let reconstructed_var_id = if block_id == split_info.block_id {
314                // If the reconstruction is in the original block we can reuse the variable id
315                // and mark the variable as reconstructed.
316                self.reconstructed.insert(var_id, None);
317                var_id
318            } else {
319                // Use a new variable id in case the variable is also reconstructed elsewhere.
320                self.variables.alloc(self.variables[var_id].clone())
321            };
322
323            statements.push(Statement::StructConstruct(StatementStructConstruct {
324                inputs,
325                output: reconstructed_var_id,
326            }));
327
328            reconstructed_var_id
329        } else {
330            // All the inputs should use the original var names.
331            assert!(
332                zip_eq(&inputs, &split_info.vars)
333                    .all(|(input, var_id)| input.var_id == self.var_remapper.map_var_id(*var_id))
334            );
335
336            // Mark `var_id` for reconstruction at the end of `split_info.block_id`
337            self.reconstructed.insert(var_id, Some(split_info.block_id));
338            var_id
339        }
340    }
341
342    /// Given an iterator over the original remapping, rebuilds the remapping with the given
343    /// splitting of variables.
344    fn rebuild_remapping(
345        &mut self,
346        split: &SplitMapping,
347        block_id: BlockId,
348        statements: &mut Vec<Statement<'db>>,
349        remappings: impl DoubleEndedIterator<Item = (VariableId, VarUsage<'db>)>,
350        new_remappings: &mut VarRemapping<'db>,
351    ) {
352        let mut stack = remappings.rev().collect_vec();
353        while let Some((orig_dst, orig_src)) = stack.pop() {
354            let dst = self.var_remapper.map_var_id(orig_dst);
355            let src = self.var_remapper.map_var_id(orig_src.var_id);
356            match (split.get(&dst), split.get(&src)) {
357                (None, None) => {
358                    new_remappings
359                        .insert(dst, VarUsage { var_id: src, location: orig_src.location });
360                }
361                (Some(dst_split), Some(src_split)) => {
362                    stack.extend(zip_eq(
363                        dst_split.vars.iter().cloned().rev(),
364                        src_split
365                            .vars
366                            .iter()
367                            .map(|var_id| VarUsage { var_id: *var_id, location: orig_src.location })
368                            .rev(),
369                    ));
370                }
371                (Some(dst_split), None) => {
372                    let mut src_vars = vec![];
373
374                    for dst in &dst_split.vars {
375                        src_vars.push(self.variables.alloc(self.variables[*dst].clone()));
376                    }
377
378                    statements.push(Statement::StructDestructure(StatementStructDestructure {
379                        input: VarUsage { var_id: src, location: orig_src.location },
380                        outputs: src_vars.clone(),
381                    }));
382
383                    stack.extend(zip_eq(
384                        dst_split.vars.iter().cloned().rev(),
385                        src_vars
386                            .into_iter()
387                            .map(|var_id| VarUsage { var_id, location: orig_src.location })
388                            .rev(),
389                    ));
390                }
391                (None, Some(_src_vars)) => {
392                    let reconstructed_src = self.maybe_reconstruct_var(
393                        split,
394                        src,
395                        block_id,
396                        statements,
397                        orig_src.location,
398                    );
399                    new_remappings.insert(
400                        dst,
401                        VarUsage { var_id: reconstructed_src, location: orig_src.location },
402                    );
403                }
404            }
405        }
406    }
407}