use alloc::{collections::VecDeque, rc::Rc};
use midenc_hir::{
BlockRef, Builder, Context, FxHashMap, OpBuilder, OpOperand, Operation, OperationRef,
ProgramPoint, Region, RegionBranchOpInterface, RegionBranchPoint, RegionRef, Report, Rewriter,
SmallVec, SourceSpan, Spanned, StorableEntity, TraceTarget, Usable, ValueRange, ValueRef,
adt::{SmallDenseMap, SmallSet},
cfg::Graph,
dominance::{DomTreeNode, DominanceFrontier, DominanceInfo},
pass::{AnalysisManager, PostPassStatus},
traits::{IsolatedFromAbove, SingleRegion},
};
use midenc_hir_analysis::analyses::{
SpillAnalysis,
spills::{Placement, Predecessor},
};
pub trait TransformSpillsInterface {
fn create_unconditional_branch(
&self,
builder: &mut OpBuilder,
destination: BlockRef,
arguments: &[ValueRef],
span: SourceSpan,
) -> Result<(), Report>;
fn create_spill(
&self,
builder: &mut OpBuilder,
value: ValueRef,
span: SourceSpan,
) -> Result<OperationRef, Report>;
fn create_reload(
&self,
builder: &mut OpBuilder,
value: ValueRef,
span: SourceSpan,
) -> Result<OperationRef, Report>;
fn convert_spill_to_store(
&mut self,
rewriter: &mut dyn Rewriter,
spill: OperationRef,
) -> Result<(), Report>;
fn convert_reload_to_load(
&mut self,
rewriter: &mut dyn Rewriter,
reload: OperationRef,
) -> Result<(), Report>;
}
pub trait SpillLike {
fn spilled(&self) -> OpOperand;
fn spilled_value(&self) -> ValueRef {
self.spilled().borrow().as_value_ref()
}
}
pub trait ReloadLike {
fn spilled(&self) -> OpOperand;
fn spilled_value(&self) -> ValueRef {
self.spilled().borrow().as_value_ref()
}
fn reloaded(&self) -> ValueRef;
}
pub fn transform_spills(
op: OperationRef,
analysis: &mut SpillAnalysis,
interface: &mut dyn TransformSpillsInterface,
analysis_manager: AnalysisManager,
) -> Result<PostPassStatus, Report> {
assert!(
op.borrow().implements::<dyn SingleRegion>(),
"the spills transformation is not supported when the root op is multi-region"
);
let mut builder = OpBuilder::new(op.borrow().context_rc());
let trace_target = TraceTarget::category("pass").with_topic("spills");
let trace_target = if let Some(sym) = op.borrow().as_symbol() {
trace_target.with_relevant_symbol(sym.name())
} else {
trace_target
};
log::debug!(
target: &trace_target,
symbol = trace_target.relevant_symbol();
"analysis determined that some spills were required
edges to split = {}
values spilled = {}
reloads issued = {}\n",
analysis.splits().len(),
analysis.spills().len(),
analysis.reloads().len(),
);
for split_info in analysis.splits_mut() {
log::trace!(
target: &trace_target,
symbol = trace_target.relevant_symbol();
"splitting control flow edge {} -> {}",
match split_info.predecessor {
Predecessor::Parent => ProgramPoint::before(split_info.predecessor.operation(split_info.point)),
Predecessor::Block { op, .. } | Predecessor::Region(op) => ProgramPoint::at_end_of(op.parent().unwrap()),
},
split_info.point,
);
let predecessor_block = split_info
.predecessor
.block()
.unwrap_or_else(|| todo!("implement support for splits following a region branch op"));
let predecessor_region = predecessor_block.parent().unwrap();
let split = builder.create_block(predecessor_region, Some(predecessor_block), &[]);
log::trace!(
target: &trace_target,
symbol = trace_target.relevant_symbol();
"created {split} to hold contents of split edge"
);
split_info.split = Some(split);
match split_info.predecessor {
Predecessor::Block { mut op, index } => {
log::trace!(
target: &trace_target,
symbol = trace_target.relevant_symbol();
"redirecting {predecessor_block} to {split}"
);
let mut op = op.borrow_mut();
let mut succ = op.successor_mut(index as usize);
let prev_dest = succ.dest.parent().unwrap();
succ.dest.borrow_mut().set(split);
log::trace!(
target: &trace_target,
symbol = trace_target.relevant_symbol();
"creating edge from {split} to {prev_dest}"
);
let arguments = succ
.arguments
.take()
.into_iter()
.map(|mut operand| {
let mut operand = operand.borrow_mut();
let value = operand.as_value_ref();
operand.unlink();
value
})
.collect::<SmallVec<[_; 4]>>();
match split_info.point {
ProgramPoint::Block { block, .. } => {
assert_eq!(
prev_dest, block,
"unexpected mismatch between predecessor target and successor block"
);
interface.create_unconditional_branch(
&mut builder,
block,
&arguments,
op.span(),
)?;
}
point => panic!(
"unexpected program point for split: unstructured control flow requires a \
block entry, got {point}"
),
}
}
Predecessor::Region(predecessor) => {
log::trace!(
target: &trace_target,
symbol = trace_target.relevant_symbol();
"splitting region control flow edge to {} from {predecessor}",
split_info.point
);
todo!()
}
Predecessor::Parent => unimplemented!(
"support for splits on exit from region branch ops is not yet implemented"
),
}
}
for spill in analysis.spills.iter_mut() {
let ip = match spill.place {
Placement::Split(split) => {
let split_block = analysis.splits[split.as_usize()]
.split
.expect("expected split to have been materialized");
let terminator = split_block.borrow().terminator().unwrap();
ProgramPoint::before(terminator)
}
Placement::At(ip) => ip,
};
log::trace!(
target: &trace_target,
symbol = trace_target.relevant_symbol();
"inserting spill of {} at {ip}",
spill.value
);
builder.set_insertion_point(ip);
let inst = interface.create_spill(&mut builder, spill.value, spill.span)?;
spill.inst = Some(inst);
}
for reload in analysis.reloads.iter_mut() {
let ip = match reload.place {
Placement::Split(split) => {
let split_block = analysis.splits[split.as_usize()]
.split
.expect("expected split to have been materialized");
let terminator = split_block.borrow().terminator().unwrap();
ProgramPoint::before(terminator)
}
Placement::At(ip) => ip,
};
log::trace!(
target: &trace_target,
symbol = trace_target.relevant_symbol();
"inserting reload of {} at {ip}",
reload.value
);
builder.set_insertion_point(ip);
let inst = interface.create_reload(&mut builder, reload.value, reload.span)?;
reload.inst = Some(inst);
}
log::trace!(
target: &trace_target,
symbol = trace_target.relevant_symbol();
"all spills and reloads inserted successfully"
);
log::trace!(
target: &trace_target,
symbol = trace_target.relevant_symbol(),
dialect = op.name().dialect().as_str(),
op = op.name().name().as_str();
"op after inserting spills: {}",
op.borrow()
);
let dominfo = analysis_manager.get_analysis::<DominanceInfo>()?;
let region = op.borrow().regions().front().as_pointer().unwrap();
if region.borrow().has_one_block() {
rewrite_single_block_spills(
op,
region,
analysis,
interface,
analysis_manager,
&trace_target,
)?;
} else {
rewrite_cfg_spills(
builder.context_rc(),
region,
analysis,
interface,
&dominfo,
analysis_manager,
&trace_target,
)?;
}
log::trace!(
symbol = trace_target.relevant_symbol(),
dialect = op.name().dialect().as_str(),
op = op.name().name().as_str();
"op after rewriting spills: {}",
op.borrow()
);
Ok(PostPassStatus::Changed)
}
fn rewrite_single_block_spills(
op: OperationRef,
region: RegionRef,
analysis: &mut SpillAnalysis,
interface: &mut dyn TransformSpillsInterface,
_analysis_manager: AnalysisManager,
trace_target: &TraceTarget,
) -> Result<(), Report> {
struct Node {
block: BlockRef,
cursor: Option<OperationRef>,
is_first_visit: bool,
}
impl Node {
pub fn new(block: BlockRef) -> Self {
Self {
block,
cursor: block.borrow().body().back().as_pointer(),
is_first_visit: true,
}
}
pub fn current(&self) -> Option<OperationRef> {
self.cursor
}
pub fn move_next(&mut self) -> Option<OperationRef> {
let next = self.cursor.take()?;
self.cursor = next.prev();
Some(next)
}
}
let mut block_states =
FxHashMap::<BlockRef, SmallDenseMap<ValueRef, SmallSet<OpOperand, 8>, 8>>::default();
let entry_block = region.borrow().entry_block_ref().unwrap();
let mut block_q = VecDeque::from([Node::new(entry_block)]);
while let Some(mut node) = block_q.pop_back() {
let Some(operation) = node.current() else {
let block = node.block.borrow();
let used = block_states.entry(node.block).or_default();
for arg in ValueRange::<2>::from(block.arguments()) {
if analysis.is_spilled(&arg) {
used.remove(&arg);
}
}
continue;
};
let op = operation.borrow();
if let Some(branch) = op.as_trait::<dyn RegionBranchOpInterface>() {
if node.is_first_visit {
node.is_first_visit = false;
block_q.push_back(node);
for region in Region::postorder_region_graph_for(branch).into_iter().rev() {
let region = region.borrow();
assert!(
region.has_one_block(),
"multi-block regions are not currently supported"
);
let entry = region.entry();
block_q.push_back(Node::new(entry.as_block_ref()));
}
continue;
} else {
for region in branch.get_successor_regions(RegionBranchPoint::Parent) {
let Some(region) = region.into_successor() else {
continue;
};
let region_entry = region.borrow().entry_block_ref().unwrap();
if let Some(uses) = block_states.remove(®ion_entry) {
let parent_uses = block_states.entry(node.block).or_default();
for (spilled, users) in uses {
let parent_users = parent_uses.entry(spilled).or_default();
let merged = users.into_union(parent_users);
*parent_users = merged;
}
}
}
}
}
let used = block_states.entry(node.block).or_default();
find_inst_uses(&op, used, analysis, trace_target);
node.move_next();
block_q.push_back(node);
}
let context = { op.borrow().context_rc() };
rewrite_spill_pseudo_instructions(context, analysis, interface, None, trace_target)
}
fn rewrite_cfg_spills(
context: Rc<Context>,
region: RegionRef,
analysis: &mut SpillAnalysis,
interface: &mut dyn TransformSpillsInterface,
dominfo: &DominanceInfo,
_analysis_manager: AnalysisManager,
trace_target: &TraceTarget,
) -> Result<(), Report> {
let domtree = dominfo.dominance(region);
let domf = DominanceFrontier::new(&domtree);
let inserted_phis = insert_required_phis(&context, analysis, &domf, trace_target);
let mut used_sets =
SmallDenseMap::<BlockRef, SmallDenseMap<ValueRef, SmallSet<OpOperand, 8>, 8>, 8>::default();
let mut block_q = VecDeque::from(domtree.postorder());
while let Some(node) = block_q.pop_front() {
let Some(block_ref) = node.block() else {
continue;
};
let mut used = SmallDenseMap::<ValueRef, SmallSet<OpOperand, 8>, 8>::default();
for succ in Rc::<DomTreeNode>::children(node) {
let Some(succ_block) = succ.block() else {
continue;
};
if let Some(usages) = used_sets.get_mut(&succ_block) {
for (value, users) in usages.iter() {
used.entry(*value).or_default().extend(users.iter().copied());
}
}
}
let block = block_ref.borrow();
for op in block.body().iter().rev() {
find_inst_uses(&op, &mut used, analysis, trace_target);
}
for arg in ValueRange::<2>::from(block.arguments()) {
used.remove(&arg);
}
rewrite_inserted_phi_uses(&inserted_phis, block_ref, &mut used, trace_target);
used_sets.insert(block_ref, used);
}
rewrite_spill_pseudo_instructions(context, analysis, interface, Some(dominfo), trace_target)
}
fn find_inst_uses(
op: &Operation,
used: &mut SmallDenseMap<ValueRef, SmallSet<OpOperand, 8>, 8>,
analysis: &SpillAnalysis,
trace_target: &TraceTarget,
) {
merge_op_nested_region_uses(op, used, analysis, trace_target);
find_inst_uses_in_op(op, used, analysis);
}
fn merge_op_nested_region_uses(
op: &Operation,
used: &mut SmallDenseMap<ValueRef, SmallSet<OpOperand, 8>, 8>,
analysis: &SpillAnalysis,
trace_target: &TraceTarget,
) {
if op.implements::<dyn IsolatedFromAbove>() {
return;
}
if let Some(branch) = op.as_trait::<dyn RegionBranchOpInterface>() {
merge_nested_region_uses(branch, used, analysis, trace_target);
return;
}
for region in op.regions().iter() {
let region = region.as_region_ref();
let region_borrowed = region.borrow();
if region_borrowed.is_empty() {
continue;
}
if !region_borrowed.has_one_block() {
log::trace!(
target: "insert-spills",
"skipping multi-block nested region {region} when collecting spill uses"
);
continue;
}
let entry = region_borrowed
.entry_block_ref()
.expect("expected non-empty region to have an entry block");
drop(region_borrowed);
let region_used = collect_block_uses(entry, analysis, trace_target);
for (value, users) in region_used {
used.entry(value).or_default().extend(users.iter().copied());
}
}
}
fn merge_nested_region_uses(
branch: &dyn RegionBranchOpInterface,
used: &mut SmallDenseMap<ValueRef, SmallSet<OpOperand, 8>, 8>,
analysis: &SpillAnalysis,
trace_target: &TraceTarget,
) {
for region in Region::postorder_region_graph_for(branch) {
let region = region.borrow();
assert!(region.has_one_block(), "multi-block regions are not currently supported");
let entry = region.entry_block_ref().expect("expected region to have an entry block");
drop(region);
let region_used = collect_block_uses(entry, analysis, trace_target);
for (value, users) in region_used {
used.entry(value).or_default().extend(users.iter().copied());
}
}
}
fn collect_block_uses(
block_ref: BlockRef,
analysis: &SpillAnalysis,
trace_target: &TraceTarget,
) -> SmallDenseMap<ValueRef, SmallSet<OpOperand, 8>, 8> {
let mut used = SmallDenseMap::<ValueRef, SmallSet<OpOperand, 8>, 8>::default();
let block = block_ref.borrow();
for op in block.body().iter().rev() {
find_inst_uses(&op, &mut used, analysis, trace_target);
}
for arg in ValueRange::<2>::from(block.arguments()) {
used.remove(&arg);
}
used
}
fn insert_required_phis(
context: &Context,
analysis: &SpillAnalysis,
domf: &DominanceFrontier,
trace_target: &TraceTarget,
) -> SmallDenseMap<BlockRef, SmallDenseMap<ValueRef, ValueRef, 8>, 8> {
use midenc_hir::adt::smallmap::Entry;
let mut required_phis = SmallDenseMap::<ValueRef, SmallSet<BlockRef, 2>, 4>::default();
for reload in analysis.reloads() {
let block = reload.inst.unwrap().parent().unwrap();
log::trace!(
target: trace_target,
symbol = trace_target.relevant_symbol();
"add required_phis for {}",
reload.value
);
let r = required_phis.entry(reload.value).or_default();
r.insert(block);
}
let mut inserted_phis =
SmallDenseMap::<BlockRef, SmallDenseMap<ValueRef, ValueRef, 8>, 8>::default();
for (value, domf_r) in required_phis {
let idf_r = domf.iterate_all(domf_r);
let (ty, span) = {
let value = value.borrow();
(value.ty().clone(), value.span())
};
for mut b in idf_r {
let phis = inserted_phis.entry(b).or_default();
if let Entry::Vacant(entry) = phis.entry(value) {
let phi = context.append_block_argument(b, ty.clone(), span);
entry.insert(phi);
let block = b.borrow_mut();
let mut next_use = block.uses().front().as_pointer();
while let Some(pred) = next_use.take() {
next_use = pred.next();
let (mut predecessor, successor_index) = {
let pred = pred.borrow();
(pred.owner, pred.index as usize)
};
let operand = context.make_operand(value, predecessor, 0);
predecessor.borrow_mut().successor_mut(successor_index).arguments.push(operand);
}
}
}
}
inserted_phis
}
fn find_inst_uses_in_op(
op: &Operation,
used: &mut SmallDenseMap<ValueRef, SmallSet<OpOperand, 8>, 8>,
analysis: &SpillAnalysis,
) {
let reload_like = op.as_trait::<dyn ReloadLike>();
let is_reload = reload_like.is_some();
if let Some(reload_like) = reload_like {
let spilled = reload_like.spilled_value();
let reloaded = reload_like.reloaded();
if let Some(to_rewrite) = used.remove(&spilled) {
debug_assert!(!to_rewrite.is_empty(), "expected empty use sets to be removed");
for mut user in to_rewrite {
user.borrow_mut().set(reloaded);
}
} else {
return;
}
}
for result in ValueRange::<2>::from(op.results().all()) {
if analysis.is_spilled(&result) {
used.remove(&result);
}
}
if !is_reload {
for operand in op.operands().iter().copied() {
let value = operand.borrow().as_value_ref();
if analysis.is_spilled(&value) {
used.entry(value).or_default().insert(operand);
}
}
}
}
fn rewrite_inserted_phi_uses(
inserted_phis: &SmallDenseMap<BlockRef, SmallDenseMap<ValueRef, ValueRef, 8>, 8>,
block_ref: BlockRef,
used: &mut SmallDenseMap<ValueRef, SmallSet<OpOperand, 8>, 8>,
trace_target: &TraceTarget,
) {
if let Some(phis) = inserted_phis.get(&block_ref) {
for (spilled, phi) in phis.iter() {
if let Some(to_rewrite) = used.remove(spilled) {
debug_assert!(!to_rewrite.is_empty(), "expected empty use sets to be removed");
for mut user in to_rewrite {
user.borrow_mut().set(*phi);
}
} else {
log::warn!(
target: trace_target,
symbol = trace_target.relevant_symbol();
"unused phi {phi} encountered during rewrite phase"
);
continue;
}
}
}
}
fn rewrite_spill_pseudo_instructions(
context: Rc<Context>,
analysis: &mut SpillAnalysis,
interface: &mut dyn TransformSpillsInterface,
dominfo: Option<&DominanceInfo>,
trace_target: &TraceTarget,
) -> Result<(), Report> {
use midenc_hir::{
dominance::Dominates,
patterns::{RewriterImpl, TracingRewriterListener},
};
let mut builder = RewriterImpl::<TracingRewriterListener>::new(context)
.with_listener(TracingRewriterListener);
for spill in analysis.spills() {
let operation = spill.inst.expect("expected spill to have been materialized");
let spilled = {
let op = operation.borrow();
let spill_like = op
.as_trait::<dyn SpillLike>()
.expect("expected materialized spill operation to implement SpillLike");
spill_like.spilled_value()
};
let mut is_used = false;
for rinfo in analysis.reloads() {
if rinfo.value != spilled {
continue;
}
let Some(reload_op) = rinfo.inst else {
continue;
};
let (reload_used, dom_ok) = {
let rop = reload_op.borrow();
let rl = rop
.as_trait::<dyn ReloadLike>()
.expect("expected materialized reload op to implement ReloadLike");
let used = rl.reloaded().borrow().is_used();
let dom_ok = match dominfo {
None => true,
Some(dominfo) => {
let sop = operation.borrow();
sop.dominates(&rop, dominfo)
}
};
(used, dom_ok)
};
if reload_used && dom_ok {
is_used = true;
break;
}
}
if is_used {
builder.set_insertion_point_after(operation);
interface.convert_spill_to_store(&mut builder, operation)?;
} else {
builder.erase_op(operation);
}
}
for reload in analysis.reloads() {
let operation = reload.inst.expect("expected reload to have been materialized");
let op = operation.borrow();
let reload_like = op
.as_trait::<dyn ReloadLike>()
.expect("expected materialized reload op to implement ReloadLike");
let is_used = reload_like.reloaded().borrow().is_used();
drop(op);
if is_used {
log::trace!(
target: trace_target,
symbol = trace_target.relevant_symbol();
"convert reload to load {}",
reload.place
);
builder.set_insertion_point_after(operation);
interface.convert_reload_to_load(&mut builder, operation)?;
} else {
log::trace!(
target: trace_target,
symbol = trace_target.relevant_symbol();
"erase unused reload {}",
reload.value
);
builder.erase_op(operation);
}
}
Ok(())
}