use crate::cfg::MirCfg;
use crate::dominators::DominatorTree;
use crate::BlockId;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
pub struct LoopId(pub u32);
#[derive(Debug, Clone)]
pub struct LoopInfo {
pub id: LoopId,
pub header: BlockId,
pub body_blocks: Vec<BlockId>,
pub back_edge_sources: Vec<BlockId>,
pub exit_blocks: Vec<BlockId>,
pub preheader: Option<BlockId>,
pub parent: Option<LoopId>,
pub children: Vec<LoopId>,
pub depth: u32,
pub is_countable: bool,
pub trip_count_hint: Option<u64>,
pub num_exits: u32,
pub schedule: SchedulePlan,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum SchedulePlan {
SequentialStrict,
DescriptiveTiled { tile_size: u32 },
DescriptiveVectorized { width: u32 },
DescriptiveMaterializeBoundary,
DescriptiveStaticPartition { chunk_size: u32 },
}
impl Default for SchedulePlan {
fn default() -> Self {
SchedulePlan::SequentialStrict
}
}
impl std::fmt::Display for SchedulePlan {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
SchedulePlan::SequentialStrict => write!(f, "sequential_strict"),
SchedulePlan::DescriptiveTiled { tile_size } => {
write!(f, "descriptive_tiled({})", tile_size)
}
SchedulePlan::DescriptiveVectorized { width } => {
write!(f, "descriptive_vectorized({})", width)
}
SchedulePlan::DescriptiveMaterializeBoundary => {
write!(f, "descriptive_materialize_boundary")
}
SchedulePlan::DescriptiveStaticPartition { chunk_size } => {
write!(f, "descriptive_static_partition({})", chunk_size)
}
}
}
}
#[derive(Debug, Clone)]
pub struct LoopTree {
pub loops: Vec<LoopInfo>,
pub block_to_loop: Vec<Option<LoopId>>,
pub num_blocks: usize,
}
impl LoopTree {
pub fn len(&self) -> usize {
self.loops.len()
}
pub fn is_empty(&self) -> bool {
self.loops.is_empty()
}
pub fn get(&self, id: LoopId) -> &LoopInfo {
&self.loops[id.0 as usize]
}
pub fn loop_for_block(&self, block: BlockId) -> Option<LoopId> {
self.block_to_loop.get(block.0 as usize).copied().flatten()
}
pub fn is_block_in_loop(&self, block: BlockId, loop_id: LoopId) -> bool {
let mut current = self.loop_for_block(block);
while let Some(lid) = current {
if lid == loop_id {
return true;
}
current = self.loops[lid.0 as usize].parent;
}
false
}
pub fn is_nested_in(&self, inner: LoopId, outer: LoopId) -> bool {
let mut current = Some(inner);
while let Some(lid) = current {
if lid == outer {
return true;
}
current = self.loops[lid.0 as usize].parent;
}
false
}
pub fn root_loops(&self) -> Vec<LoopId> {
self.loops
.iter()
.filter(|l| l.parent.is_none())
.map(|l| l.id)
.collect()
}
pub fn max_depth(&self) -> u32 {
self.loops.iter().map(|l| l.depth).max().unwrap_or(0)
}
}
pub fn compute_loop_tree(cfg: &MirCfg, domtree: &DominatorTree) -> LoopTree {
let num_blocks = cfg.basic_blocks.len();
let preds = cfg.predecessors();
let mut back_edges: Vec<Vec<BlockId>> = vec![Vec::new(); num_blocks];
for bb in &cfg.basic_blocks {
for succ in bb.terminator.successors() {
if domtree.dominates(succ, bb.id) {
back_edges[succ.0 as usize].push(bb.id);
}
}
}
let mut loop_infos: Vec<LoopInfo> = Vec::new();
let mut next_loop_id: u32 = 0;
for header_idx in 0..num_blocks {
let sources = &back_edges[header_idx];
if sources.is_empty() {
continue;
}
let header = BlockId(header_idx as u32);
let mut sorted_sources = sources.clone();
sorted_sources.sort();
let body = compute_loop_body(header, &sorted_sources, &preds, num_blocks);
let lid = LoopId(next_loop_id);
next_loop_id += 1;
loop_infos.push(LoopInfo {
id: lid,
header,
body_blocks: body,
back_edge_sources: sorted_sources,
exit_blocks: Vec::new(), preheader: None, parent: None, children: Vec::new(), depth: 0, is_countable: false,
trip_count_hint: None,
num_exits: 0,
schedule: SchedulePlan::default(),
});
}
let n_loops = loop_infos.len();
for i in 0..n_loops {
let header_i = loop_infos[i].header;
let mut best_parent: Option<LoopId> = None;
let mut best_parent_size = usize::MAX;
for j in 0..n_loops {
if i == j {
continue;
}
if loop_infos[j]
.body_blocks
.binary_search(&header_i)
.is_ok()
{
let size = loop_infos[j].body_blocks.len();
if size < best_parent_size {
best_parent = Some(loop_infos[j].id);
best_parent_size = size;
}
}
}
loop_infos[i].parent = best_parent;
}
for i in 0..n_loops {
let parent = loop_infos[i].parent;
if let Some(pid) = parent {
let _ = pid; }
}
let parents: Vec<Option<LoopId>> = loop_infos.iter().map(|l| l.parent).collect();
for (i, parent) in parents.iter().enumerate() {
if let Some(pid) = parent {
loop_infos[pid.0 as usize].children.push(LoopId(i as u32));
}
}
for info in &mut loop_infos {
info.children.sort();
}
for i in 0..n_loops {
let mut depth = 0u32;
let mut cur = loop_infos[i].parent;
while let Some(pid) = cur {
depth += 1;
cur = loop_infos[pid.0 as usize].parent;
}
loop_infos[i].depth = depth;
}
for i in 0..n_loops {
let body = &loop_infos[i].body_blocks;
let mut exits: Vec<BlockId> = Vec::new();
for &block in body {
for succ in cfg.block(block).terminator.successors() {
if body.binary_search(&succ).is_err() && !exits.contains(&succ) {
exits.push(succ);
}
}
}
exits.sort();
loop_infos[i].num_exits = exits.len() as u32;
loop_infos[i].exit_blocks = exits;
let header = loop_infos[i].header;
let back_srcs = &loop_infos[i].back_edge_sources;
let non_back_preds: Vec<BlockId> = preds[header.0 as usize]
.iter()
.filter(|p| back_srcs.binary_search(p).is_err())
.copied()
.collect();
loop_infos[i].preheader = if non_back_preds.len() == 1 {
Some(non_back_preds[0])
} else {
None
};
}
let mut block_to_loop: Vec<Option<LoopId>> = vec![None; num_blocks];
let mut by_size: Vec<usize> = (0..n_loops).collect();
by_size.sort_by(|&a, &b| {
loop_infos[b]
.body_blocks
.len()
.cmp(&loop_infos[a].body_blocks.len())
});
for idx in by_size {
let lid = loop_infos[idx].id;
for &block in &loop_infos[idx].body_blocks {
block_to_loop[block.0 as usize] = Some(lid);
}
}
LoopTree {
loops: loop_infos,
block_to_loop,
num_blocks,
}
}
fn compute_loop_body(
header: BlockId,
back_edge_sources: &[BlockId],
preds: &[Vec<BlockId>],
_num_blocks: usize,
) -> Vec<BlockId> {
let mut in_loop = vec![false; preds.len()];
in_loop[header.0 as usize] = true;
let mut worklist: Vec<BlockId> = Vec::new();
for &src in back_edge_sources {
if !in_loop[src.0 as usize] {
in_loop[src.0 as usize] = true;
worklist.push(src);
}
}
while let Some(block) = worklist.pop() {
for &pred in &preds[block.0 as usize] {
if !in_loop[pred.0 as usize] {
in_loop[pred.0 as usize] = true;
worklist.push(pred);
}
}
}
let mut body: Vec<BlockId> = in_loop
.iter()
.enumerate()
.filter_map(|(i, &b)| if b { Some(BlockId(i as u32)) } else { None })
.collect();
body.sort();
body
}
#[cfg(test)]
mod tests {
use super::*;
use crate::cfg::{BasicBlock, CfgBuilder, MirCfg, Terminator};
use crate::dominators::DominatorTree;
use crate::{MirBody, MirExpr, MirExprKind, MirStmt};
fn int_expr(v: i64) -> MirExpr {
MirExpr {
kind: MirExprKind::IntLit(v),
}
}
fn bool_expr(b: bool) -> MirExpr {
MirExpr {
kind: MirExprKind::BoolLit(b),
}
}
fn loop_tree_from_body(body: &MirBody) -> (MirCfg, DominatorTree, LoopTree) {
let cfg = CfgBuilder::build(body);
let domtree = DominatorTree::compute(&cfg);
let loops = compute_loop_tree(&cfg, &domtree);
(cfg, domtree, loops)
}
#[test]
fn test_no_loops_empty_tree() {
let body = MirBody {
stmts: vec![MirStmt::Let {
name: "x".into(),
mutable: false,
init: int_expr(42),
alloc_hint: None,
}],
result: Some(Box::new(int_expr(42))),
};
let (_, _, loops) = loop_tree_from_body(&body);
assert!(loops.is_empty());
assert_eq!(loops.len(), 0);
assert_eq!(loops.max_depth(), 0);
}
#[test]
fn test_single_while_loop() {
let body = MirBody {
stmts: vec![MirStmt::While {
cond: bool_expr(true),
body: MirBody {
stmts: vec![MirStmt::Expr(int_expr(1))],
result: None,
},
}],
result: None,
};
let (cfg, _, loops) = loop_tree_from_body(&body);
assert_eq!(loops.len(), 1, "should detect exactly 1 loop");
let loop0 = loops.get(LoopId(0));
assert_eq!(loop0.depth, 0, "outermost loop has depth 0");
assert!(loop0.parent.is_none(), "no parent loop");
assert!(loop0.children.is_empty(), "no child loops");
assert!(
!loop0.body_blocks.is_empty(),
"body should have at least header"
);
assert!(
loop0.body_blocks.contains(&loop0.header),
"header is part of body"
);
assert!(
!loop0.back_edge_sources.is_empty(),
"should have at least one back-edge source"
);
assert!(
!loop0.exit_blocks.is_empty(),
"should have at least one exit block"
);
assert!(cfg.is_loop_header(loop0.header));
}
#[test]
fn test_nested_while_loops() {
let body = MirBody {
stmts: vec![MirStmt::While {
cond: bool_expr(true),
body: MirBody {
stmts: vec![MirStmt::While {
cond: bool_expr(true),
body: MirBody {
stmts: vec![MirStmt::Expr(int_expr(1))],
result: None,
},
}],
result: None,
},
}],
result: None,
};
let (_, _, loops) = loop_tree_from_body(&body);
assert_eq!(loops.len(), 2, "should detect 2 loops (outer + inner)");
let outer = loops.loops.iter().find(|l| l.depth == 0).unwrap();
let inner = loops.loops.iter().find(|l| l.depth == 1).unwrap();
assert!(outer.parent.is_none());
assert_eq!(inner.parent, Some(outer.id));
assert!(outer.children.contains(&inner.id));
assert!(inner.children.is_empty());
assert!(outer.body_blocks.contains(&inner.header));
assert!(loops.is_nested_in(inner.id, outer.id));
assert!(!loops.is_nested_in(outer.id, inner.id));
assert_eq!(loops.max_depth(), 1);
assert_eq!(loops.root_loops().len(), 1);
}
#[test]
fn test_sequential_loops() {
let body = MirBody {
stmts: vec![
MirStmt::While {
cond: bool_expr(true),
body: MirBody {
stmts: vec![MirStmt::Expr(int_expr(1))],
result: None,
},
},
MirStmt::While {
cond: bool_expr(true),
body: MirBody {
stmts: vec![MirStmt::Expr(int_expr(2))],
result: None,
},
},
],
result: None,
};
let (_, _, loops) = loop_tree_from_body(&body);
assert_eq!(loops.len(), 2, "should detect 2 independent loops");
assert!(loops.loops.iter().all(|l| l.depth == 0));
assert!(loops.loops.iter().all(|l| l.parent.is_none()));
assert_eq!(loops.root_loops().len(), 2);
assert!(!loops.is_nested_in(LoopId(0), LoopId(1)));
assert!(!loops.is_nested_in(LoopId(1), LoopId(0)));
}
#[test]
fn test_preheader_detection() {
let body = MirBody {
stmts: vec![MirStmt::While {
cond: bool_expr(true),
body: MirBody {
stmts: vec![],
result: None,
},
}],
result: None,
};
let (_, _, loops) = loop_tree_from_body(&body);
assert_eq!(loops.len(), 1);
let loop0 = loops.get(LoopId(0));
assert!(
loop0.preheader.is_some(),
"simple while loop should have a preheader"
);
}
#[test]
fn test_block_to_loop_mapping() {
let body = MirBody {
stmts: vec![MirStmt::While {
cond: bool_expr(true),
body: MirBody {
stmts: vec![MirStmt::Expr(int_expr(1))],
result: None,
},
}],
result: None,
};
let (cfg, _, loops) = loop_tree_from_body(&body);
assert_eq!(loops.len(), 1);
let loop0 = loops.get(LoopId(0));
for &block in &loop0.body_blocks {
assert_eq!(
loops.loop_for_block(block),
Some(LoopId(0)),
"body block {:?} should map to loop 0",
block
);
}
for &block in &loop0.exit_blocks {
if loop0.body_blocks.binary_search(&block).is_err() {
assert_ne!(
loops.loop_for_block(block),
Some(LoopId(0)),
"exit block {:?} should not be in loop 0",
block
);
}
}
}
#[test]
fn test_loop_tree_determinism() {
let body = MirBody {
stmts: vec![
MirStmt::While {
cond: bool_expr(true),
body: MirBody {
stmts: vec![MirStmt::While {
cond: bool_expr(true),
body: MirBody {
stmts: vec![MirStmt::Expr(int_expr(1))],
result: None,
},
}],
result: None,
},
},
MirStmt::While {
cond: bool_expr(false),
body: MirBody {
stmts: vec![],
result: None,
},
},
],
result: None,
};
let (_, _, loops1) = loop_tree_from_body(&body);
let (_, _, loops2) = loop_tree_from_body(&body);
assert_eq!(loops1.len(), loops2.len());
for i in 0..loops1.len() {
let a = &loops1.loops[i];
let b = &loops2.loops[i];
assert_eq!(a.id, b.id);
assert_eq!(a.header, b.header);
assert_eq!(a.body_blocks, b.body_blocks);
assert_eq!(a.back_edge_sources, b.back_edge_sources);
assert_eq!(a.exit_blocks, b.exit_blocks);
assert_eq!(a.preheader, b.preheader);
assert_eq!(a.parent, b.parent);
assert_eq!(a.children, b.children);
assert_eq!(a.depth, b.depth);
}
assert_eq!(loops1.block_to_loop, loops2.block_to_loop);
}
#[test]
fn test_is_block_in_loop_transitive() {
let body = MirBody {
stmts: vec![MirStmt::While {
cond: bool_expr(true),
body: MirBody {
stmts: vec![MirStmt::While {
cond: bool_expr(true),
body: MirBody {
stmts: vec![MirStmt::Expr(int_expr(1))],
result: None,
},
}],
result: None,
},
}],
result: None,
};
let (_, _, loops) = loop_tree_from_body(&body);
assert_eq!(loops.len(), 2);
let outer = loops.loops.iter().find(|l| l.depth == 0).unwrap();
let inner = loops.loops.iter().find(|l| l.depth == 1).unwrap();
assert!(loops.is_block_in_loop(inner.header, outer.id));
for &block in &inner.body_blocks {
assert!(
loops.is_block_in_loop(block, outer.id),
"inner block {:?} should be transitively in outer loop",
block
);
}
}
}