Skip to main content

cairo_lang_lowering/optimizations/
cancel_ops.rs

1#[cfg(test)]
2#[path = "cancel_ops_test.rs"]
3mod test;
4
5use cairo_lang_utils::ordered_hash_set::OrderedHashSet;
6use cairo_lang_utils::unordered_hash_map::UnorderedHashMap;
7use itertools::{Itertools, chain, izip, zip_eq};
8
9use super::var_renamer::VarRenamer;
10use crate::analysis::{Analyzer, BackAnalysis, StatementLocation};
11use crate::utils::{Rebuilder, RebuilderEx};
12use crate::{BlockId, Lowered, MatchInfo, Statement, VarRemapping, VarUsage, VariableId};
13
14/// Cancels out a (StructConstruct, StructDestructure) and (Snap, Desnap) pair.
15///
16///
17/// The algorithm is as follows:
18/// Run backwards analysis with demand to find all the use sites.
19/// When we reach the first item in the pair, check which statement can be removed and
20/// construct the relevant `renamed_vars` mapping.
21///
22/// See CancelOpsContext::handle_stmt for more detail on when it is safe
23/// to remove a statement.
24pub fn cancel_ops<'db>(lowered: &mut Lowered<'db>) {
25    if lowered.blocks.is_empty() {
26        return;
27    }
28    let ctx = CancelOpsContext {
29        lowered,
30        use_sites: Default::default(),
31        var_remapper: Default::default(),
32        aliases: Default::default(),
33        stmts_to_remove: vec![],
34    };
35    let mut analysis = BackAnalysis::new(lowered, ctx);
36    analysis.get_root_info();
37
38    let CancelOpsContext { mut var_remapper, stmts_to_remove, .. } = analysis.analyzer;
39
40    // Remove no-longer needed statements.
41    // Note that dedup() is used since a statement might be marked for removal more than once.
42    for (block_id, stmt_id) in stmts_to_remove
43        .into_iter()
44        .sorted_by_key(|(block_id, stmt_id)| (block_id.0, *stmt_id))
45        .rev()
46        .dedup()
47    {
48        lowered.blocks[block_id].statements.remove(stmt_id);
49    }
50
51    // Rebuild the blocks with the new variable names.
52    for block in lowered.blocks.iter_mut() {
53        *block = var_remapper.rebuild_block(block);
54    }
55}
56
57pub struct CancelOpsContext<'db, 'a> {
58    lowered: &'a Lowered<'db>,
59
60    /// Maps a variable to the use sites of that variable.
61    /// Note that a remapping is considered as usage here.
62    use_sites: UnorderedHashMap<VariableId, Vec<StatementLocation>>,
63
64    /// Maps a variable to the variable that it was renamed to.
65    var_remapper: VarRenamer,
66
67    /// Keeps track of all the aliases created by the renaming.
68    aliases: UnorderedHashMap<VariableId, Vec<VariableId>>,
69
70    /// Statements that can be removed.
71    stmts_to_remove: Vec<StatementLocation>,
72}
73
74/// Similar to `mapping.get(var).or_default()` but works for types that don't implement Default.
75fn get_entry_as_slice<'a, T>(
76    mapping: &'a UnorderedHashMap<VariableId, Vec<T>>,
77    var: &VariableId,
78) -> &'a [T] {
79    match mapping.get(var) {
80        Some(entry) => &entry[..],
81        None => &[],
82    }
83}
84
85/// Returns the use sites of a variable.
86///
87/// Takes 'use_sites' map rather than `CancelOpsContext` to avoid borrowing the entire context.
88fn filter_use_sites<'a, F, T>(
89    use_sites: &'a UnorderedHashMap<VariableId, Vec<StatementLocation>>,
90    var_aliases: &'a UnorderedHashMap<VariableId, Vec<VariableId>>,
91    orig_var_id: &VariableId,
92    mut f: F,
93) -> Vec<T>
94where
95    F: FnMut(&StatementLocation) -> Option<T>,
96{
97    let mut res = vec![];
98
99    let aliases = get_entry_as_slice(var_aliases, orig_var_id);
100
101    for var in chain!(std::iter::once(orig_var_id), aliases) {
102        let use_sites = get_entry_as_slice(use_sites, var);
103        for use_site in use_sites {
104            if let Some(filtered) = f(use_site) {
105                res.push(filtered);
106            }
107        }
108    }
109    res
110}
111
112impl<'db, 'a> CancelOpsContext<'db, 'a> {
113    fn rename_var(&mut self, from: VariableId, to: VariableId) {
114        self.var_remapper.renamed_vars.insert(from, to);
115
116        let mut aliases = Vec::from_iter(chain(
117            std::iter::once(from),
118            get_entry_as_slice(&self.aliases, &from).iter().copied(),
119        ));
120        // Optimize for the case where the alias list of `to` is empty.
121        match self.aliases.entry(to) {
122            std::collections::hash_map::Entry::Occupied(entry) => {
123                aliases.extend(entry.get().iter());
124                *entry.into_mut() = aliases;
125            }
126            std::collections::hash_map::Entry::Vacant(entry) => {
127                entry.insert(aliases);
128            }
129        }
130    }
131
132    fn add_use_site(&mut self, var: VariableId, use_site: StatementLocation) {
133        self.use_sites.entry(var).or_default().push(use_site);
134    }
135
136    /// Handles a statement and returns true if it can be removed.
137    fn handle_stmt(
138        &mut self,
139        stmt: &'a Statement<'db>,
140        statement_location: StatementLocation,
141    ) -> bool {
142        match stmt {
143            Statement::StructDestructure(stmt) => {
144                let mut visited_use_sites = OrderedHashSet::<StatementLocation>::default();
145
146                let mut can_remove_struct_destructure = true;
147
148                let mut constructs = vec![];
149                for output in stmt.outputs.iter() {
150                    constructs.extend(filter_use_sites(
151                        &self.use_sites,
152                        &self.aliases,
153                        output,
154                        |location| match self.lowered.blocks[location.0].statements.get(location.1)
155                        {
156                            _ if !visited_use_sites.insert(*location) => {
157                                // Filter previously seen use sites.
158                                None
159                            }
160                            Some(Statement::StructConstruct(construct_stmt))
161                                if stmt.outputs.len() == construct_stmt.inputs.len()
162                                    && self.lowered.variables[stmt.input.var_id].ty
163                                        == self.lowered.variables[construct_stmt.output].ty
164                                    && zip_eq(
165                                        stmt.outputs.iter(),
166                                        construct_stmt.inputs.iter(),
167                                    )
168                                    .all(|(output, input)| {
169                                        output == &self.var_remapper.map_var_id(input.var_id)
170                                    }) =>
171                            {
172                                self.stmts_to_remove.push(*location);
173                                Some(construct_stmt)
174                            }
175                            _ => {
176                                can_remove_struct_destructure = false;
177                                None
178                            }
179                        },
180                    ));
181                }
182
183                if !(can_remove_struct_destructure
184                    || self.lowered.variables[stmt.input.var_id].info.copyable.is_ok())
185                {
186                    // We can't remove any of the construct statements.
187                    self.stmts_to_remove.truncate(self.stmts_to_remove.len() - constructs.len());
188                    return false;
189                }
190
191                // Mark the statements for removal and set the renaming for it outputs.
192                if can_remove_struct_destructure {
193                    self.stmts_to_remove.push(statement_location);
194                }
195
196                for construct in constructs {
197                    self.rename_var(construct.output, stmt.input.var_id)
198                }
199                can_remove_struct_destructure
200            }
201            Statement::StructConstruct(stmt) => {
202                let mut can_remove_struct_construct = true;
203                let destructures =
204                    filter_use_sites(&self.use_sites, &self.aliases, &stmt.output, |location| {
205                        if let Some(Statement::StructDestructure(destructure_stmt)) =
206                            self.lowered.blocks[location.0].statements.get(location.1)
207                        {
208                            self.stmts_to_remove.push(*location);
209                            Some(destructure_stmt)
210                        } else {
211                            can_remove_struct_construct = false;
212                            None
213                        }
214                    });
215
216                if !(can_remove_struct_construct
217                    || stmt
218                        .inputs
219                        .iter()
220                        .all(|input| self.lowered.variables[input.var_id].info.copyable.is_ok()))
221                {
222                    // We can't remove any of the destructure statements.
223                    self.stmts_to_remove.truncate(self.stmts_to_remove.len() - destructures.len());
224                    return false;
225                }
226
227                // Mark the statements for removal and set the renaming for it outputs.
228                if can_remove_struct_construct {
229                    self.stmts_to_remove.push(statement_location);
230                }
231
232                for destructure_stmt in destructures {
233                    for (output, input) in
234                        izip!(destructure_stmt.outputs.iter(), stmt.inputs.iter())
235                    {
236                        self.rename_var(*output, input.var_id);
237                    }
238                }
239                can_remove_struct_construct
240            }
241            Statement::Snapshot(stmt) => {
242                let mut can_remove_snap = true;
243
244                let desnaps = filter_use_sites(
245                    &self.use_sites,
246                    &self.aliases,
247                    &stmt.snapshot(),
248                    |location| {
249                        if let Some(Statement::Desnap(desnap_stmt)) =
250                            self.lowered.blocks[location.0].statements.get(location.1)
251                        {
252                            self.stmts_to_remove.push(*location);
253                            Some(desnap_stmt)
254                        } else {
255                            can_remove_snap = false;
256                            None
257                        }
258                    },
259                );
260
261                let new_var = if can_remove_snap {
262                    self.stmts_to_remove.push(statement_location);
263                    self.rename_var(stmt.original(), stmt.input.var_id);
264                    stmt.input.var_id
265                } else if desnaps.is_empty()
266                    && self.lowered.variables[stmt.input.var_id].info.copyable.is_err()
267                {
268                    stmt.original()
269                } else {
270                    stmt.input.var_id
271                };
272
273                for desnap in desnaps {
274                    self.rename_var(desnap.output, new_var);
275                }
276                can_remove_snap
277            }
278            _ => false,
279        }
280    }
281}
282
283impl<'db, 'a> Analyzer<'db, 'a> for CancelOpsContext<'db, 'a> {
284    type Info = ();
285
286    fn visit_stmt(
287        &mut self,
288        _info: &mut Self::Info,
289        statement_location: StatementLocation,
290        stmt: &'a Statement<'db>,
291    ) {
292        if !self.handle_stmt(stmt, statement_location) {
293            for input in stmt.inputs() {
294                self.add_use_site(input.var_id, statement_location);
295            }
296        }
297    }
298
299    fn visit_goto(
300        &mut self,
301        _info: &mut Self::Info,
302        statement_location: StatementLocation,
303        _target_block_id: BlockId,
304        remapping: &VarRemapping<'db>,
305    ) {
306        for src in remapping.values() {
307            self.add_use_site(src.var_id, statement_location);
308        }
309    }
310
311    fn merge_match(
312        &mut self,
313        statement_location: StatementLocation,
314        match_info: &'a MatchInfo<'db>,
315        _infos: impl Iterator<Item = Self::Info>,
316    ) -> Self::Info {
317        for var in match_info.inputs() {
318            self.add_use_site(var.var_id, statement_location);
319        }
320    }
321
322    fn info_from_return(
323        &mut self,
324        statement_location: StatementLocation,
325        vars: &[VarUsage<'db>],
326    ) -> Self::Info {
327        for var in vars {
328            self.add_use_site(var.var_id, statement_location);
329        }
330    }
331}