use alloc::vec::Vec;
use midenc_hir::{
adt::SmallDenseMap,
dominance::DominanceInfo,
matchers::{self, Matcher},
pass::{Pass, PassExecutionState, PostPassStatus},
traits::{ConstantLike, Terminator},
Backward, Builder, EntityMut, Forward, FxHashSet, OpBuilder, Operation, OperationName,
OperationRef, ProgramPoint, RawWalk, Region, RegionBranchOpInterface,
RegionBranchTerminatorOpInterface, RegionRef, Report, SmallVec, Usable, ValueRef,
};
pub struct ControlFlowSink;
impl Pass for ControlFlowSink {
type Target = Operation;
fn name(&self) -> &'static str {
"control-flow-sink"
}
fn argument(&self) -> &'static str {
"control-flow-sink"
}
fn can_schedule_on(&self, _name: &OperationName) -> bool {
true
}
fn run_on_operation(
&mut self,
op: EntityMut<'_, Self::Target>,
state: &mut PassExecutionState,
) -> Result<(), Report> {
let op = op.into_entity_ref();
log::debug!(target: "control-flow-sink", "sinking operations in {op}");
let operation = op.as_operation_ref();
drop(op);
let dominfo = state.analysis_manager().get_analysis::<DominanceInfo>()?;
let mut sunk = PostPassStatus::Unchanged;
operation.raw_prewalk_all::<Forward, _>(|op: OperationRef| {
let regions_to_sink = {
let op = op.borrow();
let Some(branch) = op.as_trait::<dyn RegionBranchOpInterface>() else {
return;
};
let mut regions = SmallVec::<[_; 4]>::default();
get_singly_executed_regions_to_sink(branch, &mut regions);
regions
};
sunk = control_flow_sink(
®ions_to_sink,
&dominfo,
|op: &Operation, _region: &Region| op.is_memory_effect_free(),
|mut op: OperationRef, region: RegionRef| {
let entry_block = region.borrow().entry_block_ref().unwrap();
op.borrow_mut().move_to(ProgramPoint::at_start_of(entry_block));
},
);
});
state.set_post_pass_status(sunk);
Ok(())
}
}
pub struct SinkOperandDefs;
impl Pass for SinkOperandDefs {
type Target = Operation;
fn name(&self) -> &'static str {
"sink-operand-defs"
}
fn argument(&self) -> &'static str {
"sink-operand-defs"
}
fn can_schedule_on(&self, _name: &OperationName) -> bool {
true
}
fn run_on_operation(
&mut self,
op: EntityMut<'_, Self::Target>,
state: &mut PassExecutionState,
) -> Result<(), Report> {
let operation = op.as_operation_ref();
drop(op);
log::debug!(target: "sink-operand-defs", "sinking operand defs for regions of {}", operation.borrow());
let mut worklist = alloc::collections::VecDeque::default();
let mut changed = PostPassStatus::Unchanged;
operation.raw_postwalk_all::<Backward, _>(|operation: OperationRef| {
let op = operation.borrow();
log::trace!(target: "sink-operand-defs", "visiting {op}");
for operand in op.operands().iter().rev() {
let value = operand.borrow();
let value = value.value();
let is_sole_user = value.iter_uses().all(|user| user.owner == operation);
let Some(defining_op) = value.get_defining_op() else {
log::trace!(target: "sink-operand-defs", " ignoring block argument operand '{value}'");
continue;
};
log::trace!(target: "sink-operand-defs", " evaluating operand '{value}'");
let def = defining_op.borrow();
if def.implements::<dyn ConstantLike>() {
log::trace!(target: "sink-operand-defs", " defining '{}' is constant-like", def.name());
worklist.push_back(OpOperandSink::new(operation));
break;
}
let incorrect_result_count = def.num_results() != 1;
let has_effects = !def.is_memory_effect_free();
if !is_sole_user || incorrect_result_count || has_effects {
log::trace!(target: "sink-operand-defs", " defining '{}' cannot be moved:", def.name());
log::trace!(target: "sink-operand-defs", " * op has multiple uses");
if incorrect_result_count {
log::trace!(target: "sink-operand-defs", " * op has incorrect number of results ({})", def.num_results());
}
if has_effects {
log::trace!(target: "sink-operand-defs", " * op has memory effects");
}
} else {
log::trace!(target: "sink-operand-defs", " defining '{}' is moveable, but is non-constant", def.name());
worklist.push_back(OpOperandSink::new(operation));
break;
}
}
});
for sinker in worklist.iter() {
log::debug!(target: "sink-operand-defs", "sink scheduled for {}", sinker.operation.borrow());
}
let mut visited = FxHashSet::default();
let mut erased = FxHashSet::default();
'next_operation: while let Some(mut sink_state) = worklist.pop_front() {
let mut operation = sink_state.operation;
let op = operation.borrow();
let is_memory_effect_free =
op.is_memory_effect_free() || op.implements::<dyn ConstantLike>();
if !op.is_used()
&& is_memory_effect_free
&& !op.implements::<dyn Terminator>()
&& !op.implements::<dyn RegionBranchTerminatorOpInterface>()
&& erased.insert(operation)
{
log::debug!(target: "sink-operand-defs", "erasing unused, effect-free, non-terminator op {op}");
drop(op);
operation.borrow_mut().erase();
continue;
}
if !visited.insert(operation) && sink_state.next_operand_index == op.num_operands() {
log::trace!(target: "sink-operand-defs", "already visited {}", operation.borrow());
continue;
} else {
log::trace!(target: "sink-operand-defs", "visiting {}", operation.borrow());
}
let mut builder = OpBuilder::new(op.context_rc());
builder.set_insertion_point(sink_state.ip);
'next_operand: loop {
let Some(next_operand_index) = sink_state.next_operand_index.checked_sub(1) else {
break;
};
log::debug!(target: "sink-operand-defs", " sinking next operand def for {op} at index {next_operand_index}");
let mut operand = op.operands()[next_operand_index];
sink_state.next_operand_index = next_operand_index;
let operand_value = operand.borrow().as_value_ref();
log::trace!(target: "sink-operand-defs", " visiting operand {operand_value}");
if let Some(replacement) = sink_state.replacements.get(&operand_value).copied() {
if replacement != operand_value {
log::trace!(target: "sink-operand-defs", " rewriting operand {operand_value} as {replacement}");
operand.borrow_mut().set(replacement);
changed = PostPassStatus::Changed;
if !operand_value.borrow().is_used() {
log::trace!(target: "sink-operand-defs", " {operand_value} is no longer used, erasing definition");
let mut defining_op = operand_value.borrow().get_defining_op().unwrap();
defining_op.borrow_mut().erase();
}
}
continue 'next_operand;
}
let value = operand_value.borrow();
let is_sole_user = value.iter_uses().all(|user| user.owner == operation);
let Some(mut defining_op) = value.get_defining_op() else {
log::trace!(target: "sink-operand-defs", " {value} is a block argument, ignoring..");
continue 'next_operand;
};
log::trace!(target: "sink-operand-defs", " is sole user of {value}? {is_sole_user}");
let def = defining_op.borrow();
if let Some(attr) = matchers::constant().matches(&*def) {
if !is_sole_user {
log::trace!(target: "sink-operand-defs", " defining op is a constant with multiple uses, materializing fresh copy");
let span = value.span();
let ty = value.ty();
let Some(new_def) =
def.dialect().materialize_constant(&mut builder, attr, ty, span)
else {
log::trace!(target: "sink-operand-defs", " unable to materialize copy, skipping rewrite of this operand");
continue 'next_operand;
};
drop(def);
drop(value);
let replacement = new_def.borrow().results()[0] as ValueRef;
log::trace!(target: "sink-operand-defs", " rewriting operand {operand_value} as {replacement}");
sink_state.replacements.insert(operand_value, replacement);
operand.borrow_mut().set(replacement);
changed = PostPassStatus::Changed;
} else {
log::trace!(target: "sink-operand-defs", " defining op is a constant with no other uses, moving into place");
drop(def);
drop(value);
defining_op.borrow_mut().move_to(*builder.insertion_point());
sink_state.replacements.insert(operand_value, operand_value);
}
} else if !is_sole_user || def.num_results() != 1 || !def.is_memory_effect_free() {
log::trace!(target: "sink-operand-defs", " defining op is unsuitable for sinking, ignoring this operand");
} else {
drop(def);
drop(value);
log::trace!(target: "sink-operand-defs", " defining op can be moved and has no other uses, moving into place");
defining_op.borrow_mut().move_to(*builder.insertion_point());
sink_state.replacements.insert(operand_value, operand_value);
log::trace!(target: "sink-operand-defs", " enqueing defining op for immediate processing");
sink_state.ip = ProgramPoint::before(operation);
worklist.push_front(sink_state);
worklist.push_front(OpOperandSink::new(defining_op));
continue 'next_operation;
}
}
}
state.set_post_pass_status(changed);
Ok(())
}
}
struct OpOperandSink {
operation: OperationRef,
ip: ProgramPoint,
replacements: SmallDenseMap<ValueRef, ValueRef, 4>,
next_operand_index: usize,
}
impl OpOperandSink {
pub fn new(operation: OperationRef) -> Self {
Self {
operation,
ip: ProgramPoint::before(operation),
replacements: SmallDenseMap::new(),
next_operand_index: operation.borrow().num_operands(),
}
}
}
struct Sinker<'a, P, F> {
dominfo: &'a DominanceInfo,
should_move_into_region: P,
move_into_region: F,
num_sunk: usize,
}
impl<'a, P, F> Sinker<'a, P, F>
where
P: Fn(&Operation, &Region) -> bool,
F: Fn(OperationRef, RegionRef),
{
pub fn new(
dominfo: &'a DominanceInfo,
should_move_into_region: P,
move_into_region: F,
) -> Self {
Self {
dominfo,
should_move_into_region,
move_into_region,
num_sunk: 0,
}
}
pub fn sink_regions(mut self, regions: &[RegionRef]) -> usize {
for region in regions.iter().copied() {
if !region.borrow().is_empty() {
self.sink_region(region);
}
}
self.num_sunk
}
fn all_users_dominated_by(&self, op: &Operation, region: &Region) -> bool {
assert!(
region.find_ancestor_op(op.as_operation_ref()).is_none(),
"expected op to be defined outside the region"
);
let region_entry = region.entry_block_ref().unwrap();
op.results().iter().all(|result| {
let result = result.borrow();
result.iter_uses().all(|user| {
self.dominfo.dominates(®ion_entry, &user.owner.parent().unwrap())
})
})
}
fn try_to_sink_predecessors(
&mut self,
user: OperationRef,
region: RegionRef,
stack: &mut Vec<OperationRef>,
) {
log::trace!(target: "control-flow-sink", "contained op: {}", user.borrow());
let user = user.borrow();
for operand in user.operands().iter() {
let op = operand.borrow().value().get_defining_op();
if op.is_none_or(|op| op.grandparent().is_some_and(|r| r == region)) {
continue;
}
let op = unsafe { op.unwrap_unchecked() };
log::trace!(target: "control-flow-sink", "try to sink op: {}", op.borrow());
let (all_users_dominated_by, should_move_into_region) = {
let op = op.borrow();
let region = region.borrow();
let all_users_dominated_by = self.all_users_dominated_by(&op, ®ion);
let should_move_into_region = (self.should_move_into_region)(&op, ®ion);
(all_users_dominated_by, should_move_into_region)
};
if all_users_dominated_by && should_move_into_region {
(self.move_into_region)(op, region);
self.num_sunk += 1;
stack.push(op);
}
}
}
fn sink_region(&mut self, region: RegionRef) {
let mut stack = Vec::new();
for block in region.borrow().body() {
for op in block.body() {
stack.push(op.as_operation_ref());
}
}
while let Some(op) = stack.pop() {
self.try_to_sink_predecessors(op, region, &mut stack);
}
}
}
pub fn control_flow_sink<P, F>(
regions: &[RegionRef],
dominfo: &DominanceInfo,
should_move_into_region: P,
move_into_region: F,
) -> PostPassStatus
where
P: Fn(&Operation, &Region) -> bool,
F: Fn(OperationRef, RegionRef),
{
let sinker = Sinker::new(dominfo, should_move_into_region, move_into_region);
let sunk_regions = sinker.sink_regions(regions);
(sunk_regions > 0).into()
}
fn get_singly_executed_regions_to_sink(
branch: &dyn RegionBranchOpInterface,
regions: &mut SmallVec<[RegionRef; 4]>,
) {
use midenc_hir::matchers::Matcher;
let mut operands = SmallVec::<[_; 4]>::with_capacity(branch.num_operands());
for operand in branch.operands().iter() {
let matcher = matchers::foldable_operand();
operands.push(matcher.matches(operand));
}
let bounds = branch.get_region_invocation_bounds(&operands);
for (region, bound) in branch.regions().iter().zip(bounds) {
use core::range::Bound;
match bound.max() {
Bound::Unbounded => continue,
Bound::Excluded(bound) if *bound > 2 => continue,
Bound::Excluded(0) => continue,
Bound::Included(bound) if *bound > 1 => continue,
_ => {
regions.push(region.as_region_ref());
}
}
}
}