cairo_lang_lowering/optimizations/
cancel_ops.rs1#[cfg(test)]
2#[path = "cancel_ops_test.rs"]
3mod test;
4
5use cairo_lang_utils::ordered_hash_set::OrderedHashSet;
6use cairo_lang_utils::unordered_hash_map::UnorderedHashMap;
7use itertools::{Itertools, chain, izip, zip_eq};
8
9use super::var_renamer::VarRenamer;
10use crate::analysis::{Analyzer, BackAnalysis, StatementLocation};
11use crate::utils::{Rebuilder, RebuilderEx};
12use crate::{BlockId, Lowered, MatchInfo, Statement, VarRemapping, VarUsage, VariableId};
13
14pub fn cancel_ops<'db>(lowered: &mut Lowered<'db>) {
25 if lowered.blocks.is_empty() {
26 return;
27 }
28 let ctx = CancelOpsContext {
29 lowered,
30 use_sites: Default::default(),
31 var_remapper: Default::default(),
32 aliases: Default::default(),
33 stmts_to_remove: vec![],
34 };
35 let mut analysis = BackAnalysis::new(lowered, ctx);
36 analysis.get_root_info();
37
38 let CancelOpsContext { mut var_remapper, stmts_to_remove, .. } = analysis.analyzer;
39
40 for (block_id, stmt_id) in stmts_to_remove
43 .into_iter()
44 .sorted_by_key(|(block_id, stmt_id)| (block_id.0, *stmt_id))
45 .rev()
46 .dedup()
47 {
48 lowered.blocks[block_id].statements.remove(stmt_id);
49 }
50
51 for block in lowered.blocks.iter_mut() {
53 *block = var_remapper.rebuild_block(block);
54 }
55}
56
57pub struct CancelOpsContext<'db, 'a> {
58 lowered: &'a Lowered<'db>,
59
60 use_sites: UnorderedHashMap<VariableId, Vec<StatementLocation>>,
63
64 var_remapper: VarRenamer,
66
67 aliases: UnorderedHashMap<VariableId, Vec<VariableId>>,
69
70 stmts_to_remove: Vec<StatementLocation>,
72}
73
74fn get_entry_as_slice<'a, T>(
76 mapping: &'a UnorderedHashMap<VariableId, Vec<T>>,
77 var: &VariableId,
78) -> &'a [T] {
79 match mapping.get(var) {
80 Some(entry) => &entry[..],
81 None => &[],
82 }
83}
84
85fn filter_use_sites<'a, F, T>(
89 use_sites: &'a UnorderedHashMap<VariableId, Vec<StatementLocation>>,
90 var_aliases: &'a UnorderedHashMap<VariableId, Vec<VariableId>>,
91 orig_var_id: &VariableId,
92 mut f: F,
93) -> Vec<T>
94where
95 F: FnMut(&StatementLocation) -> Option<T>,
96{
97 let mut res = vec![];
98
99 let aliases = get_entry_as_slice(var_aliases, orig_var_id);
100
101 for var in chain!(std::iter::once(orig_var_id), aliases) {
102 let use_sites = get_entry_as_slice(use_sites, var);
103 for use_site in use_sites {
104 if let Some(filtered) = f(use_site) {
105 res.push(filtered);
106 }
107 }
108 }
109 res
110}
111
112impl<'db, 'a> CancelOpsContext<'db, 'a> {
113 fn rename_var(&mut self, from: VariableId, to: VariableId) {
114 self.var_remapper.renamed_vars.insert(from, to);
115
116 let mut aliases = Vec::from_iter(chain(
117 std::iter::once(from),
118 get_entry_as_slice(&self.aliases, &from).iter().copied(),
119 ));
120 match self.aliases.entry(to) {
122 std::collections::hash_map::Entry::Occupied(entry) => {
123 aliases.extend(entry.get().iter());
124 *entry.into_mut() = aliases;
125 }
126 std::collections::hash_map::Entry::Vacant(entry) => {
127 entry.insert(aliases);
128 }
129 }
130 }
131
132 fn add_use_site(&mut self, var: VariableId, use_site: StatementLocation) {
133 self.use_sites.entry(var).or_default().push(use_site);
134 }
135
136 fn handle_stmt(
138 &mut self,
139 stmt: &'a Statement<'db>,
140 statement_location: StatementLocation,
141 ) -> bool {
142 match stmt {
143 Statement::StructDestructure(stmt) => {
144 let mut visited_use_sites = OrderedHashSet::<StatementLocation>::default();
145
146 let mut can_remove_struct_destructure = true;
147
148 let mut constructs = vec![];
149 for output in stmt.outputs.iter() {
150 constructs.extend(filter_use_sites(
151 &self.use_sites,
152 &self.aliases,
153 output,
154 |location| match self.lowered.blocks[location.0].statements.get(location.1)
155 {
156 _ if !visited_use_sites.insert(*location) => {
157 None
159 }
160 Some(Statement::StructConstruct(construct_stmt))
161 if stmt.outputs.len() == construct_stmt.inputs.len()
162 && self.lowered.variables[stmt.input.var_id].ty
163 == self.lowered.variables[construct_stmt.output].ty
164 && zip_eq(
165 stmt.outputs.iter(),
166 construct_stmt.inputs.iter(),
167 )
168 .all(|(output, input)| {
169 output == &self.var_remapper.map_var_id(input.var_id)
170 }) =>
171 {
172 self.stmts_to_remove.push(*location);
173 Some(construct_stmt)
174 }
175 _ => {
176 can_remove_struct_destructure = false;
177 None
178 }
179 },
180 ));
181 }
182
183 if !(can_remove_struct_destructure
184 || self.lowered.variables[stmt.input.var_id].info.copyable.is_ok())
185 {
186 self.stmts_to_remove.truncate(self.stmts_to_remove.len() - constructs.len());
188 return false;
189 }
190
191 if can_remove_struct_destructure {
193 self.stmts_to_remove.push(statement_location);
194 }
195
196 for construct in constructs {
197 self.rename_var(construct.output, stmt.input.var_id)
198 }
199 can_remove_struct_destructure
200 }
201 Statement::StructConstruct(stmt) => {
202 let mut can_remove_struct_construct = true;
203 let destructures =
204 filter_use_sites(&self.use_sites, &self.aliases, &stmt.output, |location| {
205 if let Some(Statement::StructDestructure(destructure_stmt)) =
206 self.lowered.blocks[location.0].statements.get(location.1)
207 {
208 self.stmts_to_remove.push(*location);
209 Some(destructure_stmt)
210 } else {
211 can_remove_struct_construct = false;
212 None
213 }
214 });
215
216 if !(can_remove_struct_construct
217 || stmt
218 .inputs
219 .iter()
220 .all(|input| self.lowered.variables[input.var_id].info.copyable.is_ok()))
221 {
222 self.stmts_to_remove.truncate(self.stmts_to_remove.len() - destructures.len());
224 return false;
225 }
226
227 if can_remove_struct_construct {
229 self.stmts_to_remove.push(statement_location);
230 }
231
232 for destructure_stmt in destructures {
233 for (output, input) in
234 izip!(destructure_stmt.outputs.iter(), stmt.inputs.iter())
235 {
236 self.rename_var(*output, input.var_id);
237 }
238 }
239 can_remove_struct_construct
240 }
241 Statement::Snapshot(stmt) => {
242 let mut can_remove_snap = true;
243
244 let desnaps = filter_use_sites(
245 &self.use_sites,
246 &self.aliases,
247 &stmt.snapshot(),
248 |location| {
249 if let Some(Statement::Desnap(desnap_stmt)) =
250 self.lowered.blocks[location.0].statements.get(location.1)
251 {
252 self.stmts_to_remove.push(*location);
253 Some(desnap_stmt)
254 } else {
255 can_remove_snap = false;
256 None
257 }
258 },
259 );
260
261 let new_var = if can_remove_snap {
262 self.stmts_to_remove.push(statement_location);
263 self.rename_var(stmt.original(), stmt.input.var_id);
264 stmt.input.var_id
265 } else if desnaps.is_empty()
266 && self.lowered.variables[stmt.input.var_id].info.copyable.is_err()
267 {
268 stmt.original()
269 } else {
270 stmt.input.var_id
271 };
272
273 for desnap in desnaps {
274 self.rename_var(desnap.output, new_var);
275 }
276 can_remove_snap
277 }
278 _ => false,
279 }
280 }
281}
282
283impl<'db, 'a> Analyzer<'db, 'a> for CancelOpsContext<'db, 'a> {
284 type Info = ();
285
286 fn visit_stmt(
287 &mut self,
288 _info: &mut Self::Info,
289 statement_location: StatementLocation,
290 stmt: &'a Statement<'db>,
291 ) {
292 if !self.handle_stmt(stmt, statement_location) {
293 for input in stmt.inputs() {
294 self.add_use_site(input.var_id, statement_location);
295 }
296 }
297 }
298
299 fn visit_goto(
300 &mut self,
301 _info: &mut Self::Info,
302 statement_location: StatementLocation,
303 _target_block_id: BlockId,
304 remapping: &VarRemapping<'db>,
305 ) {
306 for src in remapping.values() {
307 self.add_use_site(src.var_id, statement_location);
308 }
309 }
310
311 fn merge_match(
312 &mut self,
313 statement_location: StatementLocation,
314 match_info: &'a MatchInfo<'db>,
315 _infos: impl Iterator<Item = Self::Info>,
316 ) -> Self::Info {
317 for var in match_info.inputs() {
318 self.add_use_site(var.var_id, statement_location);
319 }
320 }
321
322 fn info_from_return(
323 &mut self,
324 statement_location: StatementLocation,
325 vars: &[VarUsage<'db>],
326 ) -> Self::Info {
327 for var in vars {
328 self.add_use_site(var.var_id, statement_location);
329 }
330 }
331}