use std::collections::{HashMap, HashSet};
use crate::{
analysis::{SsaFunction, SsaOp, SsaVarId},
utils::graph::{algorithms::DominatorTree, GraphBase, NodeId, Predecessors, Successors},
};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum LoopType {
PreTested,
PostTested,
Infinite,
Complex,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct LoopExit {
pub exiting_block: NodeId,
pub exit_block: NodeId,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum InductionUpdateKind {
Add,
Sub,
Mul,
Unknown,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct InductionVar {
pub phi_result: SsaVarId,
pub init_value: SsaVarId,
pub init_block: NodeId,
pub update_value: SsaVarId,
pub update_block: NodeId,
pub update_kind: InductionUpdateKind,
pub stride: Option<i64>,
}
#[derive(Debug, Clone)]
pub struct LoopInfo {
pub header: NodeId,
pub body: HashSet<NodeId>,
pub latches: Vec<NodeId>,
pub preheader: Option<NodeId>,
pub exits: Vec<LoopExit>,
pub depth: usize,
pub loop_type: LoopType,
pub parent: Option<NodeId>,
pub children: Vec<NodeId>,
}
impl LoopInfo {
#[must_use]
pub fn new(header: NodeId) -> Self {
let mut body = HashSet::new();
body.insert(header);
Self {
header,
body,
latches: Vec::new(),
preheader: None,
exits: Vec::new(),
depth: 0,
loop_type: LoopType::Complex,
parent: None,
children: Vec::new(),
}
}
#[must_use]
pub fn contains(&self, node: NodeId) -> bool {
self.body.contains(&node)
}
#[must_use]
pub fn size(&self) -> usize {
self.body.len()
}
#[must_use]
pub fn has_single_latch(&self) -> bool {
self.latches.len() == 1
}
#[must_use]
pub fn single_latch(&self) -> Option<NodeId> {
if self.latches.len() == 1 {
Some(self.latches[0])
} else {
None
}
}
#[must_use]
pub fn has_preheader(&self) -> bool {
self.preheader.is_some()
}
#[must_use]
pub fn is_canonical(&self) -> bool {
self.has_preheader() && self.has_single_latch()
}
#[must_use]
pub fn is_innermost(&self) -> bool {
self.children.is_empty()
}
#[must_use]
pub fn is_outermost(&self) -> bool {
self.parent.is_none()
}
pub fn exit_blocks(&self) -> impl Iterator<Item = NodeId> + '_ {
self.exits.iter().map(|e| e.exit_block)
}
pub fn exiting_blocks(&self) -> impl Iterator<Item = NodeId> + '_ {
self.exits.iter().map(|e| e.exiting_block)
}
#[must_use]
pub fn exit_count(&self) -> usize {
self.exits.len()
}
#[must_use]
pub fn header_is_exiting(&self) -> bool {
self.exits.iter().any(|e| e.exiting_block == self.header)
}
#[must_use]
pub fn latch_is_exiting(&self) -> bool {
self.exits
.iter()
.any(|e| self.latches.contains(&e.exiting_block))
}
#[must_use]
pub fn find_condition_in_body(&self, ssa: &SsaFunction) -> Option<NodeId> {
for &block_id in &self.body {
if let Some(block) = ssa.block(block_id.index()) {
if matches!(block.terminator_op(), Some(SsaOp::Branch { .. })) {
return Some(block_id);
}
}
}
None
}
#[must_use]
pub fn find_all_conditions_in_body(&self, ssa: &SsaFunction) -> Vec<NodeId> {
self.body
.iter()
.filter(|&&block_id| {
ssa.block(block_id.index())
.is_some_and(|b| matches!(b.terminator_op(), Some(SsaOp::Branch { .. })))
})
.copied()
.collect()
}
#[must_use]
pub fn find_induction_vars(&self, ssa: &SsaFunction) -> Vec<InductionVar> {
let mut induction_vars = Vec::new();
let Some(header_block) = ssa.block(self.header.index()) else {
return induction_vars;
};
for phi in header_block.phi_nodes() {
let operands = phi.operands();
if operands.len() < 2 {
continue;
}
let (inside_ops, outside_ops): (Vec<&_>, Vec<&_>) = operands
.iter()
.partition(|op| self.body.contains(&NodeId::new(op.predecessor())));
if outside_ops.len() == 1 && !inside_ops.is_empty() {
let init_op = outside_ops[0];
let update_op = inside_ops[0];
let (update_kind, stride) =
Self::analyze_update_instruction(ssa, update_op.value(), phi.result());
induction_vars.push(InductionVar {
phi_result: phi.result(),
init_value: init_op.value(),
init_block: NodeId::new(init_op.predecessor()),
update_value: update_op.value(),
update_block: NodeId::new(update_op.predecessor()),
update_kind,
stride,
});
}
}
induction_vars
}
fn analyze_update_instruction(
ssa: &SsaFunction,
update_var: SsaVarId,
phi_result: SsaVarId,
) -> (InductionUpdateKind, Option<i64>) {
let Some(var) = ssa.variable(update_var) else {
return (InductionUpdateKind::Unknown, None);
};
let def_site = var.def_site();
if def_site.is_phi() {
return (InductionUpdateKind::Unknown, None);
}
let Some(block) = ssa.block(def_site.block) else {
return (InductionUpdateKind::Unknown, None);
};
let Some(instr_idx) = def_site.instruction else {
return (InductionUpdateKind::Unknown, None);
};
let Some(instr) = block.instruction(instr_idx) else {
return (InductionUpdateKind::Unknown, None);
};
match instr.op() {
SsaOp::Add { left, right, .. } => {
if *left == phi_result || *right == phi_result {
let other = if *left == phi_result { *right } else { *left };
let stride = Self::try_get_constant(ssa, other);
return (InductionUpdateKind::Add, stride);
}
}
SsaOp::Sub { left, right, .. } => {
if *left == phi_result {
let stride = Self::try_get_constant(ssa, *right);
return (InductionUpdateKind::Sub, stride);
}
}
SsaOp::Mul { left, right, .. } => {
if *left == phi_result || *right == phi_result {
let other = if *left == phi_result { *right } else { *left };
let stride = Self::try_get_constant(ssa, other);
return (InductionUpdateKind::Mul, stride);
}
}
_ => {}
}
(InductionUpdateKind::Unknown, None)
}
fn try_get_constant(ssa: &SsaFunction, var: SsaVarId) -> Option<i64> {
let variable = ssa.variable(var)?;
let def_site = variable.def_site();
if def_site.is_phi() {
return None;
}
let block = ssa.block(def_site.block)?;
let instr_idx = def_site.instruction?;
let instr = block.instruction(instr_idx)?;
match instr.op() {
SsaOp::Const { value, .. } => value.as_i64(),
_ => None,
}
}
}
#[derive(Debug, Clone)]
pub struct LoopForest {
loops: Vec<LoopInfo>,
block_to_loop: Vec<Option<usize>>,
}
impl LoopForest {
#[must_use]
pub fn new(block_count: usize) -> Self {
Self {
loops: Vec::new(),
block_to_loop: vec![None; block_count],
}
}
pub fn add_loop(&mut self, loop_info: LoopInfo) {
let loop_idx = self.loops.len();
for &block in &loop_info.body {
let block_idx = block.index();
if block_idx < self.block_to_loop.len() {
if let Some(existing_idx) = self.block_to_loop[block_idx] {
if self.loops[existing_idx].depth < loop_info.depth {
self.block_to_loop[block_idx] = Some(loop_idx);
}
} else {
self.block_to_loop[block_idx] = Some(loop_idx);
}
}
}
self.loops.push(loop_info);
}
#[must_use]
pub fn loops(&self) -> &[LoopInfo] {
&self.loops
}
#[must_use]
pub fn len(&self) -> usize {
self.loops.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.loops.is_empty()
}
#[must_use]
pub fn innermost_loop(&self, block: NodeId) -> Option<&LoopInfo> {
let block_idx = block.index();
if block_idx < self.block_to_loop.len() {
self.block_to_loop[block_idx].map(|idx| &self.loops[idx])
} else {
None
}
}
#[must_use]
pub fn loop_for_header(&self, header: NodeId) -> Option<&LoopInfo> {
self.loops.iter().find(|l| l.header == header)
}
#[must_use]
pub fn loop_depth(&self, block: NodeId) -> usize {
self.innermost_loop(block).map_or(0, |l| l.depth + 1)
}
#[must_use]
pub fn is_in_loop(&self, block: NodeId) -> bool {
self.innermost_loop(block).is_some()
}
pub fn iter(&self) -> impl Iterator<Item = &LoopInfo> {
self.loops.iter()
}
#[must_use]
pub fn by_depth_ascending(&self) -> Vec<&LoopInfo> {
let mut sorted: Vec<_> = self.loops.iter().collect();
sorted.sort_by_key(|l| l.depth);
sorted
}
#[must_use]
pub fn by_depth_descending(&self) -> Vec<&LoopInfo> {
let mut sorted: Vec<_> = self.loops.iter().collect();
sorted.sort_by_key(|l| std::cmp::Reverse(l.depth));
sorted
}
}
#[must_use]
pub fn detect_loops<G>(graph: &G, dominators: &DominatorTree) -> LoopForest
where
G: GraphBase + Successors + Predecessors,
{
let block_count = graph.node_count();
let mut forest = LoopForest::new(block_count);
let mut loops_by_header: HashMap<NodeId, LoopInfo> = HashMap::new();
for node in graph.node_ids() {
for succ in graph.successors(node) {
if dominators.dominates(succ, node) {
let header = succ;
let loop_info = loops_by_header
.entry(header)
.or_insert_with(|| LoopInfo::new(header));
loop_info.latches.push(node);
expand_loop_body(graph, loop_info, node);
}
}
}
for loop_info in loops_by_header.values_mut() {
compute_preheader(graph, loop_info);
compute_exits(graph, loop_info);
loop_info.loop_type = classify_loop(loop_info);
}
let mut loops: Vec<LoopInfo> = loops_by_header.into_values().collect();
compute_nesting(&mut loops);
loops.sort_by_key(|l| l.header.index());
for loop_info in loops {
forest.add_loop(loop_info);
}
forest
}
#[must_use]
pub fn has_back_edges<G>(graph: &G, dominators: &DominatorTree) -> bool
where
G: GraphBase + Successors,
{
for node in graph.node_ids() {
for succ in graph.successors(node) {
if dominators.dominates(succ, node) {
return true;
}
}
}
false
}
fn expand_loop_body<G>(graph: &G, loop_info: &mut LoopInfo, latch: NodeId)
where
G: Predecessors,
{
if loop_info.body.contains(&latch) {
return;
}
let mut worklist = vec![latch];
while let Some(node) = worklist.pop() {
if loop_info.body.insert(node) {
for pred in graph.predecessors(node) {
if pred != loop_info.header && !loop_info.body.contains(&pred) {
worklist.push(pred);
}
}
}
}
}
fn compute_preheader<G>(graph: &G, loop_info: &mut LoopInfo)
where
G: Predecessors,
{
let mut non_loop_preds: Vec<NodeId> = Vec::new();
for pred in graph.predecessors(loop_info.header) {
if !loop_info.body.contains(&pred) {
non_loop_preds.push(pred);
}
}
loop_info.preheader = if non_loop_preds.len() == 1 {
Some(non_loop_preds[0])
} else {
None
};
}
fn compute_exits<G>(graph: &G, loop_info: &mut LoopInfo)
where
G: Successors,
{
loop_info.exits.clear();
for &body_block in &loop_info.body {
for succ in graph.successors(body_block) {
if !loop_info.body.contains(&succ) {
loop_info.exits.push(LoopExit {
exiting_block: body_block,
exit_block: succ,
});
}
}
}
}
fn classify_loop(loop_info: &LoopInfo) -> LoopType {
if loop_info.exits.is_empty() {
return LoopType::Infinite;
}
if loop_info.latches.len() > 1 {
return LoopType::Complex;
}
let latch = loop_info.single_latch();
if let Some(latch) = latch {
let latch_exits = loop_info
.exits
.iter()
.filter(|e| e.exiting_block == latch)
.count();
if latch_exits == loop_info.exits.len() && latch_exits > 0 {
return LoopType::PostTested;
}
}
let header_exits = loop_info
.exits
.iter()
.filter(|e| e.exiting_block == loop_info.header)
.count();
if header_exits == loop_info.exits.len() && header_exits > 0 {
return LoopType::PreTested;
}
LoopType::Complex
}
fn compute_nesting(loops: &mut [LoopInfo]) {
let n = loops.len();
let header_to_idx: HashMap<NodeId, usize> = loops
.iter()
.enumerate()
.map(|(i, l)| (l.header, i))
.collect();
for i in 0..n {
let header = loops[i].header;
let mut candidates: Vec<usize> = (0..n)
.filter(|&j| j != i && loops[j].body.contains(&header))
.collect();
if !candidates.is_empty() {
candidates.sort_by_key(|&j| loops[j].size());
let parent_idx = candidates[0];
loops[i].parent = Some(loops[parent_idx].header);
}
}
for i in 0..n {
if let Some(parent_header) = loops[i].parent {
if let Some(&parent_idx) = header_to_idx.get(&parent_header) {
loops[parent_idx].children.push(loops[i].header);
}
}
}
for i in 0..n {
let mut depth = 0;
let mut current = loops[i].parent;
while let Some(parent_header) = current {
depth += 1;
if let Some(&parent_idx) = header_to_idx.get(&parent_header) {
current = loops[parent_idx].parent;
} else {
break;
}
}
loops[i].depth = depth;
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_loop_info_creation() {
let header = NodeId::new(0);
let loop_info = LoopInfo::new(header);
assert_eq!(loop_info.header, header);
assert!(loop_info.contains(header));
assert_eq!(loop_info.size(), 1);
assert!(!loop_info.has_single_latch());
assert!(!loop_info.has_preheader());
assert!(!loop_info.is_canonical());
}
#[test]
fn test_loop_info_canonical() {
let header = NodeId::new(1);
let mut loop_info = LoopInfo::new(header);
loop_info.preheader = Some(NodeId::new(0));
loop_info.latches.push(NodeId::new(2));
assert!(loop_info.has_preheader());
assert!(loop_info.has_single_latch());
assert!(loop_info.is_canonical());
}
#[test]
fn test_loop_forest() {
let mut forest = LoopForest::new(10);
let mut outer_loop = LoopInfo::new(NodeId::new(1));
outer_loop.body.insert(NodeId::new(2));
outer_loop.body.insert(NodeId::new(3));
outer_loop.depth = 0;
let mut inner_loop = LoopInfo::new(NodeId::new(2));
inner_loop.body.insert(NodeId::new(3));
inner_loop.depth = 1;
forest.add_loop(outer_loop);
forest.add_loop(inner_loop);
assert_eq!(forest.len(), 2);
assert_eq!(forest.loop_depth(NodeId::new(3)), 2);
assert_eq!(forest.loop_depth(NodeId::new(1)), 1);
assert_eq!(forest.loop_depth(NodeId::new(0)), 0);
}
}