use pliron::{
basic_block::BasicBlock,
builtin::{
op_interfaces::{
IsTerminatorInterface, NOpdsInterface, NResultsInterface, OneOpdInterface,
OneResultInterface, SingleBlockRegionInterface,
},
ops::ModuleOp,
types::{IntegerType, Signedness},
},
context::{Context, Ptr},
identifier::Identifier,
irbuild::{
dialect_conversion::{
DialectConversion, DialectConversionRewriter, OperandsInfo, apply_dialect_conversion,
},
inserter::Inserter,
rewriter::Rewriter,
},
op::Op,
operation::Operation,
result::Result,
r#type::Typed,
};
use pliron::derive::pliron_op;
#[cfg(target_family = "wasm")]
use wasm_bindgen_test::*;
#[pliron_op(
name = "test.producer",
format,
interfaces = [NOpdsInterface<0>, OneResultInterface, NResultsInterface<1>],
verifier = "succ",
)]
pub struct ProducerOp;
impl ProducerOp {
fn new(ctx: &mut Context, width: u32) -> Self {
let ty = IntegerType::get(ctx, width, Signedness::Signed).into();
let op = Operation::new(
ctx,
Self::get_concrete_op_info(),
vec![ty],
vec![],
vec![],
0,
);
Self { op }
}
}
#[pliron_op(
name = "test.consumer",
format,
interfaces = [OneOpdInterface, NOpdsInterface<1>, NResultsInterface<0>],
verifier = "succ",
)]
pub struct ConsumerOp;
impl ConsumerOp {
fn new(ctx: &mut Context, value: pliron::value::Value) -> Self {
let op = Operation::new(
ctx,
Self::get_concrete_op_info(),
vec![],
vec![value],
vec![],
0,
);
Self { op }
}
}
#[pliron_op(
name = "test.forward_to_succ",
format,
interfaces = [
IsTerminatorInterface,
OneOpdInterface,
NOpdsInterface<1>,
NResultsInterface<0>
],
verifier = "succ",
)]
pub struct ForwardToSuccOp;
impl ForwardToSuccOp {
fn new(ctx: &mut Context, value: pliron::value::Value, succ: Ptr<BasicBlock>) -> Self {
let op = Operation::new(
ctx,
Self::get_concrete_op_info(),
vec![],
vec![value],
vec![succ],
0,
);
Self { op }
}
}
#[derive(Default)]
struct WidthConversion {
saw_consumer: bool,
}
impl DialectConversion for WidthConversion {
fn can_convert_op(&mut self, ctx: &Context, op: Ptr<Operation>) -> bool {
if let Some(producer) = Operation::get_op::<ProducerOp>(op, ctx) {
let ty = producer.get_result(ctx).get_type(ctx);
let width = ty
.deref(ctx)
.downcast_ref::<IntegerType>()
.expect("expected integer type")
.width();
return width > 16;
}
Operation::get_op::<ConsumerOp>(op, ctx).is_some()
}
fn rewrite(
&mut self,
ctx: &mut Context,
rewriter: &mut DialectConversionRewriter,
op: Ptr<Operation>,
operands_info: &OperandsInfo,
) -> Result<()> {
if let Some(producer) = Operation::get_op::<ProducerOp>(op, ctx) {
let current_ty = producer.get_result(ctx).get_type(ctx);
let current_width = current_ty
.deref(ctx)
.downcast_ref::<IntegerType>()
.expect("expected integer type")
.width();
let converted = ProducerOp::new(ctx, current_width / 2).get_operation();
rewriter.insert_operation(ctx, converted);
rewriter.replace_operation(ctx, op, converted);
return Ok(());
}
if let Some(consumer) = Operation::get_op::<ConsumerOp>(op, ctx) {
let operand = consumer.get_operand(ctx);
let final_width = operand
.get_type(ctx)
.deref(ctx)
.downcast_ref::<IntegerType>()
.expect("expected integer type")
.width();
assert_eq!(final_width, 16);
let operand_type_history = operands_info.lookup_operand_history(operand);
assert_eq!(operand_type_history.len(), 2);
let previous_widths = operand_type_history
.iter()
.map(|ty| {
ty.deref(ctx)
.downcast_ref::<IntegerType>()
.expect("expected integer type")
.width()
})
.collect::<Vec<u32>>();
assert_eq!(previous_widths, vec![64, 32]);
self.saw_consumer = true;
}
Ok(())
}
}
#[test]
#[cfg_attr(target_family = "wasm", wasm_bindgen_test)]
fn dialect_conversion_defs_before_uses() -> Result<()> {
let ctx = &mut Context::new();
let module = ModuleOp::new(
ctx,
Identifier::try_from("dialect_conversion_test").unwrap(),
);
let body = module.get_body(ctx, 0);
let producer = ProducerOp::new(ctx, 64);
producer.get_operation().insert_at_back(body, ctx);
let consumer = ConsumerOp::new(ctx, producer.get_result(ctx));
consumer.get_operation().insert_at_back(body, ctx);
let mut conversion = WidthConversion::default();
apply_dialect_conversion(ctx, &mut conversion, module.get_operation())?;
assert!(conversion.saw_consumer);
Ok(())
}
#[derive(Default)]
struct ConsumerOnlyConversion {
saw_consumer: bool,
}
impl DialectConversion for ConsumerOnlyConversion {
fn can_convert_op(&mut self, ctx: &Context, op: Ptr<Operation>) -> bool {
Operation::get_op::<ConsumerOp>(op, ctx).is_some()
}
fn can_convert_type(&mut self, ctx: &Context, ty: Ptr<pliron::r#type::TypeObj>) -> bool {
ty.deref(ctx)
.downcast_ref::<IntegerType>()
.is_some_and(|int_ty| int_ty.width() > 16)
}
fn convert_type(
&mut self,
ctx: &mut Context,
ty: Ptr<pliron::r#type::TypeObj>,
) -> Result<Ptr<pliron::r#type::TypeObj>> {
let width = ty
.deref(ctx)
.downcast_ref::<IntegerType>()
.expect("expected integer type")
.width();
Ok(IntegerType::get(ctx, width / 2, Signedness::Signed).into())
}
fn rewrite(
&mut self,
ctx: &mut Context,
_rewriter: &mut DialectConversionRewriter,
op: Ptr<Operation>,
operands_info: &OperandsInfo,
) -> Result<()> {
let consumer = Operation::get_op::<ConsumerOp>(op, ctx).expect("expected consumer op");
let operand = consumer.get_operand(ctx);
let final_width = operand
.get_type(ctx)
.deref(ctx)
.downcast_ref::<IntegerType>()
.expect("expected integer type")
.width();
assert_eq!(final_width, 16);
let previous_width = operands_info
.lookup_most_recent_of_type::<IntegerType>(ctx, operand)
.expect("Previous integer type missing");
assert_eq!(previous_width.width(), 32);
self.saw_consumer = true;
Ok(())
}
}
#[derive(Default)]
struct ForwardOnlyConversion {
saw_forward: bool,
}
impl DialectConversion for ForwardOnlyConversion {
fn can_convert_op(&mut self, ctx: &Context, op: Ptr<Operation>) -> bool {
Operation::get_op::<ForwardToSuccOp>(op, ctx).is_some()
}
fn can_convert_type(&mut self, ctx: &Context, ty: Ptr<pliron::r#type::TypeObj>) -> bool {
ty.deref(ctx)
.downcast_ref::<IntegerType>()
.is_some_and(|int_ty| int_ty.width() > 16)
}
fn convert_type(
&mut self,
ctx: &mut Context,
ty: Ptr<pliron::r#type::TypeObj>,
) -> Result<Ptr<pliron::r#type::TypeObj>> {
let width = ty
.deref(ctx)
.downcast_ref::<IntegerType>()
.expect("expected integer type")
.width();
Ok(IntegerType::get(ctx, width / 2, Signedness::Signed).into())
}
fn rewrite(
&mut self,
ctx: &mut Context,
_rewriter: &mut DialectConversionRewriter,
op: Ptr<Operation>,
_operands_info: &OperandsInfo,
) -> Result<()> {
let succ = op.deref(ctx).get_successor(0);
let succ_arg = succ.deref(ctx).get_argument(0);
let succ_arg_width = succ_arg
.get_type(ctx)
.deref(ctx)
.downcast_ref::<IntegerType>()
.expect("expected integer type")
.width();
assert_eq!(succ_arg_width, 16);
self.saw_forward = true;
Ok(())
}
}
#[test]
#[cfg_attr(target_family = "wasm", wasm_bindgen_test)]
fn dialect_conversion_block_arg_type_conversion() -> Result<()> {
let ctx = &mut Context::new();
let module = ModuleOp::new(
ctx,
Identifier::try_from("block_arg_type_conversion").unwrap(),
);
let body = module.get_body(ctx, 0);
let i64_ty = IntegerType::get(ctx, 64, Signedness::Signed).into();
let arg_idx = body.deref_mut(ctx).add_argument(i64_ty);
let block_arg = body.deref(ctx).get_argument(arg_idx);
let consumer = ConsumerOp::new(ctx, block_arg);
consumer.get_operation().insert_at_back(body, ctx);
let mut conversion = ConsumerOnlyConversion::default();
apply_dialect_conversion(ctx, &mut conversion, module.get_operation())?;
assert!(conversion.saw_consumer);
Ok(())
}
#[test]
#[cfg_attr(target_family = "wasm", wasm_bindgen_test)]
fn dialect_conversion_successor_block_arg_type_conversion_without_uses() -> Result<()> {
let ctx = &mut Context::new();
let module = ModuleOp::new(
ctx,
Identifier::try_from("block_arg_successor_type_conversion").unwrap(),
);
let pred_block = module.get_body(ctx, 0);
let region = pred_block
.deref(ctx)
.get_parent_region()
.expect("module body block must be in a region");
let i64_ty = IntegerType::get(ctx, 64, Signedness::Signed).into();
let producer = ProducerOp::new(ctx, 64);
producer.get_operation().insert_at_back(pred_block, ctx);
let succ_block = BasicBlock::new(
ctx,
Some(Identifier::try_from("succ").unwrap()),
vec![i64_ty],
);
succ_block.insert_at_back(region, ctx);
let forward = ForwardToSuccOp::new(ctx, producer.get_result(ctx), succ_block);
forward.get_operation().insert_at_back(pred_block, ctx);
let mut conversion = ForwardOnlyConversion::default();
apply_dialect_conversion(ctx, &mut conversion, module.get_operation())?;
assert!(conversion.saw_forward);
let succ_arg = succ_block.deref(ctx).get_argument(0);
let succ_arg_width = succ_arg
.get_type(ctx)
.deref(ctx)
.downcast_ref::<IntegerType>()
.expect("expected integer type")
.width();
assert_eq!(succ_arg_width, 16);
Ok(())
}