1use cranelift_entity::{entity_impl, packed_option::PackedOption, PrimaryMap, SecondaryMap};
2use fxhash::FxHashMap;
3use smallvec::SmallVec;
4
5use crate::{cfg::ControlFlowGraph, domtree::DomTree};
6
7use sonatina_ir::Block;
8
9#[derive(Debug, Default)]
10pub struct LoopTree {
11 loops: PrimaryMap<Loop, LoopData>,
15
16 block_to_loop: SecondaryMap<Block, PackedOption<Loop>>,
19}
20
21impl LoopTree {
22 pub fn new() -> Self {
23 Self::default()
24 }
25
26 pub fn compute(&mut self, cfg: &ControlFlowGraph, domtree: &DomTree) {
28 self.clear();
29
30 for &block in domtree.rpo() {
33 for &pred in cfg.preds_of(block) {
34 if domtree.dominates(block, pred) {
35 let loop_data = LoopData {
36 header: block,
37 parent: None.into(),
38 children: SmallVec::new(),
39 };
40
41 self.loops.push(loop_data);
42 break;
43 }
44 }
45 }
46
47 self.analyze_loops(cfg, domtree);
48 }
49
50 pub fn loops(&self) -> impl DoubleEndedIterator<Item = Loop> {
53 self.loops.keys()
54 }
55
56 pub fn iter_blocks_post_order<'a, 'b>(
58 &'a self,
59 cfg: &'b ControlFlowGraph,
60 lp: Loop,
61 ) -> BlocksInLoopPostOrder<'a, 'b> {
62 BlocksInLoopPostOrder::new(self, cfg, lp)
63 }
64
65 pub fn is_in_loop(&self, block: Block, lp: Loop) -> bool {
67 let mut loop_of_block = self.loop_of_block(block);
68 while let Some(cur_lp) = loop_of_block {
69 if lp == cur_lp {
70 return true;
71 }
72 loop_of_block = self.parent_loop(cur_lp);
73 }
74 false
75 }
76
77 pub fn loop_num(&self) -> usize {
79 self.loops.len()
80 }
81
82 pub fn map_block(&mut self, block: Block, lp: Loop) {
84 self.block_to_loop[block] = lp.into();
85 }
86
87 pub fn clear(&mut self) {
89 self.loops.clear();
90 self.block_to_loop.clear();
91 }
92
93 pub fn loop_header(&self, lp: Loop) -> Block {
95 self.loops[lp].header
96 }
97
98 pub fn parent_loop(&self, lp: Loop) -> Option<Loop> {
100 self.loops[lp].parent.expand()
101 }
102
103 pub fn loop_of_block(&self, block: Block) -> Option<Loop> {
106 self.block_to_loop[block].expand()
107 }
108
109 fn analyze_loops(&mut self, cfg: &ControlFlowGraph, domtree: &DomTree) {
113 let mut worklist = vec![];
114
115 for cur_lp in self.loops.keys().rev() {
117 let cur_lp_header = self.loop_header(cur_lp);
118
119 for &block in cfg.preds_of(cur_lp_header) {
121 if domtree.dominates(cur_lp_header, block) {
122 worklist.push(block);
123 }
124 }
125
126 while let Some(block) = worklist.pop() {
127 match self.block_to_loop[block].expand() {
128 Some(lp_of_block) => {
129 let outermost_parent = self.outermost_parent(lp_of_block);
130
131 if outermost_parent == cur_lp {
133 continue;
134 } else {
135 self.loops[cur_lp].children.push(outermost_parent);
136 self.loops[outermost_parent].parent = cur_lp.into();
137
138 let lp_header_of_block = self.loop_header(lp_of_block);
139 worklist.extend(cfg.preds_of(lp_header_of_block));
140 }
141 }
142
143 None => {
145 self.map_block(block, cur_lp);
146 if block != cur_lp_header {
148 worklist.extend(cfg.preds_of(block));
149 }
150 }
151 }
152 }
153 }
154 }
155
156 fn outermost_parent(&self, mut lp: Loop) -> Loop {
159 while let Some(parent) = self.parent_loop(lp) {
160 lp = parent;
161 }
162 lp
163 }
164}
165
166#[derive(Debug, Clone, Copy, PartialEq, Eq)]
167pub struct Loop(u32);
168entity_impl!(Loop);
169
170#[derive(Debug, Clone, PartialEq, Eq)]
171struct LoopData {
172 header: Block,
174
175 parent: PackedOption<Loop>,
177
178 children: SmallVec<[Loop; 4]>,
180}
181
182pub struct BlocksInLoopPostOrder<'a, 'b> {
183 lpt: &'a LoopTree,
184 cfg: &'b ControlFlowGraph,
185 lp: Loop,
186 stack: Vec<Block>,
187 block_state: FxHashMap<Block, BlockState>,
188}
189
190impl<'a, 'b> BlocksInLoopPostOrder<'a, 'b> {
191 fn new(lpt: &'a LoopTree, cfg: &'b ControlFlowGraph, lp: Loop) -> Self {
192 let loop_header = lpt.loop_header(lp);
193
194 Self {
195 lpt,
196 cfg,
197 lp,
198 stack: vec![loop_header],
199 block_state: FxHashMap::default(),
200 }
201 }
202}
203
204impl<'a, 'b> Iterator for BlocksInLoopPostOrder<'a, 'b> {
205 type Item = Block;
206
207 fn next(&mut self) -> Option<Self::Item> {
208 while let Some(&block) = self.stack.last() {
209 match self.block_state.get(&block) {
210 Some(BlockState::Visited) => {
213 let block = self.stack.pop().unwrap();
214 self.block_state.insert(block, BlockState::Finished);
215 return Some(block);
216 }
217
218 Some(BlockState::Finished) => {
220 self.stack.pop().unwrap();
221 }
222
223 None => {
225 self.block_state.insert(block, BlockState::Visited);
226 for &succ in self.cfg.succs_of(block) {
227 if self.block_state.get(&succ).is_none()
228 && self.lpt.is_in_loop(succ, self.lp)
229 {
230 self.stack.push(succ);
231 }
232 }
233 }
234 }
235 }
236
237 None
238 }
239}
240
241enum BlockState {
242 Visited,
243 Finished,
244}
245
246#[cfg(test)]
247mod tests {
248 use super::*;
249
250 use sonatina_ir::{builder::test_util::*, Function, Type};
251
252 fn compute_loop(func: &Function) -> LoopTree {
253 let mut cfg = ControlFlowGraph::new();
254 let mut domtree = DomTree::new();
255 let mut lpt = LoopTree::new();
256 cfg.compute(func);
257 domtree.compute(&cfg);
258 lpt.compute(&cfg, &domtree);
259 lpt
260 }
261
262 #[test]
263 fn simple_loop() {
264 let mut test_module_builder = TestModuleBuilder::new();
265 let mut builder = test_module_builder.func_builder(&[], Type::Void);
266
267 let b0 = builder.append_block();
268 let b1 = builder.append_block();
269 let b2 = builder.append_block();
270 let b3 = builder.append_block();
271
272 builder.switch_to_block(b0);
273 let v0 = builder.make_imm_value(0i32);
274 builder.jump(b1);
275
276 builder.switch_to_block(b1);
277 let v1 = builder.phi(&[(v0, b0)]);
278 let c0 = builder.make_imm_value(10i32);
279 let v2 = builder.eq(v1, c0);
280 builder.br(v2, b3, b2);
281
282 builder.switch_to_block(b2);
283 let c1 = builder.make_imm_value(1i32);
284 let v3 = builder.add(v1, c1);
285 builder.jump(b1);
286 builder.append_phi_arg(v1, v3, b2);
287
288 builder.switch_to_block(b3);
289 builder.ret(None);
290
291 builder.seal_all();
292 let func_ref = builder.finish();
293
294 let module = test_module_builder.build();
295 let func = &module.funcs[func_ref];
296 let lpt = compute_loop(func);
297
298 debug_assert_eq!(lpt.loop_num(), 1);
299 let lp0 = lpt.loops().next().unwrap();
300 debug_assert_eq!(lpt.loop_of_block(b0), None);
301 debug_assert_eq!(lpt.loop_of_block(b1), Some(lp0));
302 debug_assert_eq!(lpt.loop_of_block(b2), Some(lp0));
303 debug_assert_eq!(lpt.loop_of_block(b3), None);
304
305 debug_assert_eq!(lpt.loop_header(lp0), b1);
306 }
307
308 #[test]
309 fn continue_loop() {
310 let mut test_module_builder = TestModuleBuilder::new();
311 let mut builder = test_module_builder.func_builder(&[], Type::Void);
312
313 let b0 = builder.append_block();
314 let b1 = builder.append_block();
315 let b2 = builder.append_block();
316 let b3 = builder.append_block();
317 let b4 = builder.append_block();
318 let b5 = builder.append_block();
319 let b6 = builder.append_block();
320
321 builder.switch_to_block(b0);
322 let v0 = builder.make_imm_value(0i32);
323 builder.jump(b1);
324
325 builder.switch_to_block(b1);
326 let v1 = builder.phi(&[(v0, b0)]);
327 let c0 = builder.make_imm_value(10i32);
328 let v2 = builder.eq(v1, c0);
329 builder.br(v2, b5, b2);
330
331 builder.switch_to_block(b2);
332 let c1 = builder.make_imm_value(5i32);
333 let v3 = builder.eq(v1, c1);
334 builder.br(v3, b3, b4);
335
336 builder.switch_to_block(b3);
337 builder.jump(b5);
338
339 builder.switch_to_block(b4);
340 let c2 = builder.make_imm_value(3i32);
341 let v4 = builder.add(v1, c2);
342 builder.append_phi_arg(v1, v4, b4);
343 builder.jump(b1);
344
345 builder.switch_to_block(b5);
346 let c3 = builder.make_imm_value(1i32);
347 let v5 = builder.add(v1, c3);
348 builder.append_phi_arg(v1, v5, b5);
349 builder.jump(b1);
350
351 builder.switch_to_block(b6);
352 builder.ret(None);
353
354 builder.seal_all();
355 let func_ref = builder.finish();
356
357 let module = test_module_builder.build();
358 let func = &module.funcs[func_ref];
359 let lpt = compute_loop(func);
360
361 debug_assert_eq!(lpt.loop_num(), 1);
362 let lp0 = lpt.loops().next().unwrap();
363
364 debug_assert_eq!(lpt.loop_of_block(b0), None);
365 debug_assert_eq!(lpt.loop_of_block(b1), Some(lp0));
366 debug_assert_eq!(lpt.loop_of_block(b2), Some(lp0));
367 debug_assert_eq!(lpt.loop_of_block(b3), Some(lp0));
368 debug_assert_eq!(lpt.loop_of_block(b4), Some(lp0));
369 debug_assert_eq!(lpt.loop_of_block(b5), Some(lp0));
370 debug_assert_eq!(lpt.loop_of_block(b6), None);
371
372 debug_assert_eq!(lpt.loop_header(lp0), b1);
373 }
374
375 #[test]
376 fn single_block_loop() {
377 let mut test_module_builder = TestModuleBuilder::new();
378 let mut builder = test_module_builder.func_builder(&[Type::I1], Type::Void);
379 let b0 = builder.append_block();
380 let b1 = builder.append_block();
381 let b2 = builder.append_block();
382
383 let arg = builder.args()[0];
384
385 builder.switch_to_block(b0);
386 builder.jump(b1);
387
388 builder.switch_to_block(b1);
389 builder.br(arg, b1, b2);
390
391 builder.switch_to_block(b2);
392 builder.ret(None);
393
394 builder.seal_all();
395 let func_ref = builder.finish();
396
397 let module = test_module_builder.build();
398 let func = &module.funcs[func_ref];
399 let lpt = compute_loop(func);
400
401 debug_assert_eq!(lpt.loop_num(), 1);
402 let lp0 = lpt.loops().next().unwrap();
403
404 debug_assert_eq!(lpt.loop_of_block(b0), None);
405 debug_assert_eq!(lpt.loop_of_block(b1), Some(lp0));
406 debug_assert_eq!(lpt.loop_of_block(b2), None);
407 }
408
409 #[test]
410 fn nested_loop() {
411 let mut test_module_builder = TestModuleBuilder::new();
412 let mut builder = test_module_builder.func_builder(&[Type::I1], Type::Void);
413
414 let b0 = builder.append_block();
415 let b1 = builder.append_block();
416 let b2 = builder.append_block();
417 let b3 = builder.append_block();
418 let b4 = builder.append_block();
419 let b5 = builder.append_block();
420 let b6 = builder.append_block();
421 let b7 = builder.append_block();
422 let b8 = builder.append_block();
423 let b9 = builder.append_block();
424 let b10 = builder.append_block();
425 let b11 = builder.append_block();
426
427 let arg = builder.args()[0];
428
429 builder.switch_to_block(b0);
430 builder.jump(b1);
431
432 builder.switch_to_block(b1);
433 builder.jump(b2);
434
435 builder.switch_to_block(b2);
436 builder.br(arg, b3, b7);
437
438 builder.switch_to_block(b3);
439 builder.jump(b4);
440
441 builder.switch_to_block(b4);
442 builder.jump(b5);
443
444 builder.switch_to_block(b5);
445 builder.br(arg, b4, b6);
446
447 builder.switch_to_block(b6);
448 builder.br(arg, b3, b10);
449
450 builder.switch_to_block(b7);
451 builder.jump(b8);
452
453 builder.switch_to_block(b8);
454 builder.br(arg, b9, b7);
455
456 builder.switch_to_block(b9);
457 builder.br(arg, b1, b10);
458
459 builder.switch_to_block(b10);
460 builder.jump(b1);
461
462 builder.switch_to_block(b11);
463 builder.ret(None);
464
465 builder.seal_all();
466 let func_ref = builder.finish();
467
468 let module = test_module_builder.build();
469 let func = &module.funcs[func_ref];
470 let lpt = compute_loop(func);
471
472 debug_assert_eq!(lpt.loop_num(), 4);
473 let l0 = lpt.loop_of_block(b1).unwrap();
474 let l1 = lpt.loop_of_block(b3).unwrap();
475 let l2 = lpt.loop_of_block(b4).unwrap();
476 let l3 = lpt.loop_of_block(b7).unwrap();
477
478 debug_assert_eq!(lpt.loop_of_block(b0), None);
479 debug_assert_eq!(lpt.loop_of_block(b1), Some(l0));
480 debug_assert_eq!(lpt.loop_of_block(b2), Some(l0));
481 debug_assert_eq!(lpt.loop_of_block(b3), Some(l1));
482 debug_assert_eq!(lpt.loop_of_block(b4), Some(l2));
483 debug_assert_eq!(lpt.loop_of_block(b5), Some(l2));
484 debug_assert_eq!(lpt.loop_of_block(b6), Some(l1));
485 debug_assert_eq!(lpt.loop_of_block(b7), Some(l3));
486 debug_assert_eq!(lpt.loop_of_block(b8), Some(l3));
487 debug_assert_eq!(lpt.loop_of_block(b9), Some(l0));
488 debug_assert_eq!(lpt.loop_of_block(b10), Some(l0));
489 debug_assert_eq!(lpt.loop_of_block(b11), None);
490
491 debug_assert_eq!(lpt.parent_loop(l0), None);
492 debug_assert_eq!(lpt.parent_loop(l1), Some(l0));
493 debug_assert_eq!(lpt.parent_loop(l2), Some(l1));
494 debug_assert_eq!(lpt.parent_loop(l3), Some(l0));
495
496 debug_assert_eq!(lpt.loop_header(l0), b1);
497 debug_assert_eq!(lpt.loop_header(l1), b3);
498 debug_assert_eq!(lpt.loop_header(l2), b4);
499 debug_assert_eq!(lpt.loop_header(l3), b7);
500 }
501}