use pliron_derive::op_interface;
use rustc_hash::FxHashSet;
use crate::{
basic_block::BasicBlock,
builtin::op_interfaces::BranchOpInterface,
context::{Context, Ptr},
graph::{
HasLabel,
walkers::{IRNode, WALKCONFIG_PREORDER_FORWARD, uninterruptible::mutable::walk_op},
},
irbuild::{
inserter::Inserter,
listener::{Recorder, RecorderEvent},
rewriter::{IRRewriter, Rewriter},
},
op::{Op, op_cast, op_impls},
operation::{OpDbg, Operation},
opts::OptStatus,
printable::Printable,
result::Result,
value::{DefiningEntity, Value},
};
#[op_interface]
pub trait SideEffects {
fn has_side_effects(&self, ctx: &Context) -> bool;
fn verify(_op: &dyn Op, _ctx: &Context) -> Result<()>
where
Self: Sized,
{
Ok(())
}
}
#[op_interface]
pub trait BlockArgRemoval {
fn can_remove_block_args(&self, ctx: &Context, block: Ptr<BasicBlock>) -> bool;
fn verify(_op: &dyn Op, _ctx: &Context) -> Result<()>
where
Self: Sized,
{
Ok(())
}
}
#[derive(Clone, Copy, PartialEq, Eq, Hash)]
enum DCECandidate {
Op(Ptr<Operation>),
BlockArg(Value),
}
impl Printable for DCECandidate {
fn fmt(
&self,
ctx: &Context,
_state: &crate::printable::State,
f: &mut std::fmt::Formatter<'_>,
) -> std::fmt::Result {
match self {
DCECandidate::Op(op) => write!(f, "Operation {}", OpDbg { op: *op, ctx }),
DCECandidate::BlockArg(arg) => {
let block = arg.defining_block().expect("Expected a block argument");
write!(
f,
"Block argument {} of block {}",
arg.find_index(ctx),
block.label(ctx)
)
}
}
}
}
fn is_safe_to_erase(cand: DCECandidate, ctx: &Context) -> bool {
match cand {
DCECandidate::Op(def_op) => {
let def_op_ref = def_op.deref(ctx);
if def_op_ref.has_use() {
return false;
}
let def_op_dyn = Operation::get_op_dyn(def_op, ctx);
let has_side_effects = match op_cast::<dyn SideEffects>(&*def_op_dyn) {
Some(side_effects_op) => side_effects_op.has_side_effects(ctx),
None => true, };
!has_side_effects
}
DCECandidate::BlockArg(arg) => {
let block = arg.defining_block().expect("Expected a block argument");
if arg.is_used(ctx) {
return false;
}
let Some(parent_op) = block.deref(ctx).get_parent_op(ctx) else {
return false;
};
let parent_op_dyn = Operation::get_op_dyn(parent_op, ctx);
let Some(block_arg_removal_op) = op_cast::<dyn BlockArgRemoval>(&*parent_op_dyn) else {
return false;
};
if !block_arg_removal_op.can_remove_block_args(ctx, block) {
return false;
}
block.preds(ctx).iter().all(|pred| {
let Some(pred_terminator) = pred.deref(ctx).get_terminator(ctx) else {
return false;
};
let pred_terminator_dyn = Operation::get_op_dyn(pred_terminator, ctx);
op_impls::<dyn crate::builtin::op_interfaces::BranchOpInterface>(
&*pred_terminator_dyn,
)
})
}
}
}
fn note_erased_ops(
recorder: &mut Recorder,
erased_ops: &mut FxHashSet<Ptr<Operation>>,
erased_blocks: &mut FxHashSet<Ptr<BasicBlock>>,
) {
for event in recorder.events.drain(..) {
match event {
RecorderEvent::ErasedOperation(op) => {
erased_ops.insert(op);
}
RecorderEvent::ErasedBlock(block) => {
erased_blocks.insert(block);
}
RecorderEvent::ErasedRegion(_) => {
}
RecorderEvent::UnlinkedBlock(_, _)
| RecorderEvent::InsertedOperation(_)
| RecorderEvent::InsertedBlock(_)
| RecorderEvent::ReplacedValueUses { .. }
| RecorderEvent::ValueTypeChanged { .. }
| RecorderEvent::UnlinkedOperation(_, _) => {
panic!("Unexpected event in DCE recorder: {:?}", event);
}
}
}
}
pub fn dce(op: Ptr<Operation>, ctx: &mut Context) -> Result<OptStatus> {
let mut rewriter = IRRewriter::<Recorder>::default();
let mut cemetery: Vec<DCECandidate> = Vec::new();
let mut erased_blocks: FxHashSet<Ptr<BasicBlock>> = FxHashSet::default();
let mut erased_ops: FxHashSet<Ptr<Operation>> = FxHashSet::default();
walk_op(
ctx,
&mut cemetery,
&WALKCONFIG_PREORDER_FORWARD,
op,
|ctx, cemetery, ir_node| match ir_node {
IRNode::Operation(opr) => {
let cand = DCECandidate::Op(opr);
if is_safe_to_erase(cand, ctx) {
log::trace!("Adding to DCE cemetery: {}", cand.disp(ctx));
cemetery.push(cand);
}
}
IRNode::BasicBlock(block) => {
for arg in block.deref(ctx).arguments() {
let cand = DCECandidate::BlockArg(arg);
if is_safe_to_erase(cand, ctx) {
log::trace!("Adding to DCE cemetery: {}", cand.disp(ctx));
cemetery.push(cand);
}
}
}
IRNode::Region(_) => {}
},
);
let mut modified = OptStatus::IRUnchanged;
while let Some(dead) = cemetery.pop() {
let operands_of_dead = match dead {
DCECandidate::BlockArg(arg) => {
let block = arg.defining_block().expect("Expected a block argument");
if erased_blocks.contains(&block) {
continue;
}
let opd_idx = arg.find_index(ctx);
let successor_operands = block
.uses(ctx)
.iter()
.map(|pred| {
let succ_idx = pred.find_index(ctx);
let pred_terminator_dyn = Operation::get_op_dyn(pred.user_op, ctx);
let branch_interface =
op_cast::<dyn BranchOpInterface>(&*pred_terminator_dyn)
.expect("Terminator must implement BranchOpInterface");
branch_interface.remove_successor_operand(ctx, succ_idx, opd_idx)
})
.collect::<Vec<_>>();
log::trace!(
"Erasing block argument {} of block {}",
opd_idx,
block.label(ctx)
);
BasicBlock::remove_argument(block, ctx, opd_idx);
successor_operands
}
DCECandidate::Op(dead_op) => {
if erased_ops.contains(&dead_op) {
continue;
}
let defining_vals: Vec<Value> = dead_op.deref(ctx).operands().collect();
log::trace!("Erasing dead operation: {}", OpDbg { op: dead_op, ctx });
rewriter.erase_operation(ctx, dead_op);
note_erased_ops(
rewriter.get_listener_mut(),
&mut erased_ops,
&mut erased_blocks,
);
defining_vals
}
};
modified = OptStatus::IRChanged;
for def_val in operands_of_dead {
let dce_cand = match def_val.defining_entity() {
DefiningEntity::Op(def_op) => DCECandidate::Op(def_op),
DefiningEntity::Block(_) => DCECandidate::BlockArg(def_val),
};
if is_safe_to_erase(dce_cand, ctx) {
log::trace!("Adding to DCE cemetery: {}", dce_cand.disp(ctx));
cemetery.push(dce_cand);
}
}
}
Ok(modified)
}