Skip to main content

cairo_lang_lowering/optimizations/
reorder_statements.rs

1#[cfg(test)]
2#[path = "reorder_statements_test.rs"]
3mod test;
4
5use std::cmp::Reverse;
6
7use cairo_lang_defs::ids::ExternFunctionId;
8use cairo_lang_utils::ordered_hash_map::OrderedHashMap;
9use cairo_lang_utils::unordered_hash_map::{Entry, UnorderedHashMap};
10use cairo_lang_utils::unordered_hash_set::UnorderedHashSet;
11use itertools::{Itertools, zip_eq};
12use salsa::Database;
13
14use crate::analysis::{Analyzer, BackAnalysis, StatementLocation};
15use crate::db::LoweringGroup;
16use crate::{
17    BlockId, Lowered, MatchInfo, Statement, StatementCall, VarRemapping, VarUsage, VariableId,
18};
19
20/// Reorder the statements in the lowering in order to move variable definitions closer to their
21/// usage. Statement with no side effects and unused outputs are removed.
22///
23/// The list of call statements that can be moved is currently hardcoded.
24///
25/// Removing unnecessary remapping before this optimization will result in better code.
26pub fn reorder_statements(db: &dyn Database, lowered: &mut Lowered<'_>) {
27    if lowered.blocks.is_empty() {
28        return;
29    }
30    let ctx = ReorderStatementsContext {
31        db,
32        lowered: &*lowered,
33        moveable_functions: db.priv_movable_function_ids(),
34        statement_to_move: vec![],
35    };
36    let mut analysis = BackAnalysis::new(lowered, ctx);
37    analysis.get_root_info();
38    let ctx = analysis.analyzer;
39
40    let mut changes_by_block =
41        OrderedHashMap::<BlockId, Vec<(usize, Option<Statement<'_>>)>>::default();
42
43    for (src, opt_dst) in ctx.statement_to_move {
44        changes_by_block.entry(src.0).or_insert_with(Vec::new).push((src.1, None));
45
46        if let Some(dst) = opt_dst {
47            let statement = lowered[src].clone();
48            changes_by_block.entry(dst.0).or_insert_with(Vec::new).push((dst.1, Some(statement)));
49        }
50    }
51
52    for (block_id, block_changes) in changes_by_block {
53        let statements = &mut lowered.blocks[block_id].statements;
54
55        // Apply block changes in reverse order to prevent a change from invalidating the
56        // indices of the other changes.
57        for (index, opt_statement) in
58            block_changes.into_iter().sorted_by_key(|(index, _)| Reverse(*index))
59        {
60            match opt_statement {
61                Some(stmt) => statements.insert(index, stmt),
62                None => {
63                    statements.remove(index);
64                }
65            }
66        }
67    }
68}
69
70#[derive(Clone, Default)]
71pub struct ReorderStatementsInfo {
72    // A mapping from var_id to a candidate location that it can be moved to.
73    // If the variable is used in multiple match arms we define the next use to be
74    // the match.
75    next_use: UnorderedHashMap<VariableId, StatementLocation>,
76}
77
78pub struct ReorderStatementsContext<'db> {
79    db: &'db dyn Database,
80    lowered: &'db Lowered<'db>,
81    // A list of function that can be moved.
82    moveable_functions: &'db UnorderedHashSet<ExternFunctionId<'db>>,
83    statement_to_move: Vec<(StatementLocation, Option<StatementLocation>)>,
84}
85impl<'db> ReorderStatementsContext<'db> {
86    fn call_can_be_moved(&mut self, stmt: &StatementCall<'db>) -> bool {
87        if let Some((extern_id, _)) = stmt.function.get_extern(self.db) {
88            self.moveable_functions.contains(&extern_id)
89        } else {
90            false
91        }
92    }
93}
94impl<'db> Analyzer<'db, '_> for ReorderStatementsContext<'db> {
95    type Info = ReorderStatementsInfo;
96
97    fn visit_stmt(
98        &mut self,
99        info: &mut Self::Info,
100        statement_location: StatementLocation,
101        stmt: &Statement<'db>,
102    ) {
103        let mut immovable = matches!(stmt, Statement::Call(stmt) if !self.call_can_be_moved(stmt));
104        let mut optional_target_location = None;
105        for var_to_move in stmt.outputs() {
106            let Some((block_id, index)) = info.next_use.remove(var_to_move) else { continue };
107            if let Some((target_block_id, target_index)) = &mut optional_target_location {
108                *target_index = std::cmp::min(*target_index, index);
109                // If the output is used in multiple places we can't move their creation point.
110                immovable |= target_block_id != &block_id;
111            } else {
112                optional_target_location = Some((block_id, index));
113            }
114        }
115        if immovable {
116            for var_usage in stmt.inputs() {
117                info.next_use.insert(var_usage.var_id, statement_location);
118            }
119            return;
120        }
121
122        if let Some(target_location) = optional_target_location {
123            // If the statement is not removed add demand for its inputs.
124            for var_usage in stmt.inputs() {
125                match info.next_use.entry(var_usage.var_id) {
126                    Entry::Occupied(mut e) => {
127                        // Since we don't know where `e.get()` and `target_location` converge
128                        // we use `statement_location` as a conservative estimate.
129                        &e.insert(statement_location)
130                    }
131                    Entry::Vacant(e) => e.insert(target_location),
132                };
133            }
134
135            self.statement_to_move.push((statement_location, Some(target_location)))
136        } else if stmt
137            .inputs()
138            .iter()
139            .all(|v| self.lowered.variables[v.var_id].info.droppable.is_ok())
140        {
141            // If a movable statement is unused, and all its inputs are droppable removing it is
142            // valid.
143            self.statement_to_move.push((statement_location, None))
144        } else {
145            // Statement is unused but can't be removed.
146            for var_usage in stmt.inputs() {
147                info.next_use.insert(var_usage.var_id, statement_location);
148            }
149        }
150    }
151
152    fn visit_goto(
153        &mut self,
154        info: &mut Self::Info,
155        statement_location: StatementLocation,
156        _target_block_id: BlockId,
157        remapping: &VarRemapping<'db>,
158    ) {
159        for VarUsage { var_id, .. } in remapping.values() {
160            info.next_use.insert(*var_id, statement_location);
161        }
162    }
163
164    fn merge_match(
165        &mut self,
166        statement_location: StatementLocation,
167        match_info: &MatchInfo<'db>,
168        infos: impl Iterator<Item = Self::Info>,
169    ) -> Self::Info {
170        let mut infos = zip_eq(infos, match_info.arms()).map(|(mut info, arm)| {
171            for var_id in &arm.var_ids {
172                info.next_use.remove(var_id);
173            }
174            info
175        });
176        let mut info = infos.next().unwrap_or_default();
177        for arm_info in infos {
178            info.next_use.merge(&arm_info.next_use, |e, _| {
179                *e.into_mut() = statement_location;
180            });
181        }
182
183        for var_usage in match_info.inputs() {
184            info.next_use.insert(var_usage.var_id, statement_location);
185        }
186
187        info
188    }
189
190    fn info_from_return(
191        &mut self,
192        statement_location: StatementLocation,
193        vars: &[VarUsage<'db>],
194    ) -> Self::Info {
195        let mut info = Self::Info::default();
196        for var_usage in vars {
197            info.next_use.insert(var_usage.var_id, statement_location);
198        }
199        info
200    }
201}