use core::panic;
use crate::common::{ReturnOp, const_ret_in_mod};
use common::ConstantOp;
use expect_test::expect;
use pliron::{
builtin::{
attributes::IntegerAttr,
op_interfaces::{
IsolatedFromAboveInterface, NOpdsInterface, NResultsInterface, OneOpdInterface,
OneRegionInterface, OneResultInterface, SingleBlockRegionInterface, SymbolOpInterface,
},
ops::FuncOp,
types::{IntegerType, Signedness},
},
common_traits::Verify,
context::{Context, Ptr},
identifier::Identifier,
irbuild::{
inserter::{BlockInsertionPoint, Inserter, OpInsertionPoint},
match_rewrite::{MatchRewrite, MatchRewriter, apply_match_rewrite},
rewriter::{Rewriter, ScopedRewriter},
},
op::Op,
operation::Operation,
printable::Printable,
result::Result,
r#type::TypeObj,
value::Value,
};
use pliron::derive::{pliron_op, pliron_type};
#[cfg(target_family = "wasm")]
use wasm_bindgen_test::*;
mod common;
#[test]
#[cfg_attr(target_family = "wasm", wasm_bindgen_test)]
fn replace_c0_with_c1() -> Result<()> {
let ctx = &mut Context::new();
let (module_op, _, const_op, _) = const_ret_in_mod(ctx).unwrap();
assert!(
const_op
.get_value(ctx)
.downcast_ref::<IntegerAttr>()
.unwrap()
.value()
.to_u64()
== 0
);
struct ReplaceC0WithC1;
impl MatchRewrite for ReplaceC0WithC1 {
fn r#match(&mut self, ctx: &Context, op: Ptr<Operation>) -> bool {
if let Some(const_op) = Operation::get_op::<ConstantOp>(op, ctx) {
let val = const_op.get_value(ctx);
return val
.downcast_ref::<IntegerAttr>()
.is_some_and(|int_attr| int_attr.value().to_u64() < 2);
}
false
}
fn rewrite(
&mut self,
ctx: &mut Context,
rewriter: &mut MatchRewriter,
op: Ptr<Operation>,
) -> Result<()> {
let Some(const_op) =
Operation::get_op::<ConstantOp>(op, ctx).map(|co| co.get_value(ctx))
else {
panic!("Expected ConstantOp");
};
let cur_val = const_op
.downcast_ref::<IntegerAttr>()
.unwrap()
.value()
.to_u64();
if cur_val >= 2 {
panic!("Expected match only on constant value less than 2");
}
let const1_op = ConstantOp::new(ctx, cur_val + 1).get_operation();
rewriter.insert_operation(ctx, const1_op);
rewriter.replace_operation(ctx, op, const1_op);
Ok(())
}
}
apply_match_rewrite(ctx, ReplaceC0WithC1, module_op.get_operation())?;
module_op.get_operation().verify(ctx)?;
let printed = format!("{}", module_op.disp(ctx));
expect![[r#"
builtin.module @bar
{
^block1v1():
builtin.func @foo: builtin.function <()->(builtin.integer si64)>
{
^entry_block2v1():
op3v3_res0 = test.constant builtin.integer <2: si64>;
test.return op3v3_res0
}
}"#]]
.assert_eq(&printed);
Ok(())
}
#[pliron_op(
name = "test.global",
format,
interfaces = [
IsolatedFromAboveInterface,
NOpdsInterface<0>,
NResultsInterface<1>,
OneResultInterface,
SymbolOpInterface,
SingleBlockRegionInterface,
],
attributes = (test_global_op_const_val: IntegerAttr),
verifier = "succ",
)]
pub struct GlobalOp;
impl GlobalOp {
pub fn new(ctx: &mut Context, name: Identifier, val: IntegerAttr) -> Self {
let op = Operation::new(
ctx,
Self::get_concrete_op_info(),
vec![PointerType::get(ctx).into()],
vec![],
vec![],
0,
);
let op = GlobalOp { op };
op.set_symbol_name(ctx, name);
op.set_attr_test_global_op_const_val(ctx, val);
op
}
}
#[pliron_type(name = "test.ptr", format, generate_get = true, verifier = "succ")]
#[derive(Hash, PartialEq, Eq, Debug)]
pub struct PointerType;
#[pliron_op(
name = "test.load",
format = "$0 ` ` : ` type($0)",
interfaces = [OneResultInterface, OneOpdInterface, NResultsInterface<1>, NOpdsInterface<1>],
verifier = "succ",
)]
pub struct LoadOp;
impl LoadOp {
pub fn new(ctx: &mut Context, ptr: Value, res_ty: Ptr<TypeObj>) -> Self {
LoadOp {
op: Operation::new(
ctx,
Self::get_concrete_op_info(),
vec![res_ty],
vec![ptr],
vec![],
0,
),
}
}
}
#[test]
#[cfg_attr(target_family = "wasm", wasm_bindgen_test)]
fn scoped_rewriter_test() -> Result<()> {
let ctx = &mut Context::new();
let (module_op, _func_op, _, _) = const_ret_in_mod(ctx).unwrap();
struct ConstToGlobal;
impl MatchRewrite for ConstToGlobal {
fn r#match(&mut self, ctx: &Context, op: Ptr<Operation>) -> bool {
Operation::get_op::<ConstantOp>(op, ctx).is_some()
}
fn rewrite(
&mut self,
ctx: &mut Context,
rewriter: &mut MatchRewriter,
op: Ptr<Operation>,
) -> Result<()> {
let const_op = Operation::get_op::<ConstantOp>(op, ctx).unwrap();
let val = const_op.get_value(ctx).downcast::<IntegerAttr>().unwrap();
let func_op = op.deref(ctx).get_parent_op(ctx).unwrap();
let module_op = func_op.deref(ctx).get_parent_op(ctx).unwrap();
let module_op =
Operation::get_op::<pliron::builtin::ops::ModuleOp>(module_op, ctx).unwrap();
let name =
Identifier::try_from("global_".to_string() + &val.value().to_string(10, true))
.unwrap();
let global_op = GlobalOp::new(ctx, name, *val);
{
let mut module_inserter = ScopedRewriter::new(
rewriter,
OpInsertionPoint::AtBlockStart(module_op.get_body(ctx, 0)),
);
module_inserter.insert_operation(ctx, global_op.get_operation());
}
let int_ty = IntegerType::get(ctx, 32, Signedness::Signed).into();
let load_op = LoadOp::new(ctx, global_op.get_result(ctx), int_ty).get_operation();
rewriter.insert_operation(ctx, load_op);
rewriter.replace_operation(ctx, op, load_op);
Ok(())
}
}
apply_match_rewrite(ctx, ConstToGlobal, module_op.get_operation())?;
module_op.get_operation().verify(ctx)?;
let printed = format!("{}", module_op.disp(ctx));
expect![[r#"
builtin.module @bar
{
^block1v1():
op5v1_res0 = test.global () [] [test_global_op_const_val: builtin.integer <0: si64>, sym_name: builtin.identifier global_0]: <() -> (test.ptr )>;
builtin.func @foo: builtin.function <()->(builtin.integer si64)>
{
^entry_block2v1():
op6v1_res0 = test.load op5v1_res0 ;
test.return op6v1_res0
}
}"#]].assert_eq(&printed);
Ok(())
}
#[test]
#[cfg_attr(target_family = "wasm", wasm_bindgen_test)]
fn erase_func_with_const_zero() -> Result<()> {
let ctx = &mut Context::new();
let (module_op, _func_op, _, _) = const_ret_in_mod(ctx).unwrap();
struct EraseFunc;
impl MatchRewrite for EraseFunc {
fn r#match(&mut self, ctx: &Context, op: Ptr<Operation>) -> bool {
if let Some(const_op) = Operation::get_op::<ConstantOp>(op, ctx) {
let val = const_op.get_value(ctx);
return val
.downcast_ref::<IntegerAttr>()
.is_some_and(|int_attr| int_attr.value().to_u64() == 0);
}
false
}
fn rewrite(
&mut self,
ctx: &mut Context,
rewriter: &mut MatchRewriter,
op: Ptr<Operation>,
) -> Result<()> {
let const1_op = ConstantOp::new(ctx, 1).get_operation();
rewriter.insert_operation(ctx, const1_op);
let const0_op = ConstantOp::new(ctx, 0).get_operation();
rewriter.insert_operation(ctx, const0_op);
let func_op = op.deref(ctx).get_parent_op(ctx).unwrap();
rewriter.erase_operation(ctx, func_op);
Ok(())
}
}
apply_match_rewrite(ctx, EraseFunc, module_op.get_operation())?;
module_op.get_operation().verify(ctx)?;
let printed = format!("{}", module_op.disp(ctx));
expect![[r#"
builtin.module @bar
{
^block1v1():
}"#]]
.assert_eq(&printed);
Ok(())
}
#[test]
#[cfg_attr(target_family = "wasm", wasm_bindgen_test)]
fn split_block_after_const_zero() -> Result<()> {
let ctx = &mut Context::new();
let (module_op, _func_op, _, _) = const_ret_in_mod(ctx).unwrap();
struct SplitBlockAfterConstZero;
impl MatchRewrite for SplitBlockAfterConstZero {
fn r#match(&mut self, ctx: &Context, op: Ptr<Operation>) -> bool {
if let Some(const_op) = Operation::get_op::<ConstantOp>(op, ctx) {
let val = const_op.get_value(ctx);
return val
.downcast_ref::<IntegerAttr>()
.is_some_and(|int_attr| int_attr.value().to_u64() == 0);
}
false
}
fn rewrite(
&mut self,
ctx: &mut Context,
rewriter: &mut MatchRewriter,
op: Ptr<Operation>,
) -> Result<()> {
let const0_op = Operation::get_op::<ConstantOp>(op, ctx).unwrap();
let block = op.deref(ctx).get_parent_block().unwrap();
let new_block = rewriter.split_block(ctx, block, OpInsertionPoint::AfterOperation(op));
let const1_op = ConstantOp::new(ctx, 1).get_operation();
let const1_result = const1_op.deref(ctx).get_result(0);
rewriter.set_insertion_point(OpInsertionPoint::AtBlockStart(new_block));
rewriter.insert_operation(ctx, const1_op);
const0_op
.get_result(ctx)
.replace_all_uses_with(ctx, &const1_result);
let ret = ReturnOp::new(ctx, const0_op.get_result(ctx)).get_operation();
{
ScopedRewriter::new(rewriter, OpInsertionPoint::AtBlockEnd(block))
.insert_operation(ctx, ret);
}
Ok(())
}
}
apply_match_rewrite(ctx, SplitBlockAfterConstZero, module_op.get_operation())?;
module_op.get_operation().verify(ctx)?;
let printed = format!("{}", module_op.disp(ctx));
expect![[r#"
builtin.module @bar
{
^block1v1():
builtin.func @foo: builtin.function <()->(builtin.integer si64)>
{
^entry_block2v1():
c0_op3v1_res0 = test.constant builtin.integer <0: si64>;
test.return c0_op3v1_res0
^entry_split_block3v1():
op5v1_res0 = test.constant builtin.integer <1: si64>;
test.return op5v1_res0
}
}"#]]
.assert_eq(&printed);
Ok(())
}
#[test]
#[cfg_attr(target_family = "wasm", wasm_bindgen_test)]
fn inline_region_on_const_zero() -> Result<()> {
let ctx = &mut Context::new();
let (module_op1, func_op1, _, _) = const_ret_in_mod(ctx).unwrap();
let (module_op2, _func_op2, _, _) = const_ret_in_mod(ctx).unwrap();
struct InlineRegionOnConstZero(FuncOp);
impl MatchRewrite for InlineRegionOnConstZero {
fn r#match(&mut self, ctx: &Context, op: Ptr<Operation>) -> bool {
if let Some(const_op) = Operation::get_op::<ConstantOp>(op, ctx) {
let val = const_op.get_value(ctx);
return val
.downcast_ref::<IntegerAttr>()
.is_some_and(|int_attr| int_attr.value().to_u64() == 0);
}
false
}
fn rewrite(
&mut self,
ctx: &mut Context,
rewriter: &mut MatchRewriter,
op: Ptr<Operation>,
) -> Result<()> {
let func_op = op.deref(ctx).get_parent_op(ctx).unwrap();
let region = func_op.deref(ctx).get_region(0);
rewriter.inline_region(
ctx,
region,
BlockInsertionPoint::AtRegionEnd(self.0.get_region(ctx)),
);
Ok(())
}
}
apply_match_rewrite(
ctx,
InlineRegionOnConstZero(func_op1),
module_op2.get_operation(),
)?;
module_op1.get_operation().verify(ctx)?;
module_op2.get_operation().verify(ctx)?;
let printed = format!("{}", module_op1.disp(ctx));
expect![[r#"
builtin.module @bar
{
^block1v1():
builtin.func @foo: builtin.function <()->(builtin.integer si64)>
{
^entry_block2v1():
c0_op3v1_res0 = test.constant builtin.integer <0: si64>;
test.return c0_op3v1_res0
^entry_block4v1():
c0_op7v1_res0 = test.constant builtin.integer <0: si64>;
test.return c0_op7v1_res0
}
}"#]]
.assert_eq(&printed);
let printed = format!("{}", module_op2.disp(ctx));
expect![[r#"
builtin.module @bar
{
^block3v1():
builtin.func @foo: builtin.function <()->(builtin.integer si64)>
{
}
}"#]]
.assert_eq(&printed);
Ok(())
}