sonatina_codegen/
loop_analysis.rs

1use cranelift_entity::{entity_impl, packed_option::PackedOption, PrimaryMap, SecondaryMap};
2use fxhash::FxHashMap;
3use smallvec::SmallVec;
4
5use crate::{cfg::ControlFlowGraph, domtree::DomTree};
6
7use sonatina_ir::Block;
8
9#[derive(Debug, Default)]
10pub struct LoopTree {
11    /// Stores loops.
12    /// The index of an outer loops is guaranteed to be lower than its inner loops because loops
13    /// are found in RPO.
14    loops: PrimaryMap<Loop, LoopData>,
15
16    /// Maps blocks to its contained loop.
17    /// If the block is contained by multiple nested loops, then the block is mapped to the innermost loop.
18    block_to_loop: SecondaryMap<Block, PackedOption<Loop>>,
19}
20
21impl LoopTree {
22    pub fn new() -> Self {
23        Self::default()
24    }
25
26    /// Compute the `LoopTree` of the block.
27    pub fn compute(&mut self, cfg: &ControlFlowGraph, domtree: &DomTree) {
28        self.clear();
29
30        // Find loop headers in RPO, this means outer loops are guaranteed to be inserted first,
31        // then its inner loops are inserted.
32        for &block in domtree.rpo() {
33            for &pred in cfg.preds_of(block) {
34                if domtree.dominates(block, pred) {
35                    let loop_data = LoopData {
36                        header: block,
37                        parent: None.into(),
38                        children: SmallVec::new(),
39                    };
40
41                    self.loops.push(loop_data);
42                    break;
43                }
44            }
45        }
46
47        self.analyze_loops(cfg, domtree);
48    }
49
50    /// Returns all loops.
51    /// The result iterator guarantees outer loops are returned before its inner loops.
52    pub fn loops(&self) -> impl DoubleEndedIterator<Item = Loop> {
53        self.loops.keys()
54    }
55
56    /// Returns all blocks in the loop.
57    pub fn iter_blocks_post_order<'a, 'b>(
58        &'a self,
59        cfg: &'b ControlFlowGraph,
60        lp: Loop,
61    ) -> BlocksInLoopPostOrder<'a, 'b> {
62        BlocksInLoopPostOrder::new(self, cfg, lp)
63    }
64
65    /// Returns `true` if the `block` is in the `lp`.
66    pub fn is_in_loop(&self, block: Block, lp: Loop) -> bool {
67        let mut loop_of_block = self.loop_of_block(block);
68        while let Some(cur_lp) = loop_of_block {
69            if lp == cur_lp {
70                return true;
71            }
72            loop_of_block = self.parent_loop(cur_lp);
73        }
74        false
75    }
76
77    /// Returns number of loops found.
78    pub fn loop_num(&self) -> usize {
79        self.loops.len()
80    }
81
82    /// Map `block` to `lp`.
83    pub fn map_block(&mut self, block: Block, lp: Loop) {
84        self.block_to_loop[block] = lp.into();
85    }
86
87    /// Clear the internal state of `LoopTree`.
88    pub fn clear(&mut self) {
89        self.loops.clear();
90        self.block_to_loop.clear();
91    }
92
93    /// Returns header block of the `lp`.
94    pub fn loop_header(&self, lp: Loop) -> Block {
95        self.loops[lp].header
96    }
97
98    /// Get parent loop of the `lp` if exists.
99    pub fn parent_loop(&self, lp: Loop) -> Option<Loop> {
100        self.loops[lp].parent.expand()
101    }
102
103    /// Returns the loop that the `block` belongs to.
104    /// If the `block` belongs to multiple loops, then returns the innermost loop.
105    pub fn loop_of_block(&self, block: Block) -> Option<Loop> {
106        self.block_to_loop[block].expand()
107    }
108
109    /// Analyze loops. This method does
110    /// 1. Mapping each blocks to its contained loop.
111    /// 2. Setting parent and child of the loops.
112    fn analyze_loops(&mut self, cfg: &ControlFlowGraph, domtree: &DomTree) {
113        let mut worklist = vec![];
114
115        // Iterate loops reversely to ensure analyze inner loops first.
116        for cur_lp in self.loops.keys().rev() {
117            let cur_lp_header = self.loop_header(cur_lp);
118
119            // Add predecessors of the loop header to worklist.
120            for &block in cfg.preds_of(cur_lp_header) {
121                if domtree.dominates(cur_lp_header, block) {
122                    worklist.push(block);
123                }
124            }
125
126            while let Some(block) = worklist.pop() {
127                match self.block_to_loop[block].expand() {
128                    Some(lp_of_block) => {
129                        let outermost_parent = self.outermost_parent(lp_of_block);
130
131                        // If outermost parent is current loop, then the block is already visited.
132                        if outermost_parent == cur_lp {
133                            continue;
134                        } else {
135                            self.loops[cur_lp].children.push(outermost_parent);
136                            self.loops[outermost_parent].parent = cur_lp.into();
137
138                            let lp_header_of_block = self.loop_header(lp_of_block);
139                            worklist.extend(cfg.preds_of(lp_header_of_block));
140                        }
141                    }
142
143                    // If the block is not mapped to any loops, then map it to the loop.
144                    None => {
145                        self.map_block(block, cur_lp);
146                        // If block is not loop header, then add its predecessors to the worklist.
147                        if block != cur_lp_header {
148                            worklist.extend(cfg.preds_of(block));
149                        }
150                    }
151                }
152            }
153        }
154    }
155
156    /// Returns the outermost parent loop of `lp`. If `lp` doesn't have any parent, then returns `lp`
157    /// itself.
158    fn outermost_parent(&self, mut lp: Loop) -> Loop {
159        while let Some(parent) = self.parent_loop(lp) {
160            lp = parent;
161        }
162        lp
163    }
164}
165
166#[derive(Debug, Clone, Copy, PartialEq, Eq)]
167pub struct Loop(u32);
168entity_impl!(Loop);
169
170#[derive(Debug, Clone, PartialEq, Eq)]
171struct LoopData {
172    /// A header of the loop.
173    header: Block,
174
175    /// A parent loop that includes the loop.
176    parent: PackedOption<Loop>,
177
178    /// Child loops that the loop includes.
179    children: SmallVec<[Loop; 4]>,
180}
181
182pub struct BlocksInLoopPostOrder<'a, 'b> {
183    lpt: &'a LoopTree,
184    cfg: &'b ControlFlowGraph,
185    lp: Loop,
186    stack: Vec<Block>,
187    block_state: FxHashMap<Block, BlockState>,
188}
189
190impl<'a, 'b> BlocksInLoopPostOrder<'a, 'b> {
191    fn new(lpt: &'a LoopTree, cfg: &'b ControlFlowGraph, lp: Loop) -> Self {
192        let loop_header = lpt.loop_header(lp);
193
194        Self {
195            lpt,
196            cfg,
197            lp,
198            stack: vec![loop_header],
199            block_state: FxHashMap::default(),
200        }
201    }
202}
203
204impl<'a, 'b> Iterator for BlocksInLoopPostOrder<'a, 'b> {
205    type Item = Block;
206
207    fn next(&mut self) -> Option<Self::Item> {
208        while let Some(&block) = self.stack.last() {
209            match self.block_state.get(&block) {
210                // The block is already visited, but not returned from the iterator,
211                // so mark the block as `Finished` and return the block.
212                Some(BlockState::Visited) => {
213                    let block = self.stack.pop().unwrap();
214                    self.block_state.insert(block, BlockState::Finished);
215                    return Some(block);
216                }
217
218                // The block is already returned, so just remove the block from the stack.
219                Some(BlockState::Finished) => {
220                    self.stack.pop().unwrap();
221                }
222
223                // The block is not visited yet, so push its unvisited in-loop successors to the stack and mark the block as `Visited`.
224                None => {
225                    self.block_state.insert(block, BlockState::Visited);
226                    for &succ in self.cfg.succs_of(block) {
227                        if self.block_state.get(&succ).is_none()
228                            && self.lpt.is_in_loop(succ, self.lp)
229                        {
230                            self.stack.push(succ);
231                        }
232                    }
233                }
234            }
235        }
236
237        None
238    }
239}
240
241enum BlockState {
242    Visited,
243    Finished,
244}
245
246#[cfg(test)]
247mod tests {
248    use super::*;
249
250    use sonatina_ir::{builder::test_util::*, Function, Type};
251
252    fn compute_loop(func: &Function) -> LoopTree {
253        let mut cfg = ControlFlowGraph::new();
254        let mut domtree = DomTree::new();
255        let mut lpt = LoopTree::new();
256        cfg.compute(func);
257        domtree.compute(&cfg);
258        lpt.compute(&cfg, &domtree);
259        lpt
260    }
261
262    #[test]
263    fn simple_loop() {
264        let mut test_module_builder = TestModuleBuilder::new();
265        let mut builder = test_module_builder.func_builder(&[], Type::Void);
266
267        let b0 = builder.append_block();
268        let b1 = builder.append_block();
269        let b2 = builder.append_block();
270        let b3 = builder.append_block();
271
272        builder.switch_to_block(b0);
273        let v0 = builder.make_imm_value(0i32);
274        builder.jump(b1);
275
276        builder.switch_to_block(b1);
277        let v1 = builder.phi(&[(v0, b0)]);
278        let c0 = builder.make_imm_value(10i32);
279        let v2 = builder.eq(v1, c0);
280        builder.br(v2, b3, b2);
281
282        builder.switch_to_block(b2);
283        let c1 = builder.make_imm_value(1i32);
284        let v3 = builder.add(v1, c1);
285        builder.jump(b1);
286        builder.append_phi_arg(v1, v3, b2);
287
288        builder.switch_to_block(b3);
289        builder.ret(None);
290
291        builder.seal_all();
292        let func_ref = builder.finish();
293
294        let module = test_module_builder.build();
295        let func = &module.funcs[func_ref];
296        let lpt = compute_loop(func);
297
298        debug_assert_eq!(lpt.loop_num(), 1);
299        let lp0 = lpt.loops().next().unwrap();
300        debug_assert_eq!(lpt.loop_of_block(b0), None);
301        debug_assert_eq!(lpt.loop_of_block(b1), Some(lp0));
302        debug_assert_eq!(lpt.loop_of_block(b2), Some(lp0));
303        debug_assert_eq!(lpt.loop_of_block(b3), None);
304
305        debug_assert_eq!(lpt.loop_header(lp0), b1);
306    }
307
308    #[test]
309    fn continue_loop() {
310        let mut test_module_builder = TestModuleBuilder::new();
311        let mut builder = test_module_builder.func_builder(&[], Type::Void);
312
313        let b0 = builder.append_block();
314        let b1 = builder.append_block();
315        let b2 = builder.append_block();
316        let b3 = builder.append_block();
317        let b4 = builder.append_block();
318        let b5 = builder.append_block();
319        let b6 = builder.append_block();
320
321        builder.switch_to_block(b0);
322        let v0 = builder.make_imm_value(0i32);
323        builder.jump(b1);
324
325        builder.switch_to_block(b1);
326        let v1 = builder.phi(&[(v0, b0)]);
327        let c0 = builder.make_imm_value(10i32);
328        let v2 = builder.eq(v1, c0);
329        builder.br(v2, b5, b2);
330
331        builder.switch_to_block(b2);
332        let c1 = builder.make_imm_value(5i32);
333        let v3 = builder.eq(v1, c1);
334        builder.br(v3, b3, b4);
335
336        builder.switch_to_block(b3);
337        builder.jump(b5);
338
339        builder.switch_to_block(b4);
340        let c2 = builder.make_imm_value(3i32);
341        let v4 = builder.add(v1, c2);
342        builder.append_phi_arg(v1, v4, b4);
343        builder.jump(b1);
344
345        builder.switch_to_block(b5);
346        let c3 = builder.make_imm_value(1i32);
347        let v5 = builder.add(v1, c3);
348        builder.append_phi_arg(v1, v5, b5);
349        builder.jump(b1);
350
351        builder.switch_to_block(b6);
352        builder.ret(None);
353
354        builder.seal_all();
355        let func_ref = builder.finish();
356
357        let module = test_module_builder.build();
358        let func = &module.funcs[func_ref];
359        let lpt = compute_loop(func);
360
361        debug_assert_eq!(lpt.loop_num(), 1);
362        let lp0 = lpt.loops().next().unwrap();
363
364        debug_assert_eq!(lpt.loop_of_block(b0), None);
365        debug_assert_eq!(lpt.loop_of_block(b1), Some(lp0));
366        debug_assert_eq!(lpt.loop_of_block(b2), Some(lp0));
367        debug_assert_eq!(lpt.loop_of_block(b3), Some(lp0));
368        debug_assert_eq!(lpt.loop_of_block(b4), Some(lp0));
369        debug_assert_eq!(lpt.loop_of_block(b5), Some(lp0));
370        debug_assert_eq!(lpt.loop_of_block(b6), None);
371
372        debug_assert_eq!(lpt.loop_header(lp0), b1);
373    }
374
375    #[test]
376    fn single_block_loop() {
377        let mut test_module_builder = TestModuleBuilder::new();
378        let mut builder = test_module_builder.func_builder(&[Type::I1], Type::Void);
379        let b0 = builder.append_block();
380        let b1 = builder.append_block();
381        let b2 = builder.append_block();
382
383        let arg = builder.args()[0];
384
385        builder.switch_to_block(b0);
386        builder.jump(b1);
387
388        builder.switch_to_block(b1);
389        builder.br(arg, b1, b2);
390
391        builder.switch_to_block(b2);
392        builder.ret(None);
393
394        builder.seal_all();
395        let func_ref = builder.finish();
396
397        let module = test_module_builder.build();
398        let func = &module.funcs[func_ref];
399        let lpt = compute_loop(func);
400
401        debug_assert_eq!(lpt.loop_num(), 1);
402        let lp0 = lpt.loops().next().unwrap();
403
404        debug_assert_eq!(lpt.loop_of_block(b0), None);
405        debug_assert_eq!(lpt.loop_of_block(b1), Some(lp0));
406        debug_assert_eq!(lpt.loop_of_block(b2), None);
407    }
408
409    #[test]
410    fn nested_loop() {
411        let mut test_module_builder = TestModuleBuilder::new();
412        let mut builder = test_module_builder.func_builder(&[Type::I1], Type::Void);
413
414        let b0 = builder.append_block();
415        let b1 = builder.append_block();
416        let b2 = builder.append_block();
417        let b3 = builder.append_block();
418        let b4 = builder.append_block();
419        let b5 = builder.append_block();
420        let b6 = builder.append_block();
421        let b7 = builder.append_block();
422        let b8 = builder.append_block();
423        let b9 = builder.append_block();
424        let b10 = builder.append_block();
425        let b11 = builder.append_block();
426
427        let arg = builder.args()[0];
428
429        builder.switch_to_block(b0);
430        builder.jump(b1);
431
432        builder.switch_to_block(b1);
433        builder.jump(b2);
434
435        builder.switch_to_block(b2);
436        builder.br(arg, b3, b7);
437
438        builder.switch_to_block(b3);
439        builder.jump(b4);
440
441        builder.switch_to_block(b4);
442        builder.jump(b5);
443
444        builder.switch_to_block(b5);
445        builder.br(arg, b4, b6);
446
447        builder.switch_to_block(b6);
448        builder.br(arg, b3, b10);
449
450        builder.switch_to_block(b7);
451        builder.jump(b8);
452
453        builder.switch_to_block(b8);
454        builder.br(arg, b9, b7);
455
456        builder.switch_to_block(b9);
457        builder.br(arg, b1, b10);
458
459        builder.switch_to_block(b10);
460        builder.jump(b1);
461
462        builder.switch_to_block(b11);
463        builder.ret(None);
464
465        builder.seal_all();
466        let func_ref = builder.finish();
467
468        let module = test_module_builder.build();
469        let func = &module.funcs[func_ref];
470        let lpt = compute_loop(func);
471
472        debug_assert_eq!(lpt.loop_num(), 4);
473        let l0 = lpt.loop_of_block(b1).unwrap();
474        let l1 = lpt.loop_of_block(b3).unwrap();
475        let l2 = lpt.loop_of_block(b4).unwrap();
476        let l3 = lpt.loop_of_block(b7).unwrap();
477
478        debug_assert_eq!(lpt.loop_of_block(b0), None);
479        debug_assert_eq!(lpt.loop_of_block(b1), Some(l0));
480        debug_assert_eq!(lpt.loop_of_block(b2), Some(l0));
481        debug_assert_eq!(lpt.loop_of_block(b3), Some(l1));
482        debug_assert_eq!(lpt.loop_of_block(b4), Some(l2));
483        debug_assert_eq!(lpt.loop_of_block(b5), Some(l2));
484        debug_assert_eq!(lpt.loop_of_block(b6), Some(l1));
485        debug_assert_eq!(lpt.loop_of_block(b7), Some(l3));
486        debug_assert_eq!(lpt.loop_of_block(b8), Some(l3));
487        debug_assert_eq!(lpt.loop_of_block(b9), Some(l0));
488        debug_assert_eq!(lpt.loop_of_block(b10), Some(l0));
489        debug_assert_eq!(lpt.loop_of_block(b11), None);
490
491        debug_assert_eq!(lpt.parent_loop(l0), None);
492        debug_assert_eq!(lpt.parent_loop(l1), Some(l0));
493        debug_assert_eq!(lpt.parent_loop(l2), Some(l1));
494        debug_assert_eq!(lpt.parent_loop(l3), Some(l0));
495
496        debug_assert_eq!(lpt.loop_header(l0), b1);
497        debug_assert_eq!(lpt.loop_header(l1), b3);
498        debug_assert_eq!(lpt.loop_header(l2), b4);
499        debug_assert_eq!(lpt.loop_header(l3), b7);
500    }
501}