use std::collections::{BTreeMap, BTreeSet};
use crate::cfg::{BlockRef, Cfg, DataflowFacts, EdgeRef, GraphFacts};
use crate::transformer::{LowInstr, LoweredProto, Reg, ResultPack};
use super::common::{
LoopCandidate, LoopExitValueMergeCandidate, LoopKindHint, LoopSourceBindings, LoopValueMerge,
};
use super::helpers::{collect_region_exits, is_reducible_region};
use super::phi_facts::loop_value_merges_in_block;
pub(super) fn analyze_loops(
proto: &LoweredProto,
cfg: &Cfg,
graph_facts: &GraphFacts,
dataflow: &DataflowFacts,
) -> Vec<LoopCandidate> {
let mut grouped_loops = BTreeMap::<BlockRef, (BTreeSet<BlockRef>, Vec<EdgeRef>)>::new();
for natural_loop in &graph_facts.natural_loops {
let entry = grouped_loops
.entry(natural_loop.header)
.or_insert_with(|| (BTreeSet::new(), Vec::new()));
entry.0.extend(natural_loop.blocks.iter().copied());
entry.1.push(natural_loop.backedge);
}
let mut loop_candidates = grouped_loops
.into_iter()
.map(|(header, (blocks, mut backedges))| {
backedges.sort();
backedges.dedup();
let preheader = unique_loop_preheader(cfg, header, &blocks);
let exits = collect_region_exits(cfg, &blocks);
let reducible = is_reducible_region(cfg, header, &blocks);
let header_value_merges = analyze_loop_header_value_merges(dataflow, header, &blocks);
let (kind_hint, continue_target, source_bindings) = infer_loop_shape(
proto,
cfg,
header,
&blocks,
&backedges,
preheader,
&header_value_merges,
);
let exit_value_merges = analyze_loop_exit_value_merges(dataflow, &exits, &blocks);
LoopCandidate {
header,
preheader,
blocks,
backedges,
exits,
continue_target,
kind_hint,
source_bindings,
header_value_merges,
exit_value_merges,
reducible,
}
})
.collect::<Vec<_>>();
loop_candidates.sort_by_key(|candidate| candidate.header);
loop_candidates
}
fn infer_loop_shape(
proto: &LoweredProto,
cfg: &Cfg,
header: BlockRef,
blocks: &BTreeSet<BlockRef>,
backedges: &[EdgeRef],
preheader: Option<BlockRef>,
header_value_merges: &[LoopValueMerge],
) -> (LoopKindHint, Option<BlockRef>, Option<LoopSourceBindings>) {
let backedge_sources = backedges
.iter()
.map(|edge_ref| cfg.edges[edge_ref.index()].from)
.collect::<BTreeSet<_>>();
if backedge_sources.len() == 1 {
let source = *backedge_sources
.iter()
.next()
.expect("set length already checked");
if let Some(terminator) = cfg.terminator(&proto.instrs, source)
&& matches!(terminator, LowInstr::NumericForLoop(_instr))
{
return (
LoopKindHint::NumericForLike,
Some(source),
numeric_for_source_bindings(proto, cfg, preheader),
);
}
}
if matches!(
cfg.terminator(&proto.instrs, header),
Some(LowInstr::GenericForLoop(instr))
if generic_for_has_loop_body_and_exit(proto, cfg, header, instr, blocks)
) {
return (
LoopKindHint::GenericForLike,
Some(header),
generic_for_source_bindings(proto, cfg, header),
);
}
if block_is_while_header_like(proto, cfg, header, header_value_merges)
&& branch_has_loop_body_and_exit(cfg, header, blocks)
{
return (LoopKindHint::WhileLike, Some(header), None);
}
if backedge_sources.len() == 1 {
let source = *backedge_sources
.iter()
.next()
.expect("set length already checked");
if matches!(
cfg.terminator(&proto.instrs, source),
Some(LowInstr::Branch(_instr)) if branch_has_header_and_exit(cfg, source, header, blocks)
) {
return (LoopKindHint::RepeatLike, Some(source), None);
}
if matches!(
cfg.terminator(&proto.instrs, source),
Some(LowInstr::Jump(jump))
if cfg.instr_to_block[jump.target.index()] == header
&& repeat_continue_target_via_backedge_pad(proto, cfg, source, blocks).is_some()
) {
return (
LoopKindHint::RepeatLike,
repeat_continue_target_via_backedge_pad(proto, cfg, source, blocks),
None,
);
}
}
let continue_target = if backedge_sources.len() == 1 {
backedge_sources.iter().next().copied()
} else {
None
};
(LoopKindHint::Unknown, continue_target, None)
}
fn numeric_for_source_bindings(
proto: &LoweredProto,
cfg: &Cfg,
preheader: Option<BlockRef>,
) -> Option<LoopSourceBindings> {
let preheader = preheader?;
let instr_ref = cfg.blocks[preheader.index()].instrs.last()?;
match proto.instrs.get(instr_ref.index())? {
LowInstr::NumericForInit(instr) => Some(LoopSourceBindings::Numeric(instr.binding)),
_ => None,
}
}
fn generic_for_source_bindings(
proto: &LoweredProto,
cfg: &Cfg,
header: BlockRef,
) -> Option<LoopSourceBindings> {
let instr_ref = cfg.blocks[header.index()].instrs.last()?;
match proto.instrs.get(instr_ref.index())? {
LowInstr::GenericForLoop(instr) => Some(LoopSourceBindings::Generic(instr.bindings)),
_ => None,
}
}
fn analyze_loop_header_value_merges(
dataflow: &DataflowFacts,
header: BlockRef,
loop_blocks: &BTreeSet<BlockRef>,
) -> Vec<LoopValueMerge> {
loop_value_merges_in_block(dataflow, header, loop_blocks)
.into_iter()
.filter(loop_value_has_inside_and_outside_incoming)
.collect()
}
fn analyze_loop_exit_value_merges(
dataflow: &DataflowFacts,
exits: &BTreeSet<BlockRef>,
loop_blocks: &BTreeSet<BlockRef>,
) -> Vec<LoopExitValueMergeCandidate> {
exits
.iter()
.copied()
.filter_map(|exit| {
let values = loop_value_merges_in_block(dataflow, exit, loop_blocks)
.into_iter()
.filter(|value| !value.inside_arm.is_empty())
.collect::<Vec<_>>();
(!values.is_empty()).then_some(LoopExitValueMergeCandidate { exit, values })
})
.collect()
}
fn loop_value_has_inside_and_outside_incoming(value: &LoopValueMerge) -> bool {
!value.inside_arm.is_empty() && !value.outside_arm.is_empty()
}
fn unique_loop_preheader(
cfg: &Cfg,
header: BlockRef,
loop_blocks: &BTreeSet<BlockRef>,
) -> Option<BlockRef> {
let preds = cfg
.reachable_predecessors(header)
.into_iter()
.filter(|pred| !loop_blocks.contains(pred))
.collect::<Vec<_>>();
let [preheader] = preds.as_slice() else {
return None;
};
Some(*preheader)
}
fn branch_has_loop_body_and_exit(cfg: &Cfg, header: BlockRef, blocks: &BTreeSet<BlockRef>) -> bool {
let Some((then_edge_ref, else_edge_ref)) = cfg.branch_edges(header) else {
return false;
};
let then_block = cfg.edges[then_edge_ref.index()].to;
let else_block = cfg.edges[else_edge_ref.index()].to;
(blocks.contains(&then_block) && !blocks.contains(&else_block))
|| (!blocks.contains(&then_block) && blocks.contains(&else_block))
}
fn branch_has_header_and_exit(
cfg: &Cfg,
block: BlockRef,
header: BlockRef,
blocks: &BTreeSet<BlockRef>,
) -> bool {
let Some((then_edge_ref, else_edge_ref)) = cfg.branch_edges(block) else {
return false;
};
let then_block = cfg.edges[then_edge_ref.index()].to;
let else_block = cfg.edges[else_edge_ref.index()].to;
(then_block == header && !blocks.contains(&else_block))
|| (else_block == header && !blocks.contains(&then_block))
}
fn block_is_while_header_like(
proto: &LoweredProto,
cfg: &Cfg,
block: BlockRef,
header_value_merges: &[LoopValueMerge],
) -> bool {
let range = cfg.blocks[block.index()].instrs;
if !matches!(
cfg.terminator(&proto.instrs, block),
Some(LowInstr::Branch(_))
) {
return false;
}
if range.len == 1 {
return true;
}
let carried_regs = header_value_merges
.iter()
.map(|value| value.reg)
.collect::<BTreeSet<_>>();
(range.start.index()..range.end() - 1).all(|instr_index| {
let instr = &proto.instrs[instr_index];
instr_is_while_header_prefix(instr) && !instr_writes_any_reg(instr, &carried_regs)
})
}
fn instr_is_while_header_prefix(instr: &LowInstr) -> bool {
!matches!(
instr,
LowInstr::SetUpvalue(_)
| LowInstr::SetTable(_)
| LowInstr::SetList(_)
| LowInstr::TailCall(_)
| LowInstr::Return(_)
| LowInstr::Close(_)
| LowInstr::Tbc(_)
| LowInstr::NumericForInit(_)
| LowInstr::NumericForLoop(_)
| LowInstr::GenericForCall(_)
| LowInstr::GenericForLoop(_)
| LowInstr::Jump(_)
| LowInstr::Branch(_)
)
}
fn instr_writes_any_reg(instr: &LowInstr, regs: &BTreeSet<Reg>) -> bool {
match instr {
LowInstr::Move(instr) => regs.contains(&instr.dst),
LowInstr::LoadNil(instr) => (0..instr.dst.len)
.map(|offset| Reg(instr.dst.start.index() + offset))
.any(|reg| regs.contains(®)),
LowInstr::LoadBool(instr) => regs.contains(&instr.dst),
LowInstr::LoadConst(instr) => regs.contains(&instr.dst),
LowInstr::LoadInteger(instr) => regs.contains(&instr.dst),
LowInstr::LoadNumber(instr) => regs.contains(&instr.dst),
LowInstr::UnaryOp(instr) => regs.contains(&instr.dst),
LowInstr::BinaryOp(instr) => regs.contains(&instr.dst),
LowInstr::Concat(instr) => regs.contains(&instr.dst),
LowInstr::GetUpvalue(instr) => regs.contains(&instr.dst),
LowInstr::GetTable(instr) => regs.contains(&instr.dst),
LowInstr::NewTable(instr) => regs.contains(&instr.dst),
LowInstr::Closure(instr) => regs.contains(&instr.dst),
LowInstr::Call(instr) => result_pack_writes_any_reg(&instr.results, regs),
LowInstr::VarArg(instr) => result_pack_writes_any_reg(&instr.results, regs),
LowInstr::ErrNil(_) => false,
LowInstr::SetUpvalue(_)
| LowInstr::SetTable(_)
| LowInstr::SetList(_)
| LowInstr::TailCall(_)
| LowInstr::Return(_)
| LowInstr::Close(_)
| LowInstr::Tbc(_)
| LowInstr::NumericForInit(_)
| LowInstr::NumericForLoop(_)
| LowInstr::GenericForCall(_)
| LowInstr::GenericForLoop(_)
| LowInstr::Jump(_)
| LowInstr::Branch(_) => false,
}
}
fn result_pack_writes_any_reg(results: &ResultPack, regs: &BTreeSet<Reg>) -> bool {
match results {
ResultPack::Fixed(range) => (0..range.len)
.map(|offset| Reg(range.start.index() + offset))
.any(|reg| regs.contains(®)),
ResultPack::Open(start) => regs.iter().any(|reg| reg.index() >= start.index()),
ResultPack::Ignore => false,
}
}
fn repeat_continue_target_via_backedge_pad(
proto: &LoweredProto,
cfg: &Cfg,
backedge_source: BlockRef,
blocks: &BTreeSet<BlockRef>,
) -> Option<BlockRef> {
let preds = cfg
.reachable_predecessors(backedge_source)
.into_iter()
.filter(|pred| blocks.contains(pred))
.collect::<Vec<_>>();
let [continue_target] = preds.as_slice() else {
return None;
};
if !matches!(
cfg.terminator(&proto.instrs, *continue_target),
Some(LowInstr::Branch(_))
) {
return None;
}
let (then_edge_ref, else_edge_ref) = cfg.branch_edges(*continue_target)?;
let then_block = cfg.edges[then_edge_ref.index()].to;
let else_block = cfg.edges[else_edge_ref.index()].to;
if (then_block == backedge_source && !blocks.contains(&else_block))
|| (else_block == backedge_source && !blocks.contains(&then_block))
{
Some(*continue_target)
} else {
None
}
}
fn generic_for_has_loop_body_and_exit(
proto: &LoweredProto,
cfg: &Cfg,
header: BlockRef,
instr: &crate::transformer::GenericForLoopInstr,
blocks: &BTreeSet<BlockRef>,
) -> bool {
let range = cfg.blocks[header.index()].instrs;
if range.len < 2 {
return false;
}
let Some(call_instr_index) = range.end().checked_sub(2) else {
return false;
};
let Some(LowInstr::GenericForCall(call)) = proto.instrs.get(call_instr_index) else {
return false;
};
let body_block = cfg.instr_to_block[instr.body_target.index()];
let exit_block = cfg.instr_to_block[instr.exit_target.index()];
matches!(call.results, crate::transformer::ResultPack::Fixed(range) if range == instr.bindings)
&& blocks.contains(&body_block)
&& !blocks.contains(&exit_block)
}