use crate::{
analysis::{DominatorTree, TemporalRegion, TemporalRegionGraph},
ir::prelude::*,
ir::InstData,
opt::prelude::*,
value::IntValue,
};
use itertools::Itertools;
use std::collections::HashMap;
pub struct TemporalCodeMotion;
impl Pass for TemporalCodeMotion {
fn run_on_cfg(ctx: &PassContext, unit: &mut UnitBuilder) -> bool {
info!("TCM [{}]", unit.name());
let mut modified = false;
let trg = unit.trg();
let temp_dt = unit.temporal_domtree();
for tr in trg.regions() {
if tr.head_blocks.len() != 1 {
trace!("Skipping {} for prb move (multiple head blocks)", tr.id);
continue;
}
let head_bb = tr.head_blocks().next().unwrap();
let mut hoist = vec![];
for bb in tr.blocks() {
for inst in unit.insts(bb) {
if unit[inst].opcode() == Opcode::Prb
&& unit.get_value_inst(unit[inst].args()[0]).is_none()
{
let mut dominates = temp_dt.dominates(head_bb, bb);
for &user_inst in unit.uses(unit.inst_result(inst)) {
let user_bb = unit.inst_block(user_inst).unwrap();
let dom = temp_dt.dominates(head_bb, user_bb);
dominates &= dom;
}
if dominates {
hoist.push(inst);
} else {
trace!(
"Skipping {} for prb move (would not dominate uses)",
inst.dump(&unit)
);
}
}
}
}
hoist.sort();
for inst in hoist {
if unit.inst_block(inst) == Some(head_bb) {
continue;
}
debug!("Hoisting {} into {}", inst.dump(&unit), head_bb.dump(&unit));
unit.prepend_inst(inst, head_bb);
modified = true;
}
}
let trg = unit.trg();
for tr in trg.regions() {
if tr.tail_insts.len() <= 1 {
trace!("Skipping {} for wait merge (single wait inst)", tr.id);
continue;
}
let mut merge = HashMap::<&InstData, Vec<Inst>>::new();
for inst in tr.tail_insts() {
merge.entry(&unit[inst]).or_default().push(inst);
}
let merge: Vec<_> = merge.into_iter().map(|(_, is)| is).collect();
for insts in merge {
if insts.len() <= 1 {
trace!("Skipping {} (no equivalents)", insts[0].dump(&unit));
continue;
}
trace!("Merging:",);
for i in &insts {
trace!(" {}", i.dump(&unit));
}
let unified_bb = unit.block();
for &inst in &insts {
unit.insert_after(inst);
unit.ins().br(unified_bb);
}
unit.remove_inst(insts[0]);
unit.append_inst(insts[0], unified_bb);
for &inst in &insts[1..] {
unit.delete_inst(inst);
}
modified = true;
}
}
modified |= add_aux_blocks(ctx, unit);
modified |= push_drives(ctx, unit);
modified
}
}
fn add_aux_blocks(_ctx: &PassContext, unit: &mut UnitBuilder) -> bool {
let pt = unit.predtbl();
let trg = unit.trg();
let mut modified = false;
let head_bbs: Vec<_> = unit.blocks().filter(|&bb| trg.is_head(bb)).collect();
for bb in head_bbs {
trace!("Adding aux blocks into {}", bb.dump(&unit));
let tr = trg[bb];
let mut insts_by_region = HashMap::<TemporalRegion, Vec<Inst>>::new();
for pred in pt.pred(bb) {
let pred_tr = trg[pred];
if pred_tr != tr {
let inst = unit.terminator(pred);
insts_by_region.entry(pred_tr).or_default().push(inst);
}
}
for (src_tr, insts) in insts_by_region {
if insts.len() < 2 {
trace!(" Skipping {} (single head inst)", src_tr);
continue;
}
let aux_bb = unit.named_block("aux");
unit.append_to(aux_bb);
unit.ins().br(bb);
trace!(" Adding {} from {}", aux_bb.dump(&unit), src_tr);
for inst in insts {
trace!(" Replacing {} in {}", bb.dump(&unit), inst.dump(&unit));
unit.replace_block_within_inst(bb, aux_bb, inst);
}
modified = true;
}
}
modified
}
fn push_drives(ctx: &PassContext, unit: &mut UnitBuilder) -> bool {
let mut modified = false;
let pt = unit.predtbl();
let dt = unit.domtree_with_predtbl(&pt);
let mut aliases = HashMap::<Value, Value>::new();
let mut drv_seq = HashMap::<Value, Vec<Inst>>::new();
for &bb in dt.blocks_post_order().iter().rev() {
trace!("Checking {} for aliases", bb.dump(&unit));
for inst in unit.insts(bb) {
let data = &unit[inst];
if let Opcode::Drv | Opcode::DrvCond = data.opcode() {
let signal = data.args()[0];
let signal = aliases.get(&signal).cloned().unwrap_or(signal);
trace!(" Drive {} ({})", signal.dump(&unit), inst.dump(&unit));
drv_seq.entry(signal).or_default().push(inst);
} else if let Some(value) = unit.get_inst_result(inst) {
if !unit.value_type(value).is_signal() {
continue;
}
for &arg in data.args() {
if !unit.value_type(arg).is_signal() {
continue;
}
let arg = aliases.get(&arg).cloned().unwrap_or(arg);
trace!(
" Alias {} of {} ({})",
value.dump(&unit),
arg.dump(&unit),
inst.dump(&unit)
);
aliases.insert(value, arg);
}
}
}
}
let trg = unit.trg();
for (&signal, drives) in &drv_seq {
trace!("Moving drives on signal {}", signal.dump(&unit));
for &drive in drives.iter().rev() {
let drive_bb = unit.inst_block(drive).unwrap();
if trg.is_tail(drive_bb) {
trace!(" Skipping {} (already in tail block)", drive.dump(&unit),);
continue;
}
if trg[trg[drive_bb]].tail_blocks.is_empty() {
trace!(" Skipping {} (no tail blocks)", drive.dump(&unit),);
continue;
}
let moved = push_drive(ctx, drive, unit, &dt, &trg);
modified |= moved;
if !moved {
break;
}
}
}
for block in unit.blocks().collect::<Vec<_>>() {
modified |= coalesce_drives(ctx, block, unit);
}
modified
}
fn push_drive(
_ctx: &PassContext,
drive: Inst,
unit: &mut UnitBuilder,
dt: &DominatorTree,
trg: &TemporalRegionGraph,
) -> bool {
let src_bb = unit.inst_block(drive).unwrap();
let tr = trg[src_bb];
let mut moves = Vec::new();
for dst_bb in trg[tr].tail_blocks() {
for &arg in unit[drive].args() {
if !dt.value_dominates_block(unit, arg, dst_bb) {
trace!(
" Skipping {} ({} does not dominate {})",
drive.dump(&unit),
arg.dump(&unit),
dst_bb.dump(&unit)
);
return false;
}
}
let mut src_finger = src_bb;
let mut dst_finger = dst_bb;
let mut conds = Vec::<(Value, bool)>::new();
while src_finger != dst_finger {
let i1 = dt.block_order(src_finger);
let i2 = dt.block_order(dst_finger);
if i1 < i2 {
let parent = dt.dominator(src_finger);
if src_finger == parent {
break;
}
let term = unit.terminator(parent);
if unit[term].opcode() == Opcode::BrCond {
let cond_val = unit[term].args()[0];
if !dt.value_dominates_block(unit, cond_val, dst_bb) {
trace!(
" Skipping {} (branch cond {} does not dominate {})",
drive.dump(&unit),
cond_val.dump(&unit),
dst_bb.dump(&unit)
);
return false;
}
let cond_pol = unit[term].blocks().iter().position(|&bb| bb == src_finger);
if let Some(cond_pol) = cond_pol {
conds.push((cond_val, cond_pol != 0));
trace!(
" {} -> {} ({} == {})",
parent.dump(&unit),
src_finger.dump(&unit),
cond_val.dump(&unit),
cond_pol
);
}
} else {
trace!(" {} -> {}", parent.dump(&unit), src_finger.dump(&unit));
}
src_finger = parent;
} else if i2 < i1 {
let parent = dt.dominator(dst_finger);
if dst_finger == parent {
break;
}
dst_finger = parent;
}
}
if src_finger != dst_finger {
trace!(" Skipping {} (no common dominator)", drive.dump(&unit));
return false;
}
moves.push((dst_bb, conds));
}
for (dst_bb, conds) in moves {
debug!("Moving {} to {}", drive.dump(&unit), dst_bb.dump(&unit));
unit.prepend_to(dst_bb);
let mut cond = unit.ins().const_int(IntValue::all_ones(1));
for (value, polarity) in conds.into_iter().rev() {
let value = match polarity {
true => value,
false => unit.ins().not(value),
};
cond = unit.ins().and(cond, value);
}
if unit[drive].opcode() == Opcode::DrvCond {
let arg = unit[drive].args()[3];
cond = unit.ins().and(cond, arg);
}
let args = unit[drive].args();
let signal = args[0];
let value = args[1];
let delay = args[2];
unit.ins().drv_cond(signal, value, delay, cond);
}
unit.delete_inst(drive);
true
}
fn coalesce_drives(_ctx: &PassContext, block: Block, unit: &mut UnitBuilder) -> bool {
let mut modified = false;
let mut delay_groups = HashMap::<Value, Vec<Inst>>::new();
for inst in unit.insts(block) {
if let Opcode::Drv | Opcode::DrvCond = unit[inst].opcode() {
let delay = unit[inst].args()[2];
delay_groups.entry(delay).or_default().push(inst);
}
}
for (delay, drives) in delay_groups {
let runs: Vec<_> = drives
.into_iter()
.group_by(|&inst| unit[inst].args()[0])
.into_iter()
.map(|(target, drives)| (target, drives.collect::<Vec<_>>()))
.collect();
for (target, drives) in runs {
if drives.len() <= 1 {
continue;
}
debug!(
"Coalescing {} drives on {}",
drives.len(),
target.dump(&unit)
);
let mut drives = drives.into_iter();
let first = drives.next().unwrap();
unit.insert_before(first);
let mut cond = drive_cond(unit, first);
let mut value = unit[first].args()[1];
unit.delete_inst(first);
for drive in drives {
unit.insert_before(drive);
let c = drive_cond(unit, drive);
let v = unit[drive].args()[1];
if cond != c {
cond = unit.ins().or(cond, c);
}
if value != v {
let vs = unit.ins().array(vec![value, v]);
value = unit.ins().mux(vs, c);
}
unit.delete_inst(drive);
}
unit.ins().drv_cond(target, value, delay, cond);
modified = true;
}
}
modified
}
fn drive_cond(unit: &mut UnitBuilder, inst: Inst) -> Value {
if unit[inst].opcode() == Opcode::DrvCond {
unit[inst].args()[3]
} else {
unit.ins().const_int(IntValue::all_ones(1))
}
}