Skip to main content

cairo_lang_lowering/
reorganize_blocks.rs

1use std::collections::HashMap;
2
3use itertools::Itertools;
4
5use crate::analysis::core::StatementLocation;
6use crate::analysis::{DataflowAnalyzer, DataflowBackAnalysis, Direction, Edge};
7use crate::blocks::BlocksBuilder;
8use crate::optimizations::remappings::{self, Context};
9use crate::utils::{Rebuilder, RebuilderEx};
10use crate::{
11    Block, BlockEnd, BlockId, Lowered, Statement, VarRemapping, VariableArena, VariableId,
12};
13
14/// Reorganizes the blocks in lowered function and removes unnecessary remappings.
15///
16/// Removes unreachable blocks.
17/// Blocks that are reachable only through goto are combined with the block that does the goto.
18/// The order of the blocks is changed to be a topologically sorted.
19pub fn reorganize_blocks<'db>(lowered: &mut Lowered<'db>) {
20    if lowered.blocks.is_empty() {
21        return;
22    }
23    let mut ctx = TopSortContext {
24        old_block_rev_order: Vec::with_capacity(lowered.blocks.len()),
25        incoming_gotos: vec![0; lowered.blocks.len()],
26        can_be_merged: vec![true; lowered.blocks.len()],
27        remappings_ctx: Context::new(lowered.variables.len()),
28    };
29
30    remappings::visit_remappings(lowered, |remapping| {
31        for (dst, src) in remapping.iter() {
32            ctx.remappings_ctx.dest_to_srcs[dst.index()].push(src.var_id);
33        }
34    });
35
36    DataflowBackAnalysis::new(lowered, &mut ctx).run();
37
38    // Rebuild the blocks in the correct order.
39    let mut new_blocks = BlocksBuilder::default();
40
41    // Keep only blocks that can't be merged or have more than 1 incoming
42    // goto.
43    // Note that unreachable blocks were not added to `ctx.old_block_rev_order` during
44    // the analysis above.
45    let mut old_block_rev_order = ctx
46        .old_block_rev_order
47        .into_iter()
48        .filter(|block_id| !ctx.can_be_merged[block_id.0] || ctx.incoming_gotos[block_id.0] > 1)
49        .collect_vec();
50
51    // Add the root block as it was filtered above.
52    old_block_rev_order.push(BlockId::root());
53
54    let n_visited_blocks = old_block_rev_order.len();
55
56    let mut rebuilder = RebuildContext {
57        block_remapping: HashMap::from_iter(
58            old_block_rev_order
59                .iter()
60                .enumerate()
61                .map(|(idx, block_id)| (*block_id, BlockId(n_visited_blocks - idx - 1))),
62        ),
63        remappings_ctx: ctx.remappings_ctx,
64    };
65
66    let mut var_reassigner = VarReassigner::new(&lowered.variables);
67    for param in lowered.parameters.iter_mut() {
68        *param = var_reassigner.map_var_id(*param);
69    }
70
71    for block_id in old_block_rev_order.into_iter().rev() {
72        let mut statements = vec![];
73
74        let mut block = &lowered.blocks[block_id];
75        loop {
76            statements.extend(
77                block.statements.iter().map(|stmt| {
78                    var_reassigner.rebuild_statement(&rebuilder.rebuild_statement(stmt))
79                }),
80            );
81            if let BlockEnd::Goto(target_block_id, remappings) = &block.end
82                && !rebuilder.block_remapping.contains_key(target_block_id)
83            {
84                assert!(
85                    rebuilder.rebuild_remapping(remappings).is_empty(),
86                    "Remapping should be empty."
87                );
88                block = &lowered.blocks[*target_block_id];
89                continue;
90            }
91            break;
92        }
93
94        let end = var_reassigner.rebuild_end(&rebuilder.rebuild_end(&block.end));
95        new_blocks.alloc(Block { statements, end });
96    }
97
98    lowered.variables = var_reassigner.new_vars;
99    lowered.blocks = new_blocks.build().unwrap();
100}
101
102pub struct TopSortContext {
103    old_block_rev_order: Vec<BlockId>,
104    // The number of incoming gotos, indexed by block_id.
105    incoming_gotos: Vec<usize>,
106
107    // True if the block can be merged with the block that goes to it.
108    can_be_merged: Vec<bool>,
109
110    remappings_ctx: remappings::Context,
111}
112
113impl<'db, 'a> DataflowAnalyzer<'db, 'a> for TopSortContext {
114    type Info = ();
115
116    const DIRECTION: Direction = Direction::Backward;
117
118    fn initial_info(&mut self, _block_id: BlockId, block_end: &'a BlockEnd<'db>) -> Self::Info {
119        // For zero-arm matches (e.g. `never` type), no MatchArm edges are traversed,
120        // so mark match inputs as used here.
121        if let BlockEnd::Match { info } = block_end {
122            for var_usage in info.inputs() {
123                self.remappings_ctx.set_used(var_usage.var_id);
124            }
125        }
126    }
127
128    fn merge(
129        &mut self,
130        _lowered: &Lowered<'db>,
131        _statement_location: StatementLocation,
132        _info1: Self::Info,
133        _info2: Self::Info,
134    ) -> Self::Info {
135    }
136
137    fn visit_block_start(
138        &mut self,
139        _info: &mut Self::Info,
140        block_id: BlockId,
141        _block: &Block<'db>,
142    ) {
143        self.old_block_rev_order.push(block_id);
144    }
145
146    fn transfer_stmt(
147        &mut self,
148        _info: &mut Self::Info,
149        _statement_location: StatementLocation,
150        stmt: &'a Statement<'db>,
151    ) {
152        for var_usage in stmt.inputs() {
153            self.remappings_ctx.set_used(var_usage.var_id);
154        }
155    }
156
157    fn transfer_edge(&mut self, info: &Self::Info, edge: &Edge<'db, 'a>) -> Self::Info {
158        match edge {
159            Edge::Goto { target, remapping: _ } => {
160                // Note that the remappings of a goto are not considered a usage. Later usages
161                // (such as a merge) would catch them if used.
162                self.incoming_gotos[target.0] += 1;
163            }
164            Edge::MatchArm { arm, match_info } => {
165                self.can_be_merged[arm.block_id.0] = false;
166                // Mark match inputs as used.
167                for var_usage in match_info.inputs() {
168                    self.remappings_ctx.set_used(var_usage.var_id);
169                }
170            }
171            Edge::Return { vars, .. } => {
172                for var_usage in vars.iter() {
173                    self.remappings_ctx.set_used(var_usage.var_id);
174                }
175            }
176            Edge::Panic { var } => {
177                self.remappings_ctx.set_used(var.var_id);
178            }
179        }
180        *info
181    }
182}
183
184pub struct RebuildContext {
185    block_remapping: HashMap<BlockId, BlockId>,
186    remappings_ctx: remappings::Context,
187}
188impl<'db> Rebuilder<'db> for RebuildContext {
189    fn map_block_id(&mut self, block: BlockId) -> BlockId {
190        self.block_remapping[&block]
191    }
192
193    fn map_var_id(&mut self, var: VariableId) -> VariableId {
194        self.remappings_ctx.map_var_id(var)
195    }
196
197    fn transform_remapping(&mut self, remapping: &mut VarRemapping<'db>) {
198        self.remappings_ctx.transform_remapping(remapping)
199    }
200}
201
202/// Helper class to reassign variable ids according to the rebuild order.
203///
204/// Note that it can't be integrated into the RebuildContext above because rebuild_remapping might
205/// call `map_var_id` on variables that are going to be removed.
206pub struct VarReassigner<'db, 'a> {
207    pub old_vars: &'a VariableArena<'db>,
208    pub new_vars: VariableArena<'db>,
209
210    // Maps old var_id to new_var_id
211    pub vars: Vec<Option<VariableId>>,
212}
213
214impl<'db, 'a> VarReassigner<'db, 'a> {
215    pub fn new(old_vars: &'a VariableArena<'db>) -> Self {
216        Self { old_vars, new_vars: Default::default(), vars: vec![None; old_vars.len()] }
217    }
218}
219
220impl<'db> Rebuilder<'db> for VarReassigner<'db, '_> {
221    fn map_var_id(&mut self, var: VariableId) -> VariableId {
222        *self.vars[var.index()]
223            .get_or_insert_with(|| self.new_vars.alloc(self.old_vars[var].clone()))
224    }
225}