use core::panic;
use std::collections::hash_map;
use pliron_derive::op_interface;
use rustc_hash::{FxHashMap, FxHashSet};
use crate::{
basic_block::BasicBlock,
builtin::op_interfaces::BranchOpInterface,
common_traits::Named,
context::{Context, Ptr},
debug_info::set_block_arg_name,
graph::{
dominance::{DomFrontierMap, DomTree, compute_dominator_tree},
walkers::{IRNode, WALKCONFIG_PREORDER_FORWARD, uninterruptible::immutable::walk_op},
},
irbuild::{
inserter::{IRInserter, Inserter},
listener::{Recorder, RecorderEvent},
rewriter::IRRewriter,
},
linked_list::ContainsLinkedList,
op::{Op, op_cast, op_impls},
operation::{OpDbg, Operation},
opts::OptStatus,
region::Region,
result::Result,
r#type::TypeObj,
value::Value,
};
#[derive(Clone)]
pub struct AllocInfo {
pub ptr: Value,
pub ty: Ptr<TypeObj>,
}
#[op_interface]
pub trait PromotableAllocationInterface {
fn alloc_info(&self, ctx: &Context) -> Vec<AllocInfo>;
fn default_value(
&self,
ctx: &mut Context,
inserter: &mut IRInserter<Recorder>,
alloc_info: &AllocInfo,
) -> Result<Value>;
fn promote(
&self,
ctx: &mut Context,
rewriter: &mut IRRewriter<Recorder>,
alloc_infos: &[AllocInfo],
) -> Result<()>;
fn verify(_op: &dyn Op, _ctx: &Context) -> Result<()>
where
Self: Sized,
{
Ok(())
}
}
pub enum PromotableOpKind {
Load,
Store(Value),
EliminatableUse,
NonPromotableUse,
}
#[op_interface]
pub trait PromotableOpInterface {
fn promotion_kind(&self, ctx: &Context, alloc_info: &AllocInfo) -> PromotableOpKind;
fn promote(
&self,
ctx: &mut Context,
alloc_info_reaching_defs: &[(AllocInfo, Value)],
rewriter: &mut IRRewriter<Recorder>,
) -> Result<()>;
fn verify(_op: &dyn Op, _ctx: &Context) -> Result<()>
where
Self: Sized,
{
Ok(())
}
}
#[derive(Clone)]
struct AllocCandidate {
alloc_op: Ptr<Operation>,
alloc_info: AllocInfo,
}
fn collect_alloc_candidates(root: Ptr<Operation>, ctx: &Context) -> Vec<AllocCandidate> {
let mut candidates: Vec<AllocCandidate> = Vec::new();
walk_op(
ctx,
&mut candidates,
&WALKCONFIG_PREORDER_FORWARD,
root,
|ctx, candidates, node| {
if let IRNode::Operation(op) = node {
let op_obj = Operation::get_op_dyn(op, ctx);
if let Some(iface) = op_cast::<dyn PromotableAllocationInterface>(op_obj.as_ref()) {
for alloc_info in iface.alloc_info(ctx) {
candidates.push(AllocCandidate {
alloc_op: op,
alloc_info,
});
}
}
}
},
);
candidates
}
fn prune_candidates(candidates: &mut Vec<AllocCandidate>, ctx: &Context) {
candidates.retain(|cand| {
let alloc_region = cand
.alloc_op
.deref(ctx)
.get_parent_region(ctx)
.expect("Alloc op must be in a region");
cand.alloc_info.ptr.uses(ctx).iter().all(|r#use| {
let user_op = r#use.user_op();
let user_op_obj = Operation::get_op_dyn(user_op, ctx);
op_cast::<dyn PromotableOpInterface>(user_op_obj.as_ref()).is_some_and(|piface| {
let user_region = user_op
.deref(ctx)
.get_parent_region(ctx)
.expect("Use op must be in a region");
let promotion_kind = piface.promotion_kind(ctx, &cand.alloc_info);
user_region == alloc_region
&& !matches!(promotion_kind, PromotableOpKind::NonPromotableUse)
})
})
});
}
fn compute_candidate_live_in_and_defining_blocks(
ctx: &Context,
cand: &AllocCandidate,
) -> (FxHashSet<Ptr<BasicBlock>>, FxHashSet<Ptr<BasicBlock>>) {
let ptr = cand.alloc_info.ptr;
let mut defining_blocks: FxHashSet<Ptr<BasicBlock>> = FxHashSet::default();
let mut live_in: FxHashSet<Ptr<BasicBlock>> = FxHashSet::default();
let mut live_in_worklist: Vec<Ptr<BasicBlock>> = Vec::new();
let mut user_blocks: FxHashSet<Ptr<BasicBlock>> = FxHashSet::default();
for u in ptr.uses(ctx) {
if let Some(block) = u.user_op().deref(ctx).get_parent_block() {
user_blocks.insert(block);
}
}
for block in user_blocks {
let mut has_store = false;
let mut load_before_store = false;
for op in block.deref(ctx).iter(ctx) {
let op_obj = Operation::get_op_dyn(op, ctx);
let Some(op_promotable) = op_cast::<dyn PromotableOpInterface>(op_obj.as_ref()) else {
continue;
};
match op_promotable.promotion_kind(ctx, &cand.alloc_info) {
PromotableOpKind::Load | PromotableOpKind::EliminatableUse => {
if !has_store {
load_before_store = true;
}
}
PromotableOpKind::Store(_) => {
has_store = true;
}
PromotableOpKind::NonPromotableUse => {
}
}
}
if has_store {
defining_blocks.insert(block);
}
if load_before_store {
live_in_worklist.push(block);
}
}
while let Some(live_in_block) = live_in_worklist.pop() {
if !live_in.insert(live_in_block) {
continue;
}
for pred in live_in_block.preds(ctx) {
if !defining_blocks.contains(&pred) {
live_in_worklist.push(pred);
}
}
}
(live_in, defining_blocks)
}
fn compute_candidate_phi_blocks(
df_map: &DomFrontierMap<Ptr<Region>, Context>,
live_in: &FxHashSet<Ptr<BasicBlock>>,
defining_blocks: &FxHashSet<Ptr<BasicBlock>>,
) -> FxHashSet<Ptr<BasicBlock>> {
let mut phi_blocks: FxHashSet<Ptr<BasicBlock>> = FxHashSet::default();
let mut worklist: Vec<Ptr<BasicBlock>> = defining_blocks.iter().cloned().collect();
while let Some(block) = worklist.pop() {
for &df_block in df_map.frontier(&block) {
if !live_in.contains(&df_block) {
continue;
}
if !phi_blocks.insert(df_block) {
continue;
}
if !defining_blocks.contains(&df_block) {
worklist.push(df_block);
}
}
}
phi_blocks
}
fn prune_candidates_with_unknown_branch_from_pred(
ctx: &Context,
alloc_candidates: &mut Vec<AllocCandidate>,
phi_blocks: &mut FxHashMap<Value, FxHashSet<Ptr<BasicBlock>>>,
) {
let mut invalid_ptrs: FxHashSet<Value> = FxHashSet::default();
for cand in alloc_candidates.iter() {
let ptr = cand.alloc_info.ptr;
let invalid = phi_blocks
.get(&ptr)
.into_iter()
.flatten()
.flat_map(|&phi_block| phi_block.preds(ctx).into_iter())
.any(|pred| {
pred.deref(ctx).get_terminator(ctx).is_none_or(|term| {
!op_impls::<dyn BranchOpInterface>(Operation::get_op_dyn(term, ctx).as_ref())
})
});
if invalid {
invalid_ptrs.insert(ptr);
}
}
alloc_candidates.retain(|c| !invalid_ptrs.contains(&c.alloc_info.ptr));
phi_blocks.retain(|&ptr, _| !invalid_ptrs.contains(&ptr));
}
fn get_or_create_default_def(
alloc_cand: &AllocCandidate,
ctx: &mut Context,
default_defs: &mut FxHashMap<Value, Value>,
) -> Result<Value> {
match default_defs.entry(alloc_cand.alloc_info.ptr) {
hash_map::Entry::Occupied(entry) => Ok(*entry.get()),
hash_map::Entry::Vacant(entry) => {
let alloc_op = alloc_cand.alloc_op;
let alloc_obj = Operation::get_op_dyn(alloc_op, ctx);
let alloc_iface = op_cast::<dyn PromotableAllocationInterface>(alloc_obj.as_ref())
.expect("Alloc op must implement PromotableAllocationInterface");
let default_val = alloc_iface.default_value(
ctx,
&mut IRInserter::new_before_operation(alloc_op),
&alloc_cand.alloc_info,
)?;
entry.insert(default_val);
Ok(default_val)
}
}
}
fn note_erased_ops(recorder: &mut Recorder, erased: &mut FxHashSet<Ptr<Operation>>) {
for event in recorder.events.drain(..) {
match event {
RecorderEvent::ErasedOperation(op) => {
erased.insert(op);
}
RecorderEvent::ErasedBlock(_)
| RecorderEvent::ErasedRegion(_)
| RecorderEvent::UnlinkedBlock(_, _) => {
panic!("mem2reg rewrite (promotion) call backs must not alter control flow");
}
RecorderEvent::InsertedOperation(_)
| RecorderEvent::InsertedBlock(_)
| RecorderEvent::ReplacedValueUses { .. }
| RecorderEvent::ValueTypeChanged { .. }
| RecorderEvent::UnlinkedOperation(_, _) => {
}
}
}
}
fn rename_block(
ctx: &mut Context,
block: Ptr<BasicBlock>,
dom_tree: &DomTree<Ptr<Region>, Context>,
new_phis_in_block: &FxHashMap<Ptr<BasicBlock>, Vec<(AllocCandidate, usize)>>,
reaching_def_map: &FxHashMap<Value, Vec<Value>>,
default_def_map: &mut FxHashMap<Value, Value>,
alloc_candidates: &[AllocCandidate],
) -> Result<()> {
let mut reaching_def_map = reaching_def_map
.iter()
.map(|(&ptr, stack)| {
(ptr, {
let mut new_stack = Vec::new();
if let Some(&val) = stack.last() {
new_stack.push(val);
}
new_stack
})
})
.collect::<FxHashMap<_, _>>();
for &(ref cand, arg_idx) in new_phis_in_block.get(&block).into_iter().flatten() {
let new_val = block.deref(ctx).get_argument(arg_idx);
reaching_def_map
.get_mut(&cand.alloc_info.ptr)
.unwrap()
.push(new_val);
}
let ops: Vec<Ptr<Operation>> = block.deref(ctx).iter(ctx).collect();
let mut erased_ops = FxHashSet::default();
for &op in &ops {
if erased_ops.contains(&op) {
continue;
}
let op_obj = Operation::get_op_dyn(op, ctx);
let Some(piface) = op_cast::<dyn PromotableOpInterface>(op_obj.as_ref()) else {
continue;
};
let mut promote_queue = Vec::new();
for cand in alloc_candidates {
let ptr = cand.alloc_info.ptr;
match piface.promotion_kind(ctx, &cand.alloc_info) {
PromotableOpKind::Load | PromotableOpKind::EliminatableUse => {
let reaching_def_stack = reaching_def_map.get_mut(&ptr).unwrap();
if reaching_def_stack.is_empty() {
let default_val = get_or_create_default_def(cand, ctx, default_def_map)?;
reaching_def_stack.push(default_val);
}
let current_def = *reaching_def_stack.last().unwrap();
promote_queue.push((cand.alloc_info.clone(), current_def));
}
PromotableOpKind::Store(stored_val) => {
reaching_def_map.get_mut(&ptr).unwrap().push(stored_val);
promote_queue.push((cand.alloc_info.clone(), stored_val));
}
PromotableOpKind::NonPromotableUse => {}
}
}
if !promote_queue.is_empty() {
let rewriter = &mut IRRewriter::default();
rewriter.set_insertion_point_before_operation(op);
log::trace!("Promoting op {}", OpDbg { op, ctx });
piface.promote(ctx, &promote_queue, rewriter)?;
note_erased_ops(rewriter.get_listener_mut(), &mut erased_ops);
}
}
let succs = block.deref(ctx).succs(ctx);
for (succ_idx, new_phis_in_succ) in succs.iter().enumerate().filter_map(|(succ_idx, succ)| {
new_phis_in_block
.get(succ)
.map(|new_phis| (succ_idx, new_phis))
}) {
let term = block
.deref(ctx)
.get_terminator(ctx)
.expect("Block has successors but no terminator");
let term_obj = Operation::get_op_dyn(term, ctx);
let branch_iface = op_cast::<dyn BranchOpInterface>(term_obj.as_ref())
.expect("Terminator must implement BranchOpInterface for phi blocks");
for &(ref cand, arg_idx) in new_phis_in_succ {
let reaching_def_stack = reaching_def_map.get_mut(&cand.alloc_info.ptr).unwrap();
if reaching_def_stack.is_empty() {
let default_val = get_or_create_default_def(cand, ctx, default_def_map)?;
reaching_def_stack.push(default_val);
}
let current_def = *reaching_def_stack.last().unwrap();
let succ_opd_idx = branch_iface.add_successor_operand(ctx, succ_idx, current_def);
assert!(succ_opd_idx == arg_idx, "Mismatched phi argument index");
}
}
for child in dom_tree.children(&block) {
rename_block(
ctx,
child,
dom_tree,
new_phis_in_block,
&reaching_def_map,
default_def_map,
alloc_candidates,
)?;
}
Ok(())
}
pub fn mem2reg(root: Ptr<Operation>, ctx: &mut Context) -> Result<OptStatus> {
let mut candidates = collect_alloc_candidates(root, ctx);
prune_candidates(&mut candidates, ctx);
if candidates.is_empty() {
return Ok(OptStatus::IRUnchanged);
}
let mut by_region: FxHashMap<Ptr<Region>, Vec<AllocCandidate>> = FxHashMap::default();
for cand in candidates {
let region = cand
.alloc_op
.deref(ctx)
.get_parent_region(ctx)
.expect("Alloc op must be in a region");
by_region.entry(region).or_default().push(cand);
}
let mut opt_status = OptStatus::IRUnchanged;
for (region, mut alloc_candidates) in by_region {
let dom_tree: DomTree<Ptr<Region>, Context> = compute_dominator_tree(ctx, ®ion);
let df_map = DomFrontierMap::new(ctx, ®ion, &dom_tree);
let mut phi_blocks: FxHashMap<Value, FxHashSet<Ptr<BasicBlock>>> = FxHashMap::default();
for cand in alloc_candidates.iter() {
let ptr = cand.alloc_info.ptr;
let (live_in, defining_blocks) =
compute_candidate_live_in_and_defining_blocks(ctx, cand);
let candidate_phi_blocks =
compute_candidate_phi_blocks(&df_map, &live_in, &defining_blocks);
phi_blocks.insert(ptr, candidate_phi_blocks);
}
prune_candidates_with_unknown_branch_from_pred(ctx, &mut alloc_candidates, &mut phi_blocks);
if alloc_candidates.is_empty() {
continue;
}
opt_status |= OptStatus::IRChanged;
let mut new_phis_in_block: FxHashMap<Ptr<BasicBlock>, Vec<(AllocCandidate, usize)>> =
FxHashMap::default();
for cand in alloc_candidates.iter() {
let ptr = cand.alloc_info.ptr;
if let Some(needed_blocks) = phi_blocks.get(&ptr) {
let needed_blocks: Vec<Ptr<BasicBlock>> = needed_blocks.iter().cloned().collect();
for phi_block in needed_blocks {
let arg_idx = BasicBlock::push_argument(phi_block, ctx, cand.alloc_info.ty);
set_block_arg_name(ctx, phi_block, arg_idx, ptr.given_name(ctx));
new_phis_in_block
.entry(phi_block)
.or_default()
.push((cand.clone(), arg_idx));
}
}
}
let reaching_def_map: FxHashMap<Value, Vec<Value>> = alloc_candidates
.iter()
.map(|c| (c.alloc_info.ptr, Vec::new()))
.collect();
let mut default_def_map: FxHashMap<Value, Value> = FxHashMap::default();
let entry_block = region
.deref(ctx)
.get_head()
.expect("No entry block in region");
rename_block(
ctx,
entry_block,
&dom_tree,
&new_phis_in_block,
&reaching_def_map,
&mut default_def_map,
&alloc_candidates,
)?;
let mut alloc_op_to_infos: FxHashMap<Ptr<Operation>, Vec<AllocInfo>> = FxHashMap::default();
let rewriter = &mut IRRewriter::default();
for cand in alloc_candidates.iter() {
alloc_op_to_infos
.entry(cand.alloc_op)
.or_default()
.push(cand.alloc_info.clone());
}
let mut erased_ops = FxHashSet::default();
for (op, infos) in alloc_op_to_infos {
if erased_ops.contains(&op) {
panic!("Alloc op was already erased during promotion of another candidate");
}
rewriter.set_insertion_point_before_operation(op);
let op = Operation::get_op_dyn(op, ctx);
let piface = op_cast::<dyn PromotableAllocationInterface>(op.as_ref())
.expect("Alloc op must implement PromotableAllocationInterface");
log::trace!(
"Promoting allocation {}",
OpDbg {
op: op.get_operation(),
ctx
}
);
piface.promote(ctx, rewriter, &infos)?;
note_erased_ops(rewriter.get_listener_mut(), &mut erased_ops);
}
}
Ok(opt_status)
}