use crate::ir::prelude::*;
use crate::opt::prelude::*;
use crate::{
ir::{DataFlowGraph, FunctionLayout, InstData},
pass::gcse::{DominatorTree, PredecessorTable},
value::IntValue,
};
use itertools::Itertools;
use std::{
collections::{HashMap, HashSet, VecDeque},
ops::Index,
};
pub struct TemporalCodeMotion;
impl Pass for TemporalCodeMotion {
fn run_on_cfg(ctx: &PassContext, unit: &mut impl UnitBuilder) -> bool {
info!("TCM [{}]", unit.unit().name());
let mut modified = false;
let trg = TemporalRegionGraph::new(unit.dfg(), unit.func_layout());
let temp_pt = PredecessorTable::new_temporal(unit.dfg(), unit.func_layout());
let temp_dt = DominatorTree::new(unit.cfg(), unit.func_layout(), &temp_pt);
for tr in &trg.regions {
let dfg = unit.dfg();
let layout = unit.func_layout();
if tr.head_blocks.len() != 1 {
trace!("Skipping {} for prb move (multiple head blocks)", tr.id);
continue;
}
let head_bb = tr.head_blocks().next().unwrap();
let mut hoist = vec![];
for bb in tr.blocks() {
for inst in layout.insts(bb) {
if dfg[inst].opcode() == Opcode::Prb
&& dfg.get_value_inst(dfg[inst].args()[0]).is_none()
{
let mut dominates = temp_dt.dominates(head_bb, bb);
for user_inst in dfg.uses(dfg.inst_result(inst)) {
let user_bb = unit.func_layout().inst_block(user_inst).unwrap();
let dom = temp_dt.dominates(head_bb, user_bb);
dominates &= dom;
}
if dominates {
hoist.push(inst);
} else {
trace!(
"Skipping {} for prb move (would not dominate uses)",
inst.dump(dfg, unit.try_cfg())
);
}
}
}
}
hoist.sort();
for inst in hoist {
if unit.func_layout().inst_block(inst) == Some(head_bb) {
continue;
}
debug!(
"Hoisting {} into {}",
inst.dump(unit.dfg(), unit.try_cfg()),
head_bb.dump(unit.cfg())
);
let layout = unit.func_layout_mut();
layout.remove_inst(inst);
layout.prepend_inst(inst, head_bb);
modified = true;
}
}
let trg = TemporalRegionGraph::new(unit.dfg(), unit.func_layout());
for tr in &trg.regions {
if tr.tail_insts.len() <= 1 {
trace!("Skipping {} for wait merge (single wait inst)", tr.id);
continue;
}
let mut merge = HashMap::<&InstData, Vec<Inst>>::new();
for inst in tr.tail_insts() {
merge.entry(&unit.dfg()[inst]).or_default().push(inst);
}
let merge: Vec<_> = merge.into_iter().map(|(_, is)| is).collect();
for insts in merge {
if insts.len() <= 1 {
trace!(
"Skipping {} (no equivalents)",
insts[0].dump(unit.dfg(), unit.try_cfg())
);
continue;
}
trace!("Merging:",);
for i in &insts {
trace!(" {}", i.dump(unit.dfg(), unit.try_cfg()));
}
let unified_bb = unit.block();
for &inst in &insts {
unit.insert_after(inst);
unit.ins().br(unified_bb);
}
unit.func_layout_mut().remove_inst(insts[0]);
unit.func_layout_mut().append_inst(insts[0], unified_bb);
for &inst in &insts[1..] {
unit.remove_inst(inst);
}
modified = true;
}
}
modified |= add_aux_blocks(ctx, unit);
modified |= push_drives(ctx, unit);
modified
}
}
fn add_aux_blocks(_ctx: &PassContext, unit: &mut impl UnitBuilder) -> bool {
let pt = PredecessorTable::new(unit.dfg(), unit.func_layout());
let trg = TemporalRegionGraph::new(unit.dfg(), unit.func_layout());
let mut modified = false;
let head_bbs: Vec<_> = unit
.func_layout()
.blocks()
.filter(|&bb| trg.is_head(bb))
.collect();
for bb in head_bbs {
trace!("Adding aux blocks into {}", bb.dump(unit.cfg()));
let tr = trg[bb];
let mut insts_by_region = HashMap::<TemporalRegion, Vec<Inst>>::new();
for pred in pt.pred(bb) {
let pred_tr = trg[pred];
if pred_tr != tr {
let inst = unit.func_layout().terminator(pred);
insts_by_region.entry(pred_tr).or_default().push(inst);
}
}
for (src_tr, insts) in insts_by_region {
if insts.len() < 2 {
trace!(" Skipping {} (single head inst)", src_tr);
continue;
}
let aux_bb = unit.named_block("aux");
unit.append_to(aux_bb);
unit.ins().br(bb);
trace!(" Adding {} from {}", aux_bb.dump(unit.cfg()), src_tr);
for inst in insts {
trace!(
" Replacing {} in {}",
bb.dump(unit.cfg()),
inst.dump(unit.dfg(), unit.try_cfg())
);
unit.dfg_mut().replace_block_within_inst(bb, aux_bb, inst);
}
modified = true;
}
}
modified
}
fn push_drives(ctx: &PassContext, unit: &mut impl UnitBuilder) -> bool {
let mut modified = false;
let pt = PredecessorTable::new(unit.dfg(), unit.func_layout());
let dt = DominatorTree::new(unit.cfg(), unit.func_layout(), &pt);
let mut aliases = HashMap::<Value, Value>::new();
let mut drv_seq = HashMap::<Value, Vec<Inst>>::new();
let dfg = unit.dfg();
let cfg = unit.cfg();
for &bb in dt.blocks_post_order().iter().rev() {
trace!("Checking {} for aliases", bb.dump(unit.cfg()));
for inst in unit.func_layout().insts(bb) {
let data = &dfg[inst];
if let Opcode::Drv | Opcode::DrvCond = data.opcode() {
let signal = data.args()[0];
let signal = aliases.get(&signal).cloned().unwrap_or(signal);
trace!(
" Drive {} ({})",
signal.dump(dfg),
inst.dump(dfg, Some(cfg))
);
drv_seq.entry(signal).or_default().push(inst);
} else if let Some(value) = dfg.get_inst_result(inst) {
if !dfg.value_type(value).is_signal() {
continue;
}
for &arg in data.args() {
if !dfg.value_type(arg).is_signal() {
continue;
}
let arg = aliases.get(&arg).cloned().unwrap_or(arg);
trace!(
" Alias {} of {} ({})",
value.dump(dfg),
arg.dump(dfg),
inst.dump(dfg, Some(cfg))
);
aliases.insert(value, arg);
}
}
}
}
let trg = TemporalRegionGraph::new(unit.dfg(), unit.func_layout());
for (&signal, drives) in &drv_seq {
trace!("Moving drives on signal {}", signal.dump(unit.dfg()));
for &drive in drives.iter().rev() {
let drive_bb = unit.func_layout().inst_block(drive).unwrap();
if trg.is_tail(drive_bb) {
trace!(
" Skipping {} (already in tail block)",
drive.dump(unit.dfg(), unit.try_cfg()),
);
continue;
}
if trg[trg[drive_bb]].tail_blocks.is_empty() {
trace!(
" Skipping {} (no tail blocks)",
drive.dump(unit.dfg(), unit.try_cfg()),
);
continue;
}
let moved = push_drive(ctx, drive, unit, &dt, &trg);
modified |= moved;
if !moved {
break;
}
}
}
for block in unit.func_layout().blocks().collect::<Vec<_>>() {
modified |= coalesce_drives(ctx, block, unit);
}
modified
}
fn push_drive(
_ctx: &PassContext,
drive: Inst,
unit: &mut impl UnitBuilder,
dt: &DominatorTree,
trg: &TemporalRegionGraph,
) -> bool {
let dfg = unit.dfg();
let cfg = unit.cfg();
let layout = unit.func_layout();
let src_bb = layout.inst_block(drive).unwrap();
let tr = trg[src_bb];
let mut moves = Vec::new();
for dst_bb in trg[tr].tail_blocks() {
for &arg in dfg[drive].args() {
if !dt.value_dominates_block(dfg, layout, arg, dst_bb) {
trace!(
" Skipping {} ({} does not dominate {})",
drive.dump(dfg, Some(cfg)),
arg.dump(dfg),
dst_bb.dump(cfg)
);
return false;
}
}
let mut src_finger = src_bb;
let mut dst_finger = dst_bb;
let mut conds = Vec::<(Value, bool)>::new();
while src_finger != dst_finger {
let i1 = dt.block_order(src_finger);
let i2 = dt.block_order(dst_finger);
if i1 < i2 {
let parent = dt.dominator(src_finger);
if src_finger == parent {
break;
}
let term = layout.terminator(parent);
if dfg[term].opcode() == Opcode::BrCond {
let cond_val = dfg[term].args()[0];
if !dt.value_dominates_block(dfg, layout, cond_val, dst_bb) {
trace!(
" Skipping {} (branch cond {} does not dominate {})",
drive.dump(dfg, Some(cfg)),
cond_val.dump(dfg),
dst_bb.dump(cfg)
);
return false;
}
let cond_pol = dfg[term].blocks().iter().position(|&bb| bb == src_finger);
if let Some(cond_pol) = cond_pol {
conds.push((cond_val, cond_pol != 0));
trace!(
" {} -> {} ({} == {})",
parent.dump(cfg),
src_finger.dump(cfg),
cond_val.dump(dfg),
cond_pol
);
}
} else {
trace!(" {} -> {}", parent.dump(cfg), src_finger.dump(cfg));
}
src_finger = parent;
} else if i2 < i1 {
let parent = dt.dominator(dst_finger);
if dst_finger == parent {
break;
}
dst_finger = parent;
}
}
if src_finger != dst_finger {
trace!(
" Skipping {} (no common dominator)",
drive.dump(dfg, Some(cfg))
);
return false;
}
moves.push((dst_bb, conds));
}
for (dst_bb, conds) in moves {
debug!(
"Moving {} to {}",
drive.dump(unit.dfg(), unit.try_cfg()),
dst_bb.dump(unit.cfg())
);
unit.prepend_to(dst_bb);
let mut cond = unit.ins().const_int(IntValue::all_ones(1));
for (value, polarity) in conds.into_iter().rev() {
let value = match polarity {
true => value,
false => unit.ins().not(value),
};
cond = unit.ins().and(cond, value);
}
if unit.dfg()[drive].opcode() == Opcode::DrvCond {
let arg = unit.dfg()[drive].args()[3];
cond = unit.ins().and(cond, arg);
}
let args = unit.dfg()[drive].args();
let signal = args[0];
let value = args[1];
let delay = args[2];
unit.ins().drv_cond(signal, value, delay, cond);
}
unit.remove_inst(drive);
true
}
fn coalesce_drives(_ctx: &PassContext, block: Block, unit: &mut impl UnitBuilder) -> bool {
let mut modified = false;
let dfg = unit.dfg();
let mut delay_groups = HashMap::<Value, Vec<Inst>>::new();
for inst in unit.func_layout().insts(block) {
if let Opcode::Drv | Opcode::DrvCond = dfg[inst].opcode() {
let delay = dfg[inst].args()[2];
delay_groups.entry(delay).or_default().push(inst);
}
}
for (delay, drives) in delay_groups {
let runs: Vec<_> = drives
.into_iter()
.group_by(|&inst| unit.dfg()[inst].args()[0])
.into_iter()
.map(|(target, drives)| (target, drives.collect::<Vec<_>>()))
.collect();
for (target, drives) in runs {
if drives.len() <= 1 {
continue;
}
debug!(
"Coalescing {} drives on {}",
drives.len(),
target.dump(unit.dfg())
);
let mut drives = drives.into_iter();
let first = drives.next().unwrap();
unit.insert_before(first);
let mut cond = drive_cond(unit, first);
let mut value = unit.dfg()[first].args()[1];
unit.remove_inst(first);
for drive in drives {
unit.insert_before(drive);
let c = drive_cond(unit, drive);
let v = unit.dfg()[drive].args()[1];
if cond != c {
cond = unit.ins().or(cond, c);
}
if value != v {
let vs = unit.ins().array(vec![value, v]);
value = unit.ins().mux(vs, c);
}
unit.remove_inst(drive);
}
unit.ins().drv_cond(target, value, delay, cond);
modified = true;
}
}
modified
}
fn drive_cond(unit: &mut impl UnitBuilder, inst: Inst) -> Value {
if unit.dfg()[inst].opcode() == Opcode::DrvCond {
unit.dfg()[inst].args()[3]
} else {
unit.ins().const_int(IntValue::all_ones(1))
}
}
#[derive(Debug)]
pub struct TemporalRegionGraph {
blocks: HashMap<Block, TemporalRegion>,
regions: Vec<TemporalRegionData>,
}
impl TemporalRegionGraph {
pub fn new(dfg: &DataFlowGraph, layout: &FunctionLayout) -> Self {
let mut todo = VecDeque::new();
let mut seen = HashSet::new();
todo.push_back(layout.entry());
seen.insert(layout.entry());
for bb in layout.blocks() {
let term = layout.terminator(bb);
if dfg[term].opcode().is_temporal() {
for &target in dfg[term].blocks() {
if seen.insert(target) {
todo.push_back(target);
}
}
}
}
let mut next_id = 0;
let mut blocks = HashMap::<Block, TemporalRegion>::new();
let mut head_blocks = HashSet::new();
let mut tail_blocks = HashSet::new();
let mut breaks = vec![];
for &bb in &todo {
blocks.insert(bb, TemporalRegion(next_id));
head_blocks.insert(bb);
next_id += 1;
}
while let Some(bb) = todo.pop_front() {
let tr = blocks[&bb];
let term = layout.terminator(bb);
if dfg[term].opcode().is_temporal() {
breaks.push(term);
tail_blocks.insert(bb);
continue;
}
for &target in dfg[term].blocks() {
if seen.insert(target) {
todo.push_back(target);
if blocks.insert(target, tr).is_some() {
let tr = TemporalRegion(next_id);
blocks.insert(target, tr);
head_blocks.insert(target);
tail_blocks.insert(bb);
next_id += 1;
}
}
}
}
let mut regions: Vec<_> = (0..next_id)
.map(|id| TemporalRegionData {
id: TemporalRegion(id),
blocks: Default::default(),
entry: false,
head_insts: Default::default(),
head_blocks: Default::default(),
head_tight: true,
tail_insts: Default::default(),
tail_blocks: Default::default(),
tail_tight: true,
})
.collect();
regions[blocks[&layout.entry()].0].entry = true;
let pt = PredecessorTable::new(dfg, layout);
for (&bb, &id) in &blocks {
let mut reg = &mut regions[id.0];
reg.blocks.insert(bb);
let mut is_head = head_blocks.contains(&bb);
let mut is_tight = true;
for pred in pt.pred(bb) {
let diff_trs = blocks[&pred] != id;
is_head |= diff_trs;
is_tight &= diff_trs;
}
if is_head {
reg.head_blocks.insert(bb);
reg.head_tight &= is_tight;
}
let mut is_tail = tail_blocks.contains(&bb);
let mut is_tight = true;
for succ in pt.succ(bb) {
let diff_trs = blocks[&succ] != id;
is_tail |= diff_trs;
is_tight &= diff_trs;
}
if is_tail {
reg.tail_blocks.insert(bb);
reg.tail_tight &= is_tight;
}
for pred in pt.pred(bb) {
if blocks[&pred] != id {
reg.head_insts.insert(layout.terminator(pred));
}
}
let term = layout.terminator(bb);
if dfg[term].blocks().iter().any(|bb| blocks[bb] != id) {
reg.tail_insts.insert(term);
}
}
Self { blocks, regions }
}
pub fn is_head(&self, bb: Block) -> bool {
self[self[bb]].is_head(bb)
}
pub fn is_tail(&self, bb: Block) -> bool {
self[self[bb]].is_tail(bb)
}
pub fn regions(&self) -> impl Iterator<Item = (TemporalRegion, &TemporalRegionData)> {
self.regions
.iter()
.enumerate()
.map(|(i, tr)| (TemporalRegion(i), tr))
}
}
impl Index<TemporalRegion> for TemporalRegionGraph {
type Output = TemporalRegionData;
fn index(&self, idx: TemporalRegion) -> &Self::Output {
&self.regions[idx.0]
}
}
impl Index<Block> for TemporalRegionGraph {
type Output = TemporalRegion;
fn index(&self, idx: Block) -> &Self::Output {
&self.blocks[&idx]
}
}
#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct TemporalRegion(usize);
impl std::fmt::Display for TemporalRegion {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(f, "t{}", self.0)
}
}
impl std::fmt::Debug for TemporalRegion {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(f, "{}", self)
}
}
#[derive(Debug, Clone)]
pub struct TemporalRegionData {
pub id: TemporalRegion,
pub blocks: HashSet<Block>,
pub entry: bool,
pub head_insts: HashSet<Inst>,
pub head_blocks: HashSet<Block>,
pub head_tight: bool,
pub tail_insts: HashSet<Inst>,
pub tail_blocks: HashSet<Block>,
pub tail_tight: bool,
}
impl TemporalRegionData {
pub fn blocks(&self) -> impl Iterator<Item = Block> + Clone + '_ {
self.blocks.iter().cloned()
}
pub fn head_insts(&self) -> impl Iterator<Item = Inst> + Clone + '_ {
self.head_insts.iter().cloned()
}
pub fn head_blocks(&self) -> impl Iterator<Item = Block> + Clone + '_ {
self.head_blocks.iter().cloned()
}
pub fn tail_insts(&self) -> impl Iterator<Item = Inst> + Clone + '_ {
self.tail_insts.iter().cloned()
}
pub fn tail_blocks(&self) -> impl Iterator<Item = Block> + Clone + '_ {
self.tail_blocks.iter().cloned()
}
pub fn is_head(&self, bb: Block) -> bool {
self.head_blocks.contains(&bb)
}
pub fn is_tail(&self, bb: Block) -> bool {
self.tail_blocks.contains(&bb)
}
}