Skip to main content

cairo_lang_lowering/optimizations/
split_structs.rs

1#[cfg(test)]
2#[path = "split_structs_test.rs"]
3mod test;
4
5use cairo_lang_utils::ordered_hash_map::OrderedHashMap;
6use cairo_lang_utils::unordered_hash_map::UnorderedHashMap;
7use itertools::{Itertools, zip_eq};
8
9use super::var_renamer::VarRenamer;
10use crate::ids::LocationId;
11use crate::utils::{Rebuilder, RebuilderEx};
12use crate::{
13    BlockEnd, BlockId, Lowered, Statement, StatementStructConstruct, StatementStructDestructure,
14    VarRemapping, VarUsage, VariableArena, VariableId,
15};
16
17/// Splits all the variables that were created by struct_construct and reintroduces the
18/// struct_construct statement when needed.
19///
20/// Note that if a member is used after the struct then it means that the struct is copyable.
21pub fn split_structs(lowered: &mut Lowered<'_>) {
22    if lowered.blocks.is_empty() {
23        return;
24    }
25
26    let split = get_var_split(lowered);
27    rebuild_blocks(lowered, split);
28}
29
30/// Information about a split variable.
31struct SplitInfo {
32    /// The block_id where the variable was defined.
33    block_id: BlockId,
34    /// The variables resulting from the split.
35    vars: Vec<VariableId>,
36}
37
38type SplitMapping = UnorderedHashMap<VariableId, SplitInfo>;
39
40/// Keeps track of the variables that were reconstructed.
41/// If the value is None the variable was reconstructed at the first usage.
42/// If the value is Some(Block_id) then the variable needs to be reconstructed at the end of
43/// `block_id`.
44type ReconstructionMapping = OrderedHashMap<VariableId, Option<BlockId>>;
45
46/// Returns a mapping from variables that should be split to the variables resulting from the split.
47fn get_var_split(lowered: &mut Lowered<'_>) -> SplitMapping {
48    let mut split = UnorderedHashMap::<VariableId, SplitInfo>::default();
49
50    let mut stack = vec![BlockId::root()];
51    let mut visited = vec![false; lowered.blocks.len()];
52    while let Some(block_id) = stack.pop() {
53        if visited[block_id.0] {
54            continue;
55        }
56        visited[block_id.0] = true;
57
58        let block = &lowered.blocks[block_id];
59
60        for stmt in block.statements.iter() {
61            if let Statement::StructConstruct(stmt) = stmt {
62                assert!(
63                    split
64                        .insert(
65                            stmt.output,
66                            SplitInfo {
67                                block_id,
68                                vars: stmt.inputs.iter().map(|input| input.var_id).collect_vec(),
69                            },
70                        )
71                        .is_none()
72                );
73            }
74        }
75
76        match &block.end {
77            BlockEnd::Goto(block_id, remappings) => {
78                stack.push(*block_id);
79
80                for (dst, src) in remappings.iter() {
81                    split_remapping(
82                        *block_id,
83                        &mut split,
84                        &mut lowered.variables,
85                        *dst,
86                        src.var_id,
87                    );
88                }
89            }
90            BlockEnd::Match { info } => {
91                stack.extend(info.arms().iter().map(|arm| arm.block_id));
92            }
93            BlockEnd::Return(..) => {}
94            BlockEnd::Panic(_) | BlockEnd::NotSet => unreachable!(),
95        }
96    }
97
98    split
99}
100
101/// Splits 'dst' according to the split of 'src'.
102///
103/// For example if we have
104///     split('dst') is None
105///     split('src') = (v0, v1) and split(`v1`) = (v3, v4, v5).
106/// The function will create new variables and set:
107///     split('dst') = (v100, v101) and split(`v101`) = (v102, v103, v104).
108fn split_remapping<'db>(
109    target_block_id: BlockId,
110    split: &mut SplitMapping,
111    variables: &mut VariableArena<'db>,
112    dst: VariableId,
113    src: VariableId,
114) {
115    let mut stack = vec![(dst, src)];
116
117    while let Some((dst, src)) = stack.pop() {
118        if split.contains_key(&dst) {
119            continue;
120        }
121        if let Some(SplitInfo { block_id: _, vars: src_vars }) = split.get(&src) {
122            let mut dst_vars = vec![];
123            for split_src in src_vars {
124                let new_var = variables.alloc(variables[*split_src].clone());
125                // Queue inner remapping for possible splitting.
126                stack.push((new_var, *split_src));
127                dst_vars.push(new_var);
128            }
129
130            split.insert(dst, SplitInfo { block_id: target_block_id, vars: dst_vars });
131        }
132    }
133}
134
135// Context for rebuilding the blocks.
136struct SplitStructsContext<'db, 'a> {
137    /// The variables that were reconstructed as they were needed.
138    reconstructed: ReconstructionMapping,
139    // A renamer that keeps track of renamed vars.
140    var_remapper: VarRenamer,
141    // The variables arena.
142    variables: &'a mut VariableArena<'db>,
143}
144
145/// Rebuilds the blocks, with the splitting.
146fn rebuild_blocks(lowered: &mut Lowered<'_>, split: SplitMapping) {
147    let mut ctx = SplitStructsContext {
148        reconstructed: Default::default(),
149        var_remapper: VarRenamer::default(),
150        variables: &mut lowered.variables,
151    };
152
153    let mut stack = vec![BlockId::root()];
154    let mut visited = vec![false; lowered.blocks.len()];
155    while let Some(block_id) = stack.pop() {
156        if visited[block_id.0] {
157            continue;
158        }
159        visited[block_id.0] = true;
160
161        let block = &mut lowered.blocks[block_id];
162        let old_statements = std::mem::take(&mut block.statements);
163        let statements = &mut block.statements;
164
165        for mut stmt in old_statements {
166            match stmt {
167                Statement::StructDestructure(stmt) => {
168                    if let Some(output_split) =
169                        split.get(&ctx.var_remapper.map_var_id(stmt.input.var_id))
170                    {
171                        for (output, new_var) in zip_eq(&stmt.outputs, &output_split.vars) {
172                            assert!(
173                                ctx.var_remapper.renamed_vars.insert(*output, *new_var).is_none()
174                            )
175                        }
176                    } else {
177                        statements.push(Statement::StructDestructure(stmt));
178                    }
179                }
180                Statement::StructConstruct(stmt)
181                    if split.contains_key(&ctx.var_remapper.map_var_id(stmt.output)) =>
182                {
183                    // Remove StructConstruct statement.
184                }
185                _ => {
186                    for input in stmt.inputs_mut() {
187                        input.var_id = ctx.maybe_reconstruct_var(
188                            &split,
189                            input.var_id,
190                            block_id,
191                            statements,
192                            input.location,
193                        );
194                    }
195
196                    statements.push(stmt);
197                }
198            }
199        }
200
201        match &mut block.end {
202            BlockEnd::Goto(target_block_id, remappings) => {
203                stack.push(*target_block_id);
204
205                let old_remappings = std::mem::take(remappings);
206
207                ctx.rebuild_remapping(
208                    &split,
209                    block_id,
210                    &mut block.statements,
211                    old_remappings.remapping.into_iter(),
212                    remappings,
213                );
214            }
215            BlockEnd::Match { info } => {
216                stack.extend(info.arms().iter().map(|arm| arm.block_id));
217
218                for input in info.inputs_mut() {
219                    input.var_id = ctx.maybe_reconstruct_var(
220                        &split,
221                        input.var_id,
222                        block_id,
223                        statements,
224                        input.location,
225                    );
226                }
227            }
228            BlockEnd::Return(vars, _location) => {
229                for var in vars.iter_mut() {
230                    var.var_id = ctx.maybe_reconstruct_var(
231                        &split,
232                        var.var_id,
233                        block_id,
234                        statements,
235                        var.location,
236                    );
237                }
238            }
239            BlockEnd::Panic(_) | BlockEnd::NotSet => unreachable!(),
240        }
241
242        // Remap block variables.
243        *block = ctx.var_remapper.rebuild_block(block);
244    }
245
246    // Add all the end of block reconstructions.
247    for (var_id, opt_block_id) in ctx.reconstructed.iter() {
248        if let Some(block_id) = opt_block_id {
249            let split_vars =
250                split.get(var_id).expect("Should be check in `maybe_reconstruct_var`.");
251            lowered.blocks[*block_id].statements.push(Statement::StructConstruct(
252                StatementStructConstruct {
253                    inputs: split_vars
254                        .vars
255                        .iter()
256                        .map(|var_id| VarUsage {
257                            var_id: ctx.var_remapper.map_var_id(*var_id),
258                            location: ctx.variables[*var_id].location,
259                        })
260                        .collect_vec(),
261                    output: *var_id,
262                },
263            ));
264        }
265    }
266}
267
268impl<'db> SplitStructsContext<'db, '_> {
269    /// Given 'var_id' check if `var_remapper.map_var_id(*var_id)` was split and not reconstructed
270    /// yet, if this is the case, reconstructs the var or marks the variable for reconstruction and
271    /// returns the reconstructed variable id.
272    fn maybe_reconstruct_var(
273        &mut self,
274        split: &SplitMapping,
275        var_id: VariableId,
276        block_id: BlockId,
277        statements: &mut Vec<Statement<'db>>,
278        location: LocationId<'db>,
279    ) -> VariableId {
280        let var_id = self.var_remapper.map_var_id(var_id);
281        if self.reconstructed.contains_key(&var_id) {
282            return var_id;
283        }
284
285        let Some(split_info) = split.get(&var_id) else {
286            return var_id;
287        };
288
289        let inputs = split_info
290            .vars
291            .iter()
292            .map(|input_var_id| VarUsage {
293                var_id: self.maybe_reconstruct_var(
294                    split,
295                    *input_var_id,
296                    block_id,
297                    statements,
298                    location,
299                ),
300                location,
301            })
302            .collect_vec();
303
304        // If the variable was defined in the same block or it is non-copyable then we can
305        // reconstruct it before the first usage. If not we need to reconstruct it at the
306        // end of the original block as it might be used by more than one of the
307        // children.
308        if block_id == split_info.block_id || self.variables[var_id].info.copyable.is_err() {
309            let reconstructed_var_id = if block_id == split_info.block_id {
310                // If the reconstruction is in the original block we can reuse the variable id
311                // and mark the variable as reconstructed.
312                self.reconstructed.insert(var_id, None);
313                var_id
314            } else {
315                // Use a new variable id in case the variable is also reconstructed elsewhere.
316                self.variables.alloc(self.variables[var_id].clone())
317            };
318
319            statements.push(Statement::StructConstruct(StatementStructConstruct {
320                inputs,
321                output: reconstructed_var_id,
322            }));
323
324            reconstructed_var_id
325        } else {
326            // All the inputs should use the original var names.
327            assert!(
328                zip_eq(&inputs, &split_info.vars)
329                    .all(|(input, var_id)| input.var_id == self.var_remapper.map_var_id(*var_id))
330            );
331
332            // Mark `var_id` for reconstruction at the end of `split_info.block_id`
333            self.reconstructed.insert(var_id, Some(split_info.block_id));
334            var_id
335        }
336    }
337
338    /// Given an iterator over the original remapping, rebuilds the remapping with the given
339    /// splitting of variables.
340    fn rebuild_remapping(
341        &mut self,
342        split: &SplitMapping,
343        block_id: BlockId,
344        statements: &mut Vec<Statement<'db>>,
345        remappings: impl DoubleEndedIterator<Item = (VariableId, VarUsage<'db>)>,
346        new_remappings: &mut VarRemapping<'db>,
347    ) {
348        let mut stack = remappings.rev().collect_vec();
349        while let Some((orig_dst, orig_src)) = stack.pop() {
350            let dst = self.var_remapper.map_var_id(orig_dst);
351            let src = self.var_remapper.map_var_id(orig_src.var_id);
352            match (split.get(&dst), split.get(&src)) {
353                (None, None) => {
354                    new_remappings
355                        .insert(dst, VarUsage { var_id: src, location: orig_src.location });
356                }
357                (Some(dst_split), Some(src_split)) => {
358                    stack.extend(zip_eq(
359                        dst_split.vars.iter().cloned().rev(),
360                        src_split
361                            .vars
362                            .iter()
363                            .map(|var_id| VarUsage { var_id: *var_id, location: orig_src.location })
364                            .rev(),
365                    ));
366                }
367                (Some(dst_split), None) => {
368                    let mut src_vars = vec![];
369
370                    for dst in &dst_split.vars {
371                        src_vars.push(self.variables.alloc(self.variables[*dst].clone()));
372                    }
373
374                    statements.push(Statement::StructDestructure(StatementStructDestructure {
375                        input: VarUsage { var_id: src, location: orig_src.location },
376                        outputs: src_vars.clone(),
377                    }));
378
379                    stack.extend(zip_eq(
380                        dst_split.vars.iter().cloned().rev(),
381                        src_vars
382                            .into_iter()
383                            .map(|var_id| VarUsage { var_id, location: orig_src.location })
384                            .rev(),
385                    ));
386                }
387                (None, Some(_src_vars)) => {
388                    let reconstructed_src = self.maybe_reconstruct_var(
389                        split,
390                        src,
391                        block_id,
392                        statements,
393                        orig_src.location,
394                    );
395                    new_remappings.insert(
396                        dst,
397                        VarUsage { var_id: reconstructed_src, location: orig_src.location },
398                    );
399                }
400            }
401        }
402    }
403}