use crate::ir::prelude::*;
use crate::ir::{DataFlowGraph, FunctionLayout, InstData};
use crate::opt::prelude::*;
use crate::pass::gcse::{DominatorTree, PredecessorTable};
use std::{
collections::{HashMap, HashSet, VecDeque},
ops::Index,
};
pub struct TemporalCodeMotion;
impl Pass for TemporalCodeMotion {
fn run_on_cfg(_ctx: &PassContext, unit: &mut impl UnitBuilder) -> bool {
info!("TCM [{}]", unit.unit().name());
let mut modified = false;
let trg = TemporalRegionGraph::new(unit.dfg(), unit.func_layout());
let temp_pt = PredecessorTable::new_temporal(unit.dfg(), unit.func_layout());
let temp_dt = DominatorTree::new(unit.func_layout(), &temp_pt);
for tr in &trg.regions {
let dfg = unit.dfg();
let layout = unit.func_layout();
if tr.head_blocks.len() != 1 {
trace!("Skipping {} for prb move (multiple head blocks)", tr.id);
continue;
}
let head_bb = tr.head_blocks().next().unwrap();
let mut hoist = vec![];
for bb in tr.blocks() {
for inst in layout.insts(bb) {
if dfg[inst].opcode() == Opcode::Prb
&& dfg.get_value_inst(dfg[inst].args()[0]).is_none()
{
let mut dominates = temp_dt.dominates(head_bb, bb);
for (user_inst, _) in dfg.uses(dfg.inst_result(inst)) {
let user_bb = unit.func_layout().inst_block(user_inst).unwrap();
let dom = temp_dt.dominates(head_bb, user_bb);
dominates &= dom;
}
if dominates {
hoist.push(inst);
} else {
trace!(
"Skipping {} for prb move (would not dominate uses)",
inst.dump(dfg, unit.try_cfg())
);
}
}
}
}
hoist.sort();
for inst in hoist {
debug!(
"Hoisting {} into {}",
inst.dump(unit.dfg(), unit.try_cfg()),
head_bb.dump(unit.cfg())
);
let layout = unit.func_layout_mut();
layout.remove_inst(inst);
layout.prepend_inst(inst, head_bb);
modified = true;
}
}
for i in 0..100 {
let mut changes = false;
debug!("Moving `drv` iteration {}", i);
let inner_trg = TemporalRegionGraph::new(unit.dfg(), unit.func_layout());
assert_eq!(
inner_trg.regions.len(),
trg.regions.len(),
"{:#?}",
inner_trg
);
let pred = PredecessorTable::new(unit.dfg(), unit.func_layout());
changes |= diverge_drives(unit, &inner_trg, &pred);
let pred = PredecessorTable::new(unit.dfg(), unit.func_layout());
let dt = DominatorTree::new(unit.func_layout(), &pred);
changes |= reconverge_drives(unit, &inner_trg, &dt, &pred);
modified |= changes;
if !changes {
break;
}
}
let trg = TemporalRegionGraph::new(unit.dfg(), unit.func_layout());
for tr in &trg.regions {
if tr.tail_insts.len() <= 1 {
trace!("Skipping {} for wait merge (single wait inst)", tr.id);
continue;
}
let mut merge = HashMap::<&InstData, Vec<Inst>>::new();
for inst in tr.tail_insts() {
merge.entry(&unit.dfg()[inst]).or_default().push(inst);
}
let merge: Vec<_> = merge.into_iter().map(|(_, is)| is).collect();
for insts in merge {
if insts.len() <= 1 {
trace!(
"Skipping {} (no equivalents)",
insts[0].dump(unit.dfg(), unit.try_cfg())
);
continue;
}
trace!("Merging:",);
for i in &insts {
trace!(" {}", i.dump(unit.dfg(), unit.try_cfg()));
}
let unified_bb = unit.block();
for &inst in &insts {
unit.insert_after(inst);
unit.ins().br(unified_bb);
}
unit.func_layout_mut().remove_inst(insts[0]);
unit.func_layout_mut().append_inst(insts[0], unified_bb);
for &inst in &insts[1..] {
unit.remove_inst(inst);
}
modified = true;
}
}
modified
}
}
fn diverge_drives(
unit: &mut impl UnitBuilder,
trg: &TemporalRegionGraph,
pt: &PredecessorTable,
) -> bool {
let mut relocs = HashSet::new();
let mut worklist = vec![];
let mut driven_sigs = HashSet::new();
let mut drive_order = HashMap::new();
for bb in unit.func_layout().blocks() {
let mut order_id = 0;
for inst in unit.func_layout().insts(bb) {
if unit.dfg()[inst].opcode() == Opcode::Drv {
let reloc = (bb, inst);
relocs.insert(reloc);
worklist.push(reloc);
driven_sigs.insert((bb, unit.dfg()[inst].args()[0]));
drive_order.insert(inst, order_id);
order_id += 1;
}
}
}
trace!("Considering drv diverges:");
let mut helper_blocks = HashSet::new();
while let Some((bb, inst)) = worklist.pop() {
let sig = unit.dfg()[inst].args()[0];
if trg.is_tail(bb) {
trace!(
" Skipping {} (in temporal tail)",
inst.dump(unit.dfg(), unit.try_cfg())
);
continue;
}
let diverging = pt.succ_set(bb).len() > 1;
if !diverging {
let succ = pt.succ(bb).next().unwrap();
if pt.is_sole_pred(bb, succ) && !driven_sigs.contains(&(succ, sig)) {
let reloc = (succ, inst);
relocs.remove(&(bb, inst));
relocs.insert(reloc);
worklist.push(reloc);
trace!(
" Pushing {} into non-diverging succ {}",
inst.dump(unit.dfg(), unit.try_cfg()),
succ.dump(unit.cfg())
);
}
}
else {
relocs.remove(&(bb, inst));
for succ in pt.succ(bb) {
if driven_sigs.contains(&(succ, sig)) {
continue;
}
if pt.is_sole_pred(bb, succ) {
let reloc = (succ, inst);
relocs.insert(reloc);
worklist.push(reloc);
trace!(
" Pushing {} into diverging succ {}",
inst.dump(unit.dfg(), unit.try_cfg()),
succ.dump(unit.cfg())
);
}
else {
trace!(
" Pushing {} into helper block from {} to {}",
inst.dump(unit.dfg(), unit.try_cfg()),
bb.dump(unit.cfg()),
succ.dump(unit.cfg())
);
helper_blocks.insert((bb, succ, inst));
}
}
}
}
let mut relocated = HashSet::new();
let mut skip = HashSet::new();
let mut lookup = HashMap::<(Block, Value, Value), Inst>::new();
let dfg = unit.dfg();
for &(into_bb, inst) in &relocs {
let sig = dfg[inst].args()[0];
let delay = dfg[inst].args()[2];
let key = (into_bb, sig, delay);
if let Some(&other) = lookup.get(&key) {
let inst_bb = unit.func_layout().inst_block(inst).unwrap();
let other_bb = unit.func_layout().inst_block(other).unwrap();
trace!(
"Double drive {} in {}",
inst.dump(dfg, unit.try_cfg()),
into_bb.dump(unit.cfg())
);
if inst_bb == other_bb {
if drive_order[&inst] > drive_order[&other] {
debug!(
"Removing overdriven {} in {}",
other.dump(dfg, unit.try_cfg()),
other_bb.dump(unit.cfg())
);
skip.insert((into_bb, other));
relocated.insert(other);
lookup.insert(key, inst);
} else {
debug!(
"Removing overdriven {} in {}",
inst.dump(dfg, unit.try_cfg()),
inst_bb.dump(unit.cfg())
);
skip.insert((into_bb, inst));
relocated.insert(inst);
}
}
else {
panic!("Cannot resolve double drive originating in different bbs");
}
} else {
lookup.insert(key, inst);
}
}
let mut modified = false;
for (into_bb, inst) in relocs {
let bb = unit.func_layout().inst_block(inst).unwrap();
if into_bb == bb || skip.contains(&(into_bb, inst)) {
continue;
}
debug!(
"Moving {} into {}",
inst.dump(unit.dfg(), unit.try_cfg()),
into_bb.dump(unit.cfg())
);
modified |= true;
relocated.insert(inst);
let dfg = unit.dfg();
let sig = dfg[inst].args()[0];
let value = dfg[inst].args()[1];
let delay = dfg[inst].args()[2];
unit.prepend_to(into_bb);
unit.ins().drv(sig, value, delay);
}
let mut helper_cache = HashMap::new();
for (from_bb, to_bb, inst) in helper_blocks {
debug!(
"Moving {} into helper from {} to {}",
inst.dump(unit.dfg(), unit.try_cfg()),
from_bb.dump(unit.cfg()),
to_bb.dump(unit.cfg())
);
modified |= true;
relocated.insert(inst);
let helper = *helper_cache.entry((from_bb, to_bb)).or_insert_with(|| {
let helper = unit.block();
unit.append_to(helper);
unit.ins().br(to_bb);
let term = unit.func_layout().terminator(from_bb);
unit.dfg_mut()[term].replace_block(to_bb, helper);
helper
});
let dfg = unit.dfg();
let sig = dfg[inst].args()[0];
let value = dfg[inst].args()[1];
let delay = dfg[inst].args()[2];
unit.insert_before(unit.func_layout().terminator(helper));
unit.ins().drv(sig, value, delay);
}
for inst in relocated {
unit.remove_inst(inst);
modified |= true;
}
modified
}
fn reconverge_drives(
unit: &mut impl UnitBuilder,
trg: &TemporalRegionGraph,
_dt: &DominatorTree,
pt: &PredecessorTable,
) -> bool {
let mut modified = false;
for tr in &trg.regions {
let dfg = unit.dfg();
let layout = unit.func_layout();
let mut drvs = HashMap::<(Value, Value), HashSet<Inst>>::new();
for bb in tr.blocks() {
for inst in layout.insts(bb) {
if dfg[inst].opcode() == Opcode::Drv {
drvs.entry((dfg[inst].args()[0], dfg[inst].args()[2]))
.or_default()
.insert(inst);
}
}
}
trace!("Considering drv reconverges:");
let mut candidates = HashMap::<Block, Vec<(Value, Value, Vec<Inst>)>>::new();
for ((sig, del), drvs) in drvs {
let mut into = HashMap::<Block, (Vec<Inst>, HashSet<Block>)>::new();
for drv in drvs {
let drv_bb = layout.inst_block(drv).unwrap();
if trg.is_tail(drv_bb) {
trace!(
" Skipping {} (in temporal tail)",
drv.dump(unit.dfg(), unit.try_cfg())
);
continue;
}
let succ = pt.succ_set(drv_bb);
if succ.len() == 1 {
let e = into.entry(*succ.iter().next().unwrap()).or_default();
(e.0).push(drv);
(e.1).insert(drv_bb);
trace!(" Considering {} ", drv.dump(unit.dfg(), unit.try_cfg()));
} else {
trace!(
" Skipping {} (divergent control flow)",
drv.dump(unit.dfg(), unit.try_cfg())
);
}
}
for (into_bb, (insts, from_bbs)) in into {
let pred_set = pt.pred_set(into_bb);
if from_bbs == *pred_set {
candidates
.entry(into_bb)
.or_default()
.push((sig, del, insts));
} else {
trace!(
" Skipping merge of {:?} into {} (predecessors {:?} not fully covered by {:?})",
insts,
into_bb.dump(unit.cfg()),
pred_set, from_bbs
);
}
}
}
for (into_bb, candidates) in candidates {
unit.prepend_to(into_bb);
for (sig, delay, insts) in candidates {
let dfg = unit.dfg();
let layout = unit.func_layout();
debug!(
"Grouping {} drives {:?} into {}",
sig.dump(dfg),
insts,
into_bb.dump(unit.cfg()),
);
let mut phi_args = vec![];
let mut phi_blocks = vec![];
for &inst in &insts {
phi_args.push(dfg[inst].args()[1]);
phi_blocks.push(layout.inst_block(inst).unwrap());
}
let phi_blocks_unique: HashSet<_> = phi_blocks.iter().cloned().collect();
assert_eq!(
phi_blocks_unique.len(),
phi_blocks.len(),
"merging multiple drives to {} from the same origin bb should never happen: {:?}",
sig.dump(dfg),
phi_blocks
);
let homogenous = phi_args.iter().all(|&x| x == phi_args[0]);
let phi = if homogenous {
trace!("Using single value {}", phi_args[0].dump(unit.dfg()));
phi_args[0]
} else {
trace!("Add phi node in {} with arms:", into_bb.dump(unit.cfg()));
for (v, bb) in phi_args.iter().zip(phi_blocks.iter()) {
trace!(" [{}, {}]", v.dump(dfg), bb.dump(unit.cfg()));
}
unit.ins().phi(phi_args, phi_blocks)
};
unit.ins().drv(sig, phi, delay);
for inst in insts {
unit.remove_inst(inst);
}
modified = true;
}
}
}
modified
}
#[derive(Debug)]
pub struct TemporalRegionGraph {
blocks: HashMap<Block, TemporalRegion>,
regions: Vec<TemporalRegionData>,
}
impl TemporalRegionGraph {
pub fn new(dfg: &DataFlowGraph, layout: &FunctionLayout) -> Self {
trace!("Constructing TRG:");
let mut todo = VecDeque::new();
let mut seen = HashSet::new();
todo.push_back(layout.entry());
seen.insert(layout.entry());
trace!(" Root {:?} (entry)", layout.entry());
for bb in layout.blocks() {
let term = layout.terminator(bb);
if dfg[term].opcode().is_temporal() {
for &target in dfg[term].blocks() {
if seen.insert(target) {
trace!(" Root {:?} (wait target)", target);
todo.push_back(target);
}
}
}
}
let mut next_id = 0;
let mut blocks = HashMap::<Block, TemporalRegion>::new();
let mut head_blocks = HashSet::new();
let mut tail_blocks = HashSet::new();
let mut breaks = vec![];
for &bb in &todo {
blocks.insert(bb, TemporalRegion(next_id));
head_blocks.insert(bb);
next_id += 1;
}
while let Some(bb) = todo.pop_front() {
let tr = blocks[&bb];
trace!(" Pushing {:?} ({})", bb, tr);
let term = layout.terminator(bb);
if dfg[term].opcode().is_temporal() {
breaks.push(term);
tail_blocks.insert(bb);
continue;
}
for &target in dfg[term].blocks() {
if seen.insert(target) {
todo.push_back(target);
trace!(" Assigning {:?} <- {:?}", target, tr);
if blocks.insert(target, tr).is_some() {
let tr = TemporalRegion(next_id);
blocks.insert(target, tr);
head_blocks.insert(target);
tail_blocks.insert(bb);
trace!(" Assigning {:?} <- {:?} (override)", target, tr);
next_id += 1;
}
}
}
}
trace!(" Blocks: {:#?}", blocks);
let mut regions: Vec<_> = (0..next_id)
.map(|id| TemporalRegionData {
id: TemporalRegion(id),
blocks: Default::default(),
head_insts: Default::default(),
head_blocks: Default::default(),
tail_insts: Default::default(),
tail_blocks: Default::default(),
})
.collect();
for (&bb, &id) in &blocks {
regions[id.0].blocks.insert(bb);
}
for &inst in &breaks {
let bb = layout.inst_block(inst).unwrap();
for &target in dfg[inst].blocks() {
let data = &mut regions[blocks[&target].0];
data.head_insts.insert(inst);
}
let data = &mut regions[blocks[&bb].0];
data.tail_insts.insert(inst);
}
for bb in head_blocks {
regions[blocks[&bb].0].head_blocks.insert(bb);
}
for bb in tail_blocks {
regions[blocks[&bb].0].tail_blocks.insert(bb);
}
Self { blocks, regions }
}
pub fn is_head(&self, bb: Block) -> bool {
self[self[bb]].is_head(bb)
}
pub fn is_tail(&self, bb: Block) -> bool {
self[self[bb]].is_tail(bb)
}
}
impl Index<TemporalRegion> for TemporalRegionGraph {
type Output = TemporalRegionData;
fn index(&self, idx: TemporalRegion) -> &Self::Output {
&self.regions[idx.0]
}
}
impl Index<Block> for TemporalRegionGraph {
type Output = TemporalRegion;
fn index(&self, idx: Block) -> &Self::Output {
&self.blocks[&idx]
}
}
#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct TemporalRegion(usize);
impl std::fmt::Display for TemporalRegion {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(f, "t{}", self.0)
}
}
impl std::fmt::Debug for TemporalRegion {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(f, "{}", self)
}
}
#[derive(Debug, Clone)]
pub struct TemporalRegionData {
pub id: TemporalRegion,
pub blocks: HashSet<Block>,
pub head_insts: HashSet<Inst>,
pub head_blocks: HashSet<Block>,
pub tail_insts: HashSet<Inst>,
pub tail_blocks: HashSet<Block>,
}
impl TemporalRegionData {
pub fn blocks(&self) -> impl Iterator<Item = Block> + Clone + '_ {
self.blocks.iter().cloned()
}
pub fn head_insts(&self) -> impl Iterator<Item = Inst> + Clone + '_ {
self.head_insts.iter().cloned()
}
pub fn head_blocks(&self) -> impl Iterator<Item = Block> + Clone + '_ {
self.head_blocks.iter().cloned()
}
pub fn tail_insts(&self) -> impl Iterator<Item = Inst> + Clone + '_ {
self.tail_insts.iter().cloned()
}
pub fn tail_blocks(&self) -> impl Iterator<Item = Block> + Clone + '_ {
self.tail_blocks.iter().cloned()
}
pub fn is_head(&self, bb: Block) -> bool {
self.head_blocks.contains(&bb)
}
pub fn is_tail(&self, bb: Block) -> bool {
self.tail_blocks.contains(&bb)
}
}