use std::num::NonZero;
use thiserror::Error;
use pliron::{
arg_err,
attribute::AttrObj,
basic_block::BasicBlock,
builtin::{
attributes::IntegerAttr,
op_interfaces::{BranchOpInterface, OneResultInterface},
},
context::{Context, Ptr},
derive::op_interface_impl,
irbuild::{IRStatus, inserter::Inserter, rewriter::Rewriter},
op::Op,
opts::{
constants::{BranchOpFoldInterface, ConstFoldInterface},
dce::{BlockArgRemoval, SideEffects},
mem2reg::{
AllocInfo, PromotableAllocationInterface, PromotableOpInterface, PromotableOpKind,
},
},
result::Result,
utils::apint::APInt,
value::Value,
};
use crate::attributes::IntegerOverflowFlagsAttr;
use crate::op_interfaces::IntBinArithOpWithOverflowFlag;
use crate::{
op_interfaces::PointerTypeResult,
ops::{
AShrOp, AddOp, AddressOfOp, AllocaOp, AndOp, BitcastOp, BrOp, CondBrOp, ExtractElementOp,
ExtractValueOp, FAddOp, FCmpOp, FDivOp, FMulOp, FNegOp, FPExtOp, FPToSIOp, FPToUIOp,
FPTruncOp, FRemOp, FSubOp, FreezeOp, FuncOp, GetElementPtrOp, ICmpOp, InsertElementOp,
InsertValueOp, IntToPtrOp, LShrOp, LoadOp, MulOp, OrOp, PoisonOp, PtrToIntOp, SDivOp,
SExtOp, SIToFPOp, SRemOp, SelectOp, ShlOp, ShuffleVectorOp, StoreOp, SubOp, SwitchOp,
TruncOp, UDivOp, UIToFPOp, URemOp, UndefOp, XorOp, ZExtOp, ZeroOp,
},
};
#[derive(Error, Debug)]
#[error("Register Promotion: Allocation info provided is not related to this operation")]
pub struct UnrelatedAllocInfo;
#[op_interface_impl]
impl PromotableAllocationInterface for AllocaOp {
fn alloc_info(&self, ctx: &Context) -> Vec<AllocInfo> {
vec![AllocInfo {
ptr: self.get_result(ctx),
ty: self.result_pointee_type(ctx),
}]
}
fn default_value(
&self,
ctx: &mut Context,
inserter: &mut dyn Inserter,
alloc_info: &AllocInfo,
) -> Result<Value> {
if alloc_info.ptr != self.get_result(ctx) {
return arg_err!(self.loc(ctx), UnrelatedAllocInfo);
}
let poison = PoisonOp::new(ctx, alloc_info.ty);
let poison_val = poison.get_result(ctx);
inserter.insert_op(ctx, &poison);
Ok(poison_val)
}
fn promote(
&self,
ctx: &mut Context,
rewriter: &mut dyn Rewriter,
alloc_infos: &[AllocInfo],
) -> Result<()> {
if alloc_infos.len() != 1 || alloc_infos[0].ptr != self.get_result(ctx) {
return arg_err!(self.loc(ctx), UnrelatedAllocInfo);
}
rewriter.erase_operation(ctx, self.get_operation());
Ok(())
}
}
#[op_interface_impl]
impl PromotableOpInterface for StoreOp {
fn promotion_kind(&self, ctx: &Context, alloc_info: &AllocInfo) -> PromotableOpKind {
if self.get_operand_address(ctx) == alloc_info.ptr {
PromotableOpKind::Store(self.get_operand_value(ctx))
} else {
PromotableOpKind::NonPromotableUse
}
}
fn promote(
&self,
ctx: &mut Context,
alloc_info_reaching_defs: &[(AllocInfo, Value)],
rewriter: &mut dyn Rewriter,
) -> Result<()> {
if alloc_info_reaching_defs.len() != 1 {
return arg_err!(self.loc(ctx), UnrelatedAllocInfo);
}
let (alloc_info, _reaching_def) = &alloc_info_reaching_defs[0];
if self.get_operand_address(ctx) != alloc_info.ptr {
return arg_err!(self.loc(ctx), UnrelatedAllocInfo);
}
rewriter.erase_operation(ctx, self.get_operation());
Ok(())
}
}
#[op_interface_impl]
impl PromotableOpInterface for LoadOp {
fn promotion_kind(&self, ctx: &Context, alloc_info: &AllocInfo) -> PromotableOpKind {
if self.get_operand_address(ctx) == alloc_info.ptr {
PromotableOpKind::Load
} else {
PromotableOpKind::NonPromotableUse
}
}
fn promote(
&self,
ctx: &mut Context,
alloc_info_reaching_defs: &[(AllocInfo, Value)],
rewriter: &mut dyn Rewriter,
) -> Result<()> {
if alloc_info_reaching_defs.len() != 1 {
return arg_err!(self.loc(ctx), UnrelatedAllocInfo);
}
let (alloc_info, reaching_def) = &alloc_info_reaching_defs[0];
if self.get_operand_address(ctx) != alloc_info.ptr {
return arg_err!(self.loc(ctx), UnrelatedAllocInfo);
}
rewriter.replace_operation_with_values(ctx, self.get_operation(), vec![*reaching_def]);
Ok(())
}
}
macro_rules! impl_side_effects_false {
($($op:ty),+ $(,)?) => {
$(
#[op_interface_impl]
impl SideEffects for $op {
fn has_side_effects(&self, _ctx: &Context) -> bool {
false
}
}
)+
};
}
impl_side_effects_false!(
AddOp,
SubOp,
MulOp,
ShlOp,
UDivOp,
SDivOp,
URemOp,
SRemOp,
AndOp,
OrOp,
XorOp,
LShrOp,
AShrOp,
ICmpOp,
AllocaOp,
BitcastOp,
IntToPtrOp,
PtrToIntOp,
UndefOp,
PoisonOp,
FreezeOp,
ZeroOp,
AddressOfOp,
SExtOp,
ZExtOp,
FPExtOp,
TruncOp,
FPTruncOp,
FPToSIOp,
FPToUIOp,
SIToFPOp,
UIToFPOp,
InsertValueOp,
ExtractValueOp,
InsertElementOp,
ExtractElementOp,
ShuffleVectorOp,
SelectOp,
FNegOp,
FAddOp,
FSubOp,
FMulOp,
FDivOp,
FRemOp,
FCmpOp,
GetElementPtrOp,
);
#[op_interface_impl]
impl BlockArgRemoval for FuncOp {
fn can_remove_block_args(&self, ctx: &Context, block: Ptr<BasicBlock>) -> bool {
!matches!(self.get_entry_block(ctx), Some(entry) if entry == block)
}
}
fn get_int_bin_operands(operand_attrs: &[Option<AttrObj>]) -> Option<(IntegerAttr, IntegerAttr)> {
assert!(operand_attrs.len() == 2);
let [Some(lhs), Some(rhs)] = operand_attrs else {
return None;
};
let lhs_int = lhs
.downcast_ref::<IntegerAttr>()
.expect("invalid operand type: typecheck before optimizing");
let rhs_int = rhs
.downcast_ref::<IntegerAttr>()
.expect("invalid operand type: typecheck before optimizing");
Some((lhs_int.clone(), rhs_int.clone()))
}
fn check_fold_int_bin_op_with_overflow(
operand_attrs: &[Option<AttrObj>],
flags: IntegerOverflowFlagsAttr,
combine: impl Fn(&APInt, &APInt) -> (APInt, bool, bool),
) -> Vec<Option<AttrObj>> {
let Some((lhs, rhs)) = get_int_bin_operands(operand_attrs) else {
return vec![None];
};
let (res, unsigned_overflow, signed_overflow) = combine(&lhs.value(), &rhs.value());
if (flags.nsw && signed_overflow) || (flags.nuw && unsigned_overflow) {
return vec![None];
}
let res = Box::new(IntegerAttr::new(lhs.get_type(), res)) as AttrObj;
vec![Some(res)]
}
fn check_fold_int_bin_op(
operand_attrs: &[Option<AttrObj>],
combine: impl Fn(&APInt, &APInt) -> APInt,
) -> Vec<Option<AttrObj>> {
let Some((lhs, rhs)) = get_int_bin_operands(operand_attrs) else {
return vec![None];
};
let res = Box::new(IntegerAttr::new(
lhs.get_type(),
combine(&lhs.value(), &rhs.value()),
)) as AttrObj;
vec![Some(res)]
}
#[op_interface_impl]
impl ConstFoldInterface for AddOp {
fn check_fold(&self, ctx: &Context, ops: &[Option<AttrObj>]) -> Vec<Option<AttrObj>> {
check_fold_int_bin_op_with_overflow(
ops,
self.integer_overflow_flag(ctx),
APInt::add_overflow,
)
}
fn fold_in_place(
&self,
ctx: &mut Context,
ops: &[Option<AttrObj>],
rw: &mut dyn Rewriter,
) -> IRStatus {
self.fold_with_materialization(ctx, ops, rw)
}
}
#[op_interface_impl]
impl ConstFoldInterface for SubOp {
fn check_fold(&self, ctx: &Context, ops: &[Option<AttrObj>]) -> Vec<Option<AttrObj>> {
check_fold_int_bin_op_with_overflow(
ops,
self.integer_overflow_flag(ctx),
APInt::sub_overflow,
)
}
fn fold_in_place(
&self,
ctx: &mut Context,
ops: &[Option<AttrObj>],
rw: &mut dyn Rewriter,
) -> IRStatus {
self.fold_with_materialization(ctx, ops, rw)
}
}
#[op_interface_impl]
impl ConstFoldInterface for MulOp {
fn check_fold(&self, ctx: &Context, ops: &[Option<AttrObj>]) -> Vec<Option<AttrObj>> {
check_fold_int_bin_op_with_overflow(
ops,
self.integer_overflow_flag(ctx),
APInt::mul_overflow,
)
}
fn fold_in_place(
&self,
ctx: &mut Context,
ops: &[Option<AttrObj>],
rw: &mut dyn Rewriter,
) -> IRStatus {
self.fold_with_materialization(ctx, ops, rw)
}
}
#[op_interface_impl]
impl ConstFoldInterface for ShlOp {
fn check_fold(&self, ctx: &Context, ops: &[Option<AttrObj>]) -> Vec<Option<AttrObj>> {
match get_int_bin_operands(ops) {
Some((lhs, rhs)) => {
let shamt = rhs.value();
let lhs_bw: usize = lhs.value().bw();
let lhs_bw: APInt = APInt::from_usize(lhs_bw, NonZero::new(lhs_bw).unwrap());
if shamt.ult(&lhs_bw) {
check_fold_int_bin_op_with_overflow(
ops,
self.integer_overflow_flag(ctx),
APInt::shl_overflow,
)
} else {
vec![None]
}
}
None => vec![None],
}
}
fn fold_in_place(
&self,
ctx: &mut Context,
ops: &[Option<AttrObj>],
rw: &mut dyn Rewriter,
) -> IRStatus {
self.fold_with_materialization(ctx, ops, rw)
}
}
#[op_interface_impl]
impl ConstFoldInterface for UDivOp {
fn check_fold(&self, _ctx: &Context, ops: &[Option<AttrObj>]) -> Vec<Option<AttrObj>> {
match get_int_bin_operands(ops) {
Some((_, rhs)) if rhs.value().is_zero() => vec![None],
_ => check_fold_int_bin_op(ops, APInt::udiv),
}
}
fn fold_in_place(
&self,
ctx: &mut Context,
ops: &[Option<AttrObj>],
rw: &mut dyn Rewriter,
) -> IRStatus {
self.fold_with_materialization(ctx, ops, rw)
}
}
#[op_interface_impl]
impl ConstFoldInterface for SDivOp {
fn check_fold(&self, _ctx: &Context, ops: &[Option<AttrObj>]) -> Vec<Option<AttrObj>> {
match get_int_bin_operands(ops) {
Some((_, rhs)) if rhs.value().is_zero() => vec![None],
_ => check_fold_int_bin_op(ops, APInt::sdiv),
}
}
fn fold_in_place(
&self,
ctx: &mut Context,
ops: &[Option<AttrObj>],
rw: &mut dyn Rewriter,
) -> IRStatus {
self.fold_with_materialization(ctx, ops, rw)
}
}
#[op_interface_impl]
impl ConstFoldInterface for URemOp {
fn check_fold(&self, _ctx: &Context, ops: &[Option<AttrObj>]) -> Vec<Option<AttrObj>> {
match get_int_bin_operands(ops) {
Some((_, rhs)) if rhs.value().is_zero() => vec![None],
_ => check_fold_int_bin_op(ops, APInt::urem),
}
}
fn fold_in_place(
&self,
ctx: &mut Context,
ops: &[Option<AttrObj>],
rw: &mut dyn Rewriter,
) -> IRStatus {
self.fold_with_materialization(ctx, ops, rw)
}
}
#[op_interface_impl]
impl ConstFoldInterface for SRemOp {
fn check_fold(&self, _ctx: &Context, ops: &[Option<AttrObj>]) -> Vec<Option<AttrObj>> {
match get_int_bin_operands(ops) {
Some((_, rhs)) if rhs.value().is_zero() => vec![None],
_ => check_fold_int_bin_op(ops, APInt::srem),
}
}
fn fold_in_place(
&self,
ctx: &mut Context,
ops: &[Option<AttrObj>],
rw: &mut dyn Rewriter,
) -> IRStatus {
self.fold_with_materialization(ctx, ops, rw)
}
}
#[op_interface_impl]
impl BranchOpFoldInterface for BrOp {
fn check_fold(&self, ctx: &Context, _operands: &[Option<AttrObj>]) -> Vec<Ptr<BasicBlock>> {
self.get_operation().deref(ctx).successors().collect()
}
fn fold_in_place(
&self,
_ctx: &mut Context,
_ops: &[Option<AttrObj>],
_rw: &mut dyn Rewriter,
) -> IRStatus {
IRStatus::Unchanged
}
}
impl CondBrOp {
fn possible_successor_indices(
&self,
ctx: &Context,
operands: &[Option<AttrObj>],
) -> Vec<usize> {
let Some(cond_attr) = operands.first().unwrap().as_ref() else {
let num_successors = self.get_operation().deref(ctx).successors().count();
return (0..num_successors).collect();
};
let cond_int = cond_attr
.downcast_ref::<IntegerAttr>()
.expect("CondBrOp condition operand must be an IntegerAttr");
let taken = if cond_int.value().is_zero() { 1 } else { 0 };
vec![taken]
}
}
#[op_interface_impl]
impl BranchOpFoldInterface for CondBrOp {
fn check_fold(&self, ctx: &Context, operands: &[Option<AttrObj>]) -> Vec<Ptr<BasicBlock>> {
let successors: Vec<Ptr<BasicBlock>> =
self.get_operation().deref(ctx).successors().collect();
self.possible_successor_indices(ctx, operands)
.iter()
.map(|ind| successors[*ind])
.collect()
}
fn fold_in_place(
&self,
ctx: &mut Context,
ops: &[Option<AttrObj>],
rewriter: &mut dyn Rewriter,
) -> IRStatus {
let possible_successor_indices = self.possible_successor_indices(ctx, ops);
if possible_successor_indices.len() != 1 {
return IRStatus::Unchanged;
};
let successor_ind = possible_successor_indices[0];
let successors: Vec<Ptr<BasicBlock>> =
self.get_operation().deref(ctx).successors().collect();
let new_op = BrOp::new(
ctx,
successors[successor_ind],
self.successor_operands(ctx, successor_ind),
)
.get_operation();
let old_op = self.get_operation();
rewriter.insert_operation(ctx, new_op);
rewriter.replace_operation(ctx, old_op, new_op);
IRStatus::Changed
}
}
#[op_interface_impl]
impl BranchOpFoldInterface for SwitchOp {
fn check_fold(&self, ctx: &Context, operands: &[Option<AttrObj>]) -> Vec<Ptr<BasicBlock>> {
let successors: Vec<Ptr<BasicBlock>> =
self.get_operation().deref(ctx).successors().collect();
let Some(cond_attr) = operands.first().and_then(|o| o.as_ref()) else {
return successors;
};
let cond_int = cond_attr
.downcast_ref::<IntegerAttr>()
.expect("Switch condition operand must be an IntegerAttr")
.value();
let case_values = self
.get_attr_switch_case_values(ctx)
.expect("SwitchOp missing case values attribute");
let taken = case_values
.0
.iter()
.position(|case| case.value() == cond_int)
.map(|i| i + 1)
.unwrap_or(0);
vec![successors[taken]]
}
fn fold_in_place(
&self,
ctx: &mut Context,
ops: &[Option<AttrObj>],
rewriter: &mut dyn Rewriter,
) -> IRStatus {
let Some(cond_attr) = ops.first().unwrap().as_ref() else {
return IRStatus::Unchanged;
};
let cond_int = cond_attr
.downcast_ref::<IntegerAttr>()
.expect("Switch condition operand must be an IntegerAttr")
.value();
let successor_ind = {
let case_values = self
.get_attr_switch_case_values(ctx)
.expect("SwitchOp missing case values attribute");
case_values
.0
.iter()
.position(|case| case.value() == cond_int)
.map(|i| i + 1)
.unwrap_or(0)
};
let successors: Vec<Ptr<BasicBlock>> =
self.get_operation().deref(ctx).successors().collect();
let new_op = BrOp::new(
ctx,
successors[successor_ind],
self.successor_operands(ctx, successor_ind),
)
.get_operation();
let old_op = self.get_operation();
rewriter.insert_operation(ctx, new_op);
rewriter.replace_operation(ctx, old_op, new_op);
IRStatus::Changed
}
}