use std::sync::Arc;
use crate::patterns;
use morok_dtype::DType;
use morok_ir::pattern::RewriteResult;
use morok_ir::{BinaryOp, ConstValue, Op, UOp};
fn binary(op: BinaryOp, lhs: Arc<UOp>, rhs: Arc<UOp>) -> Arc<UOp> {
let dtype = lhs.dtype();
UOp::new(Op::Binary(op, lhs, rhs), dtype)
}
#[test]
fn test_simple_add_zero_pattern() {
let matcher = patterns! {
Add(x, Const(0)) ~> x
};
let x = UOp::native_const(42i32);
let zero = UOp::native_const(0i32);
let add = binary(BinaryOp::Add, x.clone(), zero);
let result = matcher.rewrite(&add, &mut ());
match result {
RewriteResult::Rewritten(rewritten) => {
assert!(Arc::ptr_eq(&rewritten, &x), "Should rewrite to x");
}
_ => panic!("Expected rewrite to succeed"),
}
}
#[test]
fn test_mul_one_pattern() {
let matcher = patterns! {
Mul(x, Const(1)) ~> x
};
let x = UOp::native_const(42i32);
let one = UOp::native_const(1i32);
let mul = binary(BinaryOp::Mul, x.clone(), one);
let result = matcher.rewrite(&mul, &mut ());
match result {
RewriteResult::Rewritten(rewritten) => {
assert!(Arc::ptr_eq(&rewritten, &x), "Should rewrite to x");
}
_ => panic!("Expected rewrite to succeed"),
}
}
#[test]
fn test_binding_pattern() {
let matcher = patterns! {
Mul(_, zero @ Const(0)) ~> zero
};
let x = UOp::native_const(42i32);
let zero = UOp::native_const(0i32);
let mul = binary(BinaryOp::Mul, x, zero.clone());
let result = matcher.rewrite(&mul, &mut ());
match result {
RewriteResult::Rewritten(rewritten) => {
assert!(Arc::ptr_eq(&rewritten, &zero), "Should rewrite to zero");
}
_ => panic!("Expected rewrite to succeed"),
}
}
#[test]
fn test_multiple_patterns() {
let matcher = patterns! {
Add(x, Const(0)) ~> x,
Mul(x, Const(1)) ~> x,
Mul(_, zero @ Const(0)) ~> zero,
};
let x = UOp::native_const(42i32);
let zero = UOp::native_const(0i32);
let one = UOp::native_const(1i32);
let add_zero = binary(BinaryOp::Add, x.clone(), zero.clone());
match matcher.rewrite(&add_zero, &mut ()) {
RewriteResult::Rewritten(r) => assert!(Arc::ptr_eq(&r, &x)),
_ => panic!("Add(x, 0) should rewrite to x"),
}
let mul_one = binary(BinaryOp::Mul, x.clone(), one);
match matcher.rewrite(&mul_one, &mut ()) {
RewriteResult::Rewritten(r) => assert!(Arc::ptr_eq(&r, &x)),
_ => panic!("Mul(x, 1) should rewrite to x"),
}
let mul_zero = binary(BinaryOp::Mul, x, zero.clone());
match matcher.rewrite(&mul_zero, &mut ()) {
RewriteResult::Rewritten(r) => assert!(Arc::ptr_eq(&r, &zero)),
_ => panic!("Mul(x, 0) should rewrite to 0"),
}
}
#[test]
fn test_no_match() {
let matcher = patterns! {
Add(x, Const(0)) ~> x
};
let x = UOp::native_const(42i32);
let one = UOp::native_const(1i32);
let add = binary(BinaryOp::Add, x, one);
let result = matcher.rewrite(&add, &mut ());
assert!(matches!(result, RewriteResult::NoMatch), "x + 1 should not match x + 0 pattern");
}
#[test]
fn test_pattern_matcher_composition() {
let pm1 = patterns! {
Add(x, Const(0)) ~> x
};
let pm2 = patterns! {
Mul(x, Const(1)) ~> x
};
let combined = pm1 + pm2;
let x = UOp::native_const(42i32);
let zero = UOp::native_const(0i32);
let one = UOp::native_const(1i32);
let add_zero = binary(BinaryOp::Add, x.clone(), zero);
match combined.rewrite(&add_zero, &mut ()) {
RewriteResult::Rewritten(r) => assert!(Arc::ptr_eq(&r, &x)),
_ => panic!("Combined matcher should handle Add(x, 0)"),
}
let mul_one = binary(BinaryOp::Mul, x.clone(), one);
match combined.rewrite(&mul_one, &mut ()) {
RewriteResult::Rewritten(r) => assert!(Arc::ptr_eq(&r, &x)),
_ => panic!("Combined matcher should handle Mul(x, 1)"),
}
}
#[test]
fn test_complex_guard_with_block() {
let matcher = patterns! {
Add(x, c) if {
match c.op() {
Op::Const(cv) => matches!(cv.0, ConstValue::Int(0)) || matches!(cv.0, ConstValue::Float(f) if f == 0.0),
_ => false,
}
} ~> x
};
let x = UOp::native_const(42i32);
let zero_int = UOp::native_const(0i32);
let add_zero_int = binary(BinaryOp::Add, x.clone(), zero_int);
match matcher.rewrite(&add_zero_int, &mut ()) {
RewriteResult::Rewritten(r) => assert!(Arc::ptr_eq(&r, &x)),
_ => panic!("Add(x, 0) should match with complex guard"),
}
let x_f32 = UOp::native_const(42.0f32);
let zero_float = UOp::native_const(0.0f32);
let add_zero_float = binary(BinaryOp::Add, x_f32.clone(), zero_float);
match matcher.rewrite(&add_zero_float, &mut ()) {
RewriteResult::Rewritten(r) => assert!(Arc::ptr_eq(&r, &x_f32)),
_ => panic!("Add(x, 0.0) should match with complex guard"),
}
let one = UOp::native_const(1i32);
let add_one = binary(BinaryOp::Add, x.clone(), one);
match matcher.rewrite(&add_one, &mut ()) {
RewriteResult::NoMatch => {} _ => panic!("Add(x, 1) should NOT match zero guard"),
}
}
#[test]
fn test_guard_with_pointer_equality() {
let matcher = patterns! {
And(x, y) if Arc::ptr_eq(x, y) ~> x
};
let a = UOp::native_const(42i32);
let and_same = a.try_and_op(&a).unwrap();
match matcher.rewrite(&and_same, &mut ()) {
RewriteResult::Rewritten(r) => assert!(Arc::ptr_eq(&r, &a)),
_ => panic!("And(x, x) should rewrite to x"),
}
let b = UOp::native_const(99i32); let and_diff = a.try_and_op(&b).unwrap();
match matcher.rewrite(&and_diff, &mut ()) {
RewriteResult::NoMatch => {} _ => panic!("And(a, b) with different pointers should NOT match"),
}
}
#[test]
fn test_auto_ptr_eq_duplicate_variable() {
let matcher = patterns! {
And(x, x) ~> x
};
let a = UOp::native_const(42i32);
let and_same = a.try_and_op(&a).unwrap();
match matcher.rewrite(&and_same, &mut ()) {
RewriteResult::Rewritten(r) => assert!(Arc::ptr_eq(&r, &a), "And(x, x) should rewrite to x"),
_ => panic!("And(x, x) should match"),
}
let b = UOp::native_const(99i32);
let and_diff = a.try_and_op(&b).unwrap();
match matcher.rewrite(&and_diff, &mut ()) {
RewriteResult::NoMatch => {} _ => panic!("And(a, b) with different pointers should NOT match"),
}
}
#[test]
fn test_auto_ptr_eq_three_args() {
let matcher = patterns! {
Where(x, x, x) ~> x
};
let a = UOp::const_(DType::Bool, ConstValue::Bool(true));
let b = UOp::const_(DType::Bool, ConstValue::Bool(false));
let where_same = UOp::try_where(a.clone(), a.clone(), a.clone()).unwrap();
match matcher.rewrite(&where_same, &mut ()) {
RewriteResult::Rewritten(r) => assert!(Arc::ptr_eq(&r, &a), "Where(x, x, x) should rewrite to x"),
_ => panic!("Where(x, x, x) should match"),
}
let where_diff = UOp::try_where(a.clone(), a.clone(), b.clone()).unwrap();
match matcher.rewrite(&where_diff, &mut ()) {
RewriteResult::NoMatch => {} _ => panic!("Where(a, a, b) should NOT match"),
}
let where_middle_diff = UOp::try_where(a.clone(), b.clone(), a.clone()).unwrap();
match matcher.rewrite(&where_middle_diff, &mut ()) {
RewriteResult::NoMatch => {} _ => panic!("Where(a, b, a) should NOT match Where(x, x, x)"),
}
let where_first_diff = UOp::try_where(b.clone(), a.clone(), a.clone()).unwrap();
match matcher.rewrite(&where_first_diff, &mut ()) {
RewriteResult::NoMatch => {} _ => panic!("Where(b, a, a) should NOT match Where(x, x, x)"),
}
}
#[test]
fn test_special_constant_zero() {
let matcher = patterns! {
Add(x, @zero) ~> x
};
let x_int = UOp::native_const(42i32);
let zero_int = UOp::native_const(0i32);
let add_zero_int = binary(BinaryOp::Add, x_int.clone(), zero_int);
match matcher.rewrite(&add_zero_int, &mut ()) {
RewriteResult::Rewritten(r) => assert!(Arc::ptr_eq(&r, &x_int)),
_ => panic!("Add(x, @zero) should match int 0"),
}
let x_f32 = UOp::native_const(42.0f32);
let zero_float = UOp::native_const(0.0f32);
let add_zero_float = binary(BinaryOp::Add, x_f32.clone(), zero_float);
match matcher.rewrite(&add_zero_float, &mut ()) {
RewriteResult::Rewritten(r) => assert!(Arc::ptr_eq(&r, &x_f32)),
_ => panic!("Add(x, @zero) should match float 0.0"),
}
let one = UOp::native_const(1i32);
let add_one = binary(BinaryOp::Add, x_int.clone(), one);
match matcher.rewrite(&add_one, &mut ()) {
RewriteResult::NoMatch => {} _ => panic!("Add(x, 1) should NOT match @zero"),
}
}
#[test]
fn test_special_constant_one() {
let matcher = patterns! {
Mul(x, @one) ~> x
};
let x_int = UOp::native_const(42i32);
let one_int = UOp::native_const(1i32);
let mul_one_int = binary(BinaryOp::Mul, x_int.clone(), one_int);
match matcher.rewrite(&mul_one_int, &mut ()) {
RewriteResult::Rewritten(r) => assert!(Arc::ptr_eq(&r, &x_int)),
_ => panic!("Mul(x, @one) should match int 1"),
}
let x_f32 = UOp::native_const(42.0f32);
let one_float = UOp::native_const(1.0f32);
let mul_one_float = binary(BinaryOp::Mul, x_f32.clone(), one_float);
match matcher.rewrite(&mul_one_float, &mut ()) {
RewriteResult::Rewritten(r) => assert!(Arc::ptr_eq(&r, &x_f32)),
_ => panic!("Mul(x, @one) should match float 1.0"),
}
let two = UOp::native_const(2i32);
let mul_two = binary(BinaryOp::Mul, x_int.clone(), two);
match matcher.rewrite(&mul_two, &mut ()) {
RewriteResult::NoMatch => {} _ => panic!("Mul(x, 2) should NOT match @one"),
}
}
#[test]
fn test_special_constant_with_binding() {
let matcher = patterns! {
Mul(_, zero @ @zero) ~> zero
};
let x = UOp::native_const(42i32);
let zero = UOp::native_const(0i32);
let mul_zero = binary(BinaryOp::Mul, x, zero.clone());
match matcher.rewrite(&mul_zero, &mut ()) {
RewriteResult::Rewritten(r) => assert!(Arc::ptr_eq(&r, &zero)),
_ => panic!("Mul(_, zero @ @zero) should return zero"),
}
let x_f32 = UOp::native_const(42.0f32);
let zero_f32 = UOp::native_const(0.0f32);
let mul_zero_f32 = binary(BinaryOp::Mul, x_f32, zero_f32.clone());
match matcher.rewrite(&mul_zero_f32, &mut ()) {
RewriteResult::Rewritten(r) => assert!(Arc::ptr_eq(&r, &zero_f32)),
_ => panic!("Mul(_, zero @ @zero) should return float zero"),
}
}
#[test]
fn test_identity_patterns_with_special_constants() {
let matcher = patterns! {
Add(x, @zero) ~> x,
Add(@zero, x) ~> x,
Mul(x, @one) ~> x,
Mul(@one, x) ~> x,
Mul(_, zero @ @zero) ~> zero,
Mul(zero @ @zero, _) ~> zero,
};
let x = UOp::native_const(42i32);
let zero = UOp::native_const(0i32);
let one = UOp::native_const(1i32);
let add_x_zero = binary(BinaryOp::Add, x.clone(), zero.clone());
match matcher.rewrite(&add_x_zero, &mut ()) {
RewriteResult::Rewritten(r) => assert!(Arc::ptr_eq(&r, &x)),
_ => panic!("Add(x, @zero) failed"),
}
let add_zero_x = binary(BinaryOp::Add, zero.clone(), x.clone());
match matcher.rewrite(&add_zero_x, &mut ()) {
RewriteResult::Rewritten(r) => assert!(Arc::ptr_eq(&r, &x)),
_ => panic!("Add(@zero, x) failed"),
}
let mul_x_one = binary(BinaryOp::Mul, x.clone(), one.clone());
match matcher.rewrite(&mul_x_one, &mut ()) {
RewriteResult::Rewritten(r) => assert!(Arc::ptr_eq(&r, &x)),
_ => panic!("Mul(x, @one) failed"),
}
let mul_one_x = binary(BinaryOp::Mul, one.clone(), x.clone());
match matcher.rewrite(&mul_one_x, &mut ()) {
RewriteResult::Rewritten(r) => assert!(Arc::ptr_eq(&r, &x)),
_ => panic!("Mul(@one, x) failed"),
}
let mul_x_zero = binary(BinaryOp::Mul, x.clone(), zero.clone());
match matcher.rewrite(&mul_x_zero, &mut ()) {
RewriteResult::Rewritten(r) => assert!(Arc::ptr_eq(&r, &zero)),
_ => panic!("Mul(_, @zero) failed"),
}
let mul_zero_x = binary(BinaryOp::Mul, zero.clone(), x.clone());
match matcher.rewrite(&mul_zero_x, &mut ()) {
RewriteResult::Rewritten(r) => assert!(Arc::ptr_eq(&r, &zero)),
_ => panic!("Mul(@zero, _) failed"),
}
}
#[test]
fn test_struct_field_extraction() {
let matcher = patterns! {
Cast { src: x, dtype } if *dtype == DType::Float32 ~> x
};
let x_int = UOp::native_const(42i32);
let cast_to_f32 = x_int.cast(DType::Float32);
match matcher.rewrite(&cast_to_f32, &mut ()) {
RewriteResult::Rewritten(r) => assert!(Arc::ptr_eq(&r, &x_int)),
_ => panic!("Cast {{ src: x, dtype }} with dtype == Float32 should match"),
}
let cast_to_i64 = x_int.cast(DType::Int64);
match matcher.rewrite(&cast_to_i64, &mut ()) {
RewriteResult::NoMatch => {} _ => panic!("Cast {{ src: x, dtype }} with dtype == Int64 should NOT match Float32 guard"),
}
}
#[test]
fn test_struct_field_extraction_permute() {
let matcher = patterns! {
Permute { src: x, axes } if axes.len() == 2 ~> x
};
let x = UOp::native_const(1.0f32);
let permute_2 = UOp::new(Op::Permute { src: x.clone(), axes: vec![1, 0] }, DType::Float32);
match matcher.rewrite(&permute_2, &mut ()) {
RewriteResult::Rewritten(r) => assert!(Arc::ptr_eq(&r, &x)),
_ => panic!("Permute with 2 axes should match"),
}
let permute_3 = UOp::new(Op::Permute { src: x.clone(), axes: vec![2, 0, 1] }, DType::Float32);
match matcher.rewrite(&permute_3, &mut ()) {
RewriteResult::NoMatch => {} _ => panic!("Permute with 3 axes should NOT match axes.len() == 2 guard"),
}
}
#[test]
fn test_nested_struct_pattern() {
let matcher = patterns! {
Cast { src: Cast { src: x, .. }, dtype } if *dtype == DType::Float32 ~> x
};
let x_int = UOp::native_const(42i32);
let inner_cast = x_int.cast(DType::Int64);
let outer_cast = inner_cast.cast(DType::Float32);
match matcher.rewrite(&outer_cast, &mut ()) {
RewriteResult::Rewritten(r) => {
assert!(Arc::ptr_eq(&r, &x_int), "Should extract innermost source");
}
_ => panic!("Nested Cast pattern should match"),
}
let single_cast = x_int.cast(DType::Float32);
match matcher.rewrite(&single_cast, &mut ()) {
RewriteResult::NoMatch => {} _ => panic!("Single Cast should NOT match nested pattern"),
}
}
#[test]
fn test_nested_struct_field_extraction() {
use morok_ir::types::{AddrSpace, BufferizeOpts};
let matcher = patterns! {
Index { buffer: Bufferize { compute, ranges, .. }, indices, gate: None }
if ranges.len() == indices.len() ~> compute
};
let opts = BufferizeOpts { device: None, addrspace: AddrSpace::Global, removable: true };
let compute = UOp::native_const(42.0f32);
let range1 = UOp::range(UOp::index_const(10), 0);
let range2 = UOp::range(UOp::index_const(20), 1);
let buf = UOp::bufferize(compute.clone(), vec![range1.clone(), range2.clone()], opts);
let idx = UOp::index().buffer(buf).indices(vec![range1.clone(), range2.clone()]).call().unwrap();
match matcher.rewrite(&idx, &mut ()) {
RewriteResult::Rewritten(r) => {
assert!(Arc::ptr_eq(&r, &compute), "Should extract compute from nested pattern");
}
_ => panic!("Nested Index(Bufferize) pattern should match with extracted ranges"),
}
}
#[test]
fn test_nested_struct_field_extraction_mismatch() {
use morok_ir::types::{AddrSpace, BufferizeOpts};
let matcher = patterns! {
Index { buffer: Bufferize { compute, ranges, .. }, indices, gate: None }
if ranges.len() == indices.len() ~> compute
};
let opts = BufferizeOpts { device: None, addrspace: AddrSpace::Global, removable: true };
let compute = UOp::native_const(42.0f32);
let range1 = UOp::range(UOp::index_const(10), 0);
let range2 = UOp::range(UOp::index_const(20), 1);
let buf = UOp::bufferize(compute.clone(), vec![range1.clone(), range2], opts);
let idx = UOp::index().buffer(buf).indices(vec![range1]).call().unwrap();
match matcher.rewrite(&idx, &mut ()) {
RewriteResult::NoMatch => {} _ => panic!("Should NOT match when ranges.len() != indices.len()"),
}
}
#[test]
fn test_for_loop_unary_expansion() {
#[allow(unused_variables)]
let matcher = patterns! {
for op in unary [Sqrt, Exp2] {
op(c) ~> {
Arc::clone(c)
}
}
};
let x = UOp::native_const(42.0f32);
let sqrt_x = x.try_sqrt().unwrap();
match matcher.rewrite(&sqrt_x, &mut ()) {
RewriteResult::Rewritten(r) => assert!(Arc::ptr_eq(&r, &x)),
_ => panic!("Sqrt pattern from for-loop should match"),
}
let exp2_x = x.try_exp2().unwrap();
match matcher.rewrite(&exp2_x, &mut ()) {
RewriteResult::Rewritten(r) => assert!(Arc::ptr_eq(&r, &x)),
_ => panic!("Exp2 pattern from for-loop should match"),
}
let sin_x = x.try_sin().unwrap();
match matcher.rewrite(&sin_x, &mut ()) {
RewriteResult::NoMatch => {} _ => panic!("Sin should NOT match (not in for-loop list)"),
}
}
#[test]
fn test_for_loop_binary_expansion() {
#[allow(unused_variables)]
let matcher = patterns! {
for op in binary [Add, Mul, Sub] {
op(x, @zero) ~> x
}
};
let x = UOp::native_const(42i32);
let zero = UOp::native_const(0i32);
let add_zero = binary(BinaryOp::Add, x.clone(), zero.clone());
match matcher.rewrite(&add_zero, &mut ()) {
RewriteResult::Rewritten(r) => assert!(Arc::ptr_eq(&r, &x)),
_ => panic!("Add(x, 0) from for-loop should match"),
}
let mul_zero = binary(BinaryOp::Mul, x.clone(), zero.clone());
match matcher.rewrite(&mul_zero, &mut ()) {
RewriteResult::Rewritten(r) => assert!(Arc::ptr_eq(&r, &x)),
_ => panic!("Mul(x, 0) from for-loop should match"),
}
let sub_zero = binary(BinaryOp::Sub, x.clone(), zero.clone());
match matcher.rewrite(&sub_zero, &mut ()) {
RewriteResult::Rewritten(r) => assert!(Arc::ptr_eq(&r, &x)),
_ => panic!("Sub(x, 0) from for-loop should match"),
}
let and_zero = x.try_and_op(&zero).unwrap();
match matcher.rewrite(&and_zero, &mut ()) {
RewriteResult::NoMatch => {} _ => panic!("And should NOT match (not in for-loop list)"),
}
}
#[test]
fn test_for_loop_ternary_expansion() {
#[allow(unused_variables)]
let matcher = patterns! {
for op in ternary [Where, MulAcc] {
op(a, b, c) ~> {
Arc::clone(a)
}
}
};
let cond = UOp::const_(DType::Bool, ConstValue::Bool(true));
let true_val = UOp::native_const(2.0f32);
let false_val = UOp::native_const(3.0f32);
let where_abc = UOp::try_where(cond.clone(), true_val.clone(), false_val.clone()).unwrap();
match matcher.rewrite(&where_abc, &mut ()) {
RewriteResult::Rewritten(r) => assert!(Arc::ptr_eq(&r, &cond)),
_ => panic!("Where pattern from for-loop should match"),
}
let a = UOp::native_const(1.0f32);
let b = UOp::native_const(2.0f32);
let c = UOp::native_const(3.0f32);
let mulacc_abc = UOp::try_mulacc(a.clone(), b.clone(), c.clone()).unwrap();
match matcher.rewrite(&mulacc_abc, &mut ()) {
RewriteResult::Rewritten(r) => assert!(Arc::ptr_eq(&r, &a)),
_ => panic!("MulAcc pattern from for-loop should match"),
}
}
#[test]
fn test_for_loop_with_op_var_access() {
use morok_ir::UnaryOp;
let matcher = patterns! {
for op in unary [Sqrt, Exp2] {
op(x) ~> {
match op {
UnaryOp::Sqrt => x.try_exp2().unwrap(),
UnaryOp::Exp2 => x.try_sqrt().unwrap(),
_ => Arc::clone(x),
}
}
}
};
let x = UOp::native_const(42.0f32);
let sqrt_x = x.try_sqrt().unwrap();
match matcher.rewrite(&sqrt_x, &mut ()) {
RewriteResult::Rewritten(r) => {
assert!(matches!(r.op(), Op::Unary(UnaryOp::Exp2, _)), "Sqrt should rewrite to Exp2");
}
_ => panic!("Sqrt pattern should match"),
}
let exp2_x = x.try_exp2().unwrap();
match matcher.rewrite(&exp2_x, &mut ()) {
RewriteResult::Rewritten(r) => {
assert!(matches!(r.op(), Op::Unary(UnaryOp::Sqrt, _)), "Exp2 should rewrite to Sqrt");
}
_ => panic!("Exp2 pattern should match"),
}
}
#[test]
fn test_for_loop_mixed_with_regular_patterns() {
#[allow(unused_variables)]
let matcher = patterns! {
Add(x, @zero) ~> x,
for op in unary [Sqrt, Exp2] {
op(x) ~> Arc::clone(x)
},
Mul(x, @one) ~> x,
};
let x = UOp::native_const(42.0f32);
let zero = UOp::native_const(0.0f32);
let one = UOp::native_const(1.0f32);
let add_zero = binary(BinaryOp::Add, x.clone(), zero);
match matcher.rewrite(&add_zero, &mut ()) {
RewriteResult::Rewritten(r) => assert!(Arc::ptr_eq(&r, &x)),
_ => panic!("Add(x, 0) should match"),
}
let sqrt_x = x.try_sqrt().unwrap();
match matcher.rewrite(&sqrt_x, &mut ()) {
RewriteResult::Rewritten(r) => assert!(Arc::ptr_eq(&r, &x)),
_ => panic!("Sqrt(x) from for-loop should match"),
}
let mul_one = binary(BinaryOp::Mul, x.clone(), one);
match matcher.rewrite(&mul_one, &mut ()) {
RewriteResult::Rewritten(r) => assert!(Arc::ptr_eq(&r, &x)),
_ => panic!("Mul(x, 1) should match"),
}
}
#[test]
fn test_for_loop_with_guard() {
#[allow(unused_variables)]
let matcher = patterns! {
for op in unary [Sqrt, Exp2] {
op(c) if matches!(c.op(), Op::Const(_)) ~> Arc::clone(c)
}
};
let c = UOp::native_const(42.0f32);
let sqrt_c = c.try_sqrt().unwrap();
match matcher.rewrite(&sqrt_c, &mut ()) {
RewriteResult::Rewritten(r) => assert!(Arc::ptr_eq(&r, &c)),
_ => panic!("Sqrt(const) should match with guard"),
}
let x = UOp::native_const(1.0f32);
let y = UOp::native_const(2.0f32);
let add_xy = binary(BinaryOp::Add, x, y);
let sqrt_add = add_xy.try_sqrt().unwrap();
match matcher.rewrite(&sqrt_add, &mut ()) {
RewriteResult::NoMatch => {} _ => panic!("Sqrt(non-const) should NOT match with const guard"),
}
}
#[test]
fn test_for_loop_with_binding() {
#[allow(unused_variables)]
let matcher = patterns! {
for op in unary [Sqrt, Exp2] {
op(inner @ @const) ~> inner
}
};
let c = UOp::native_const(42.0f32);
let sqrt_c = c.try_sqrt().unwrap();
match matcher.rewrite(&sqrt_c, &mut ()) {
RewriteResult::Rewritten(r) => assert!(Arc::ptr_eq(&r, &c)),
_ => panic!("Sqrt(inner @ @const) should match and return inner"),
}
let exp2_c = c.try_exp2().unwrap();
match matcher.rewrite(&exp2_c, &mut ()) {
RewriteResult::Rewritten(r) => assert!(Arc::ptr_eq(&r, &c)),
_ => panic!("Exp2(inner @ @const) should match and return inner"),
}
}
#[test]
fn test_const_with_value_extraction() {
let matcher = patterns! {
Add(x, _c@const(cv)) if cv == ConstValue::Int(0) ~> x
};
let x = UOp::native_const(42i32);
let zero = UOp::native_const(0i32);
let one = UOp::native_const(1i32);
let add_zero = binary(BinaryOp::Add, x.clone(), zero);
match matcher.rewrite(&add_zero, &mut ()) {
RewriteResult::Rewritten(r) => assert!(Arc::ptr_eq(&r, &x)),
_ => panic!("Add(x, 0) should match with cv == 0"),
}
let add_one = binary(BinaryOp::Add, x.clone(), one);
match matcher.rewrite(&add_one, &mut ()) {
RewriteResult::NoMatch => {} _ => panic!("Add(x, 1) should NOT match cv == 0 guard"),
}
}
#[test]
fn test_const_with_value_extraction_fallible() {
let matcher = patterns! {
Sqrt(_c@const(cv)) => cv.cast(&DType::Float32).map(|casted| UOp::const_(DType::Float32, casted))
};
let c = UOp::native_const(42.0f32);
let sqrt_c = c.try_sqrt().unwrap();
match matcher.rewrite(&sqrt_c, &mut ()) {
RewriteResult::Rewritten(r) => {
assert_eq!(r.dtype(), DType::Float32);
}
_ => panic!("Sqrt(c@const(cv)) should match and cast the value"),
}
}
#[test]
fn test_rest_pattern_end() {
use smallvec::smallvec;
let matcher = patterns! {
end_op @ End(_, ..) ~> {
if let Op::End { computation, .. } = end_op.op() {
Arc::clone(computation)
} else {
unreachable!()
}
}
};
let computation = UOp::native_const(42i32);
let range1 = UOp::range(UOp::index_const(10), 0);
let range2 = UOp::range(UOp::index_const(20), 1);
let end1 = computation.end(smallvec![range1.clone()]);
match matcher.rewrite(&end1, &mut ()) {
RewriteResult::Rewritten(r) => {
assert!(Arc::ptr_eq(&r, &computation), "Should rewrite to computation");
}
_ => panic!("End(_, ..) should match END with 1 range"),
}
let end2 = computation.end(smallvec![range1.clone(), range2.clone()]);
match matcher.rewrite(&end2, &mut ()) {
RewriteResult::Rewritten(r) => {
assert!(Arc::ptr_eq(&r, &computation), "Should rewrite to computation");
}
_ => panic!("End(_, ..) should match END with 2 ranges"),
}
}
#[test]
fn test_rest_pattern_reduce() {
use morok_ir::types::ReduceOp;
use smallvec::smallvec;
let matcher = patterns! {
reduce_op @ Reduce(_, ..) ~> UOp::const_(reduce_op.dtype(), ConstValue::Int(99))
};
let src = UOp::native_const(42i32);
let range1 = UOp::range(UOp::index_const(10), 0);
let range2 = UOp::range(UOp::index_const(20), 1);
let reduce1 = src.reduce(smallvec![range1.clone()], ReduceOp::Add);
match matcher.rewrite(&reduce1, &mut ()) {
RewriteResult::Rewritten(r) => {
assert!(matches!(r.op(), Op::Const(_)));
}
_ => panic!("Reduce(_, ..) should match REDUCE with 1 range"),
}
let reduce2 = src.reduce(smallvec![range1.clone(), range2.clone()], ReduceOp::Add);
match matcher.rewrite(&reduce2, &mut ()) {
RewriteResult::Rewritten(r) => {
assert!(matches!(r.op(), Op::Const(_)));
}
_ => panic!("Reduce(_, ..) should match REDUCE with 2 ranges"),
}
}
#[test]
fn test_rest_pattern_with_guard() {
use morok_ir::types::ReduceOp;
use smallvec::smallvec;
let matcher = patterns! {
reduce_op @ Reduce(_, ..) if {
matches!(reduce_op.op(), Op::Reduce { reduce_op: ReduceOp::Add, .. })
} ~> UOp::const_(reduce_op.dtype(), ConstValue::Int(0))
};
let src = UOp::native_const(42i32);
let range = UOp::range(UOp::index_const(10), 0);
let reduce_add = src.reduce(smallvec![range.clone()], ReduceOp::Add);
match matcher.rewrite(&reduce_add, &mut ()) {
RewriteResult::Rewritten(_) => {}
_ => panic!("Should match REDUCE Add"),
}
let reduce_mul = src.reduce(smallvec![range.clone()], ReduceOp::Mul);
match matcher.rewrite(&reduce_mul, &mut ()) {
RewriteResult::NoMatch => {} _ => panic!("Should NOT match REDUCE Mul"),
}
}
#[test]
fn test_bufferize_variable_ranges() {
use morok_ir::types::{AddrSpace, BufferizeOpts};
let matcher = patterns! {
Bufferize { compute: c, .. } if matches!(c.op(), Op::Const(_)) ~> c
};
let opts = BufferizeOpts { device: None, addrspace: AddrSpace::Global, removable: true };
let const_val = UOp::native_const(42.0f32);
let range1 = UOp::range(UOp::index_const(10), 0);
let range2 = UOp::range(UOp::index_const(20), 1);
let buf0 = UOp::bufferize(const_val.clone(), vec![], opts.clone());
match matcher.rewrite(&buf0, &mut ()) {
RewriteResult::Rewritten(r) => {
assert!(Arc::ptr_eq(&r, &const_val), "Should rewrite to const_val with 0 ranges");
}
_ => panic!("Bufferize {{ compute: c, .. }} should match with 0 ranges"),
}
let buf1 = UOp::bufferize(const_val.clone(), vec![range1.clone()], opts.clone());
match matcher.rewrite(&buf1, &mut ()) {
RewriteResult::Rewritten(r) => {
assert!(Arc::ptr_eq(&r, &const_val), "Should rewrite to const_val with 1 range");
}
_ => panic!("Bufferize {{ compute: c, .. }} should match with 1 range"),
}
let buf2 = UOp::bufferize(const_val.clone(), vec![range1, range2], opts);
match matcher.rewrite(&buf2, &mut ()) {
RewriteResult::Rewritten(r) => {
assert!(Arc::ptr_eq(&r, &const_val), "Should rewrite to const_val with 2 ranges");
}
_ => panic!("Bufferize {{ compute: c, .. }} should match with 2 ranges"),
}
}
#[test]
fn test_index_variable_indices() {
let matcher = patterns! {
Index { buffer: c, .. } if matches!(c.op(), Op::Const(_)) ~> c
};
let const_val = UOp::native_const(42.0f32);
let idx1 = UOp::index_const(0);
let idx2 = UOp::index_const(1);
let index1 = UOp::index().buffer(const_val.clone()).indices(vec![idx1.clone()]).call().unwrap();
match matcher.rewrite(&index1, &mut ()) {
RewriteResult::Rewritten(r) => {
assert!(Arc::ptr_eq(&r, &const_val), "Should rewrite to const_val with 1 index");
}
_ => panic!("Index {{ buffer: c, .. }} should match with 1 index"),
}
let index2 = UOp::index().buffer(const_val.clone()).indices(vec![idx1.clone(), idx2.clone()]).call().unwrap();
match matcher.rewrite(&index2, &mut ()) {
RewriteResult::Rewritten(r) => {
assert!(Arc::ptr_eq(&r, &const_val), "Should rewrite to const_val with 2 indices");
}
_ => panic!("Index {{ buffer: c, .. }} should match with 2 indices"),
}
let gate = UOp::const_(DType::Bool, ConstValue::Int(1));
let index_gated = UOp::index().buffer(const_val.clone()).indices(vec![idx1]).gate(gate).call().unwrap();
match matcher.rewrite(&index_gated, &mut ()) {
RewriteResult::Rewritten(r) => {
assert!(Arc::ptr_eq(&r, &const_val), "Should rewrite to const_val with gate");
}
_ => panic!("Index {{ buffer: c, .. }} should match with gate"),
}
}
#[test]
fn test_index_gate_bare_binding() {
let matcher = patterns! {
Index { buffer: b, indices: _, gate } => {
match gate {
Some(g) => Some(g.clone()), None => Some(b.clone()), }
}
};
let buffer = UOp::native_const(42.0f32);
let idx = UOp::index_const(0);
let gate_val = UOp::const_(DType::Bool, ConstValue::Int(1));
let ungated = UOp::index().buffer(buffer.clone()).indices(vec![idx.clone()]).call().unwrap();
match matcher.rewrite(&ungated, &mut ()) {
RewriteResult::Rewritten(r) => {
assert!(Arc::ptr_eq(&r, &buffer), "Should return buffer when no gate");
}
_ => panic!("Pattern should match ungated Index"),
}
let gated = UOp::index().buffer(buffer.clone()).indices(vec![idx]).gate(gate_val.clone()).call().unwrap();
match matcher.rewrite(&gated, &mut ()) {
RewriteResult::Rewritten(r) => {
assert!(Arc::ptr_eq(&r, &gate_val), "Should return gate when present");
}
_ => panic!("Pattern should match gated Index"),
}
}
#[test]
fn test_tuple_prefix_semantics_vs_exact() {
use morok_ir::types::{AddrSpace, BufferizeOpts};
let matcher = patterns! {
Bufferize { compute: c, .. } ~> c
};
let opts = BufferizeOpts { device: None, addrspace: AddrSpace::Global, removable: true };
let const_val = UOp::native_const(42.0f32);
let range1 = UOp::range(UOp::index_const(10), 0);
let range2 = UOp::range(UOp::index_const(20), 1);
let range3 = UOp::range(UOp::index_const(30), 2);
for (n, ranges) in [
(0, vec![]),
(1, vec![range1.clone()]),
(2, vec![range1.clone(), range2.clone()]),
(3, vec![range1, range2, range3]),
] {
let buf = UOp::bufferize(const_val.clone(), ranges, opts.clone());
match matcher.rewrite(&buf, &mut ()) {
RewriteResult::Rewritten(r) => {
assert!(Arc::ptr_eq(&r, &const_val), "Should rewrite with {} ranges", n);
}
_ => panic!("Bufferize {{ compute: c, .. }} should match with {} ranges (prefix semantics)", n),
}
}
}
#[test]
fn test_alternative_patterns_basic() {
let matcher = patterns! {
(Add(x, _y) | Mul(x, _y)) ~> x
};
let a = UOp::native_const(5i32);
let b = UOp::native_const(3i32);
let add = binary(BinaryOp::Add, a.clone(), b.clone());
match matcher.rewrite(&add, &mut ()) {
RewriteResult::Rewritten(r) => {
assert!(Arc::ptr_eq(&r, &a), "Add should rewrite to x");
}
_ => panic!("Add should match alternative pattern"),
}
let mul = binary(BinaryOp::Mul, a.clone(), b.clone());
match matcher.rewrite(&mul, &mut ()) {
RewriteResult::Rewritten(r) => {
assert!(Arc::ptr_eq(&r, &a), "Mul should rewrite to x");
}
_ => panic!("Mul should match alternative pattern"),
}
let sub = binary(BinaryOp::Sub, a.clone(), b.clone());
match matcher.rewrite(&sub, &mut ()) {
RewriteResult::NoMatch => {}
_ => panic!("Sub should NOT match (Add | Mul) pattern"),
}
}
#[test]
fn test_alternative_patterns_op_shorthand() {
let matcher = patterns! {
(Add | Mul)(x, @zero) ~> x
};
let x = UOp::native_const(42i32);
let zero = UOp::native_const(0i32);
let add = binary(BinaryOp::Add, x.clone(), zero.clone());
match matcher.rewrite(&add, &mut ()) {
RewriteResult::Rewritten(r) => {
assert!(Arc::ptr_eq(&r, &x), "Add(x, 0) should rewrite to x");
}
_ => panic!("Add(x, 0) should match"),
}
let mul = binary(BinaryOp::Mul, x.clone(), zero.clone());
match matcher.rewrite(&mul, &mut ()) {
RewriteResult::Rewritten(r) => {
assert!(Arc::ptr_eq(&r, &x), "Mul(x, 0) should rewrite to x");
}
_ => panic!("Mul(x, 0) should match"),
}
}
#[test]
fn test_alternative_patterns_grouped() {
let matcher = patterns! {
(Add(x, _y) | Mul(x, _y)) ~> x
};
let a = UOp::native_const(5i32);
let b = UOp::native_const(3i32);
let add = binary(BinaryOp::Add, a.clone(), b.clone());
match matcher.rewrite(&add, &mut ()) {
RewriteResult::Rewritten(r) => {
assert!(Arc::ptr_eq(&r, &a), "Add(x, y) should rewrite to x");
}
_ => panic!("Add should match"),
}
let mul = binary(BinaryOp::Mul, a.clone(), b.clone());
match matcher.rewrite(&mul, &mut ()) {
RewriteResult::Rewritten(r) => {
assert!(Arc::ptr_eq(&r, &a), "Mul(x, y) should rewrite to x");
}
_ => panic!("Mul should match"),
}
}
#[test]
fn test_alternative_patterns_with_special_const() {
let matcher = patterns! {
(Add(x, @zero) | Add(x, @one)) ~> x
};
let x = UOp::native_const(42i32);
let zero = UOp::native_const(0i32);
let one = UOp::native_const(1i32);
let two = UOp::native_const(2i32);
let add0 = binary(BinaryOp::Add, x.clone(), zero);
match matcher.rewrite(&add0, &mut ()) {
RewriteResult::Rewritten(r) => {
assert!(Arc::ptr_eq(&r, &x), "Add(x, @zero) should rewrite to x");
}
_ => panic!("Add(x, 0) should match @zero"),
}
let add1 = binary(BinaryOp::Add, x.clone(), one);
match matcher.rewrite(&add1, &mut ()) {
RewriteResult::Rewritten(r) => {
assert!(Arc::ptr_eq(&r, &x), "Add(x, @one) should rewrite to x");
}
_ => panic!("Add(x, 1) should match @one"),
}
let add2 = binary(BinaryOp::Add, x.clone(), two);
match matcher.rewrite(&add2, &mut ()) {
RewriteResult::NoMatch => {} _ => panic!("Add(x, 2) should NOT match (neither @zero nor @one)"),
}
}
#[test]
fn test_permutation_pattern_basic() {
let matcher = patterns! {
Add[x, @const] ~> x
};
let x = UOp::native_const(42i32);
let c = UOp::native_const(5i32);
let add1 = binary(BinaryOp::Add, x.clone(), c.clone());
match matcher.rewrite(&add1, &mut ()) {
RewriteResult::Rewritten(_) => {}
_ => panic!("Add(x, c) should match permutation pattern"),
}
let add2 = binary(BinaryOp::Add, c.clone(), x.clone());
match matcher.rewrite(&add2, &mut ()) {
RewriteResult::Rewritten(_) => {}
_ => panic!("Add(c, x) should match permutation pattern"),
}
}
#[test]
fn test_permutation_pattern_commutative_const_folding() {
let matcher = patterns! {
Add[x, Const(0)] ~> x
};
let x = UOp::native_const(42i32);
let zero = UOp::native_const(0i32);
let add1 = binary(BinaryOp::Add, x.clone(), zero.clone());
match matcher.rewrite(&add1, &mut ()) {
RewriteResult::Rewritten(r) => {
assert!(Arc::ptr_eq(&r, &x), "Add(x, 0) should rewrite to x");
}
_ => panic!("Add(x, 0) should match"),
}
let add2 = binary(BinaryOp::Add, zero.clone(), x.clone());
match matcher.rewrite(&add2, &mut ()) {
RewriteResult::Rewritten(r) => {
assert!(Arc::ptr_eq(&r, &x), "Add(0, x) should rewrite to x");
}
_ => panic!("Add(0, x) should match"),
}
}
#[test]
fn test_copy_struct_pattern() {
use morok_device::DeviceSpec;
let matcher = patterns! {
Copy { src: c, .. } if matches!(c.op(), Op::Const(_)) ~> c
};
let const_val = UOp::native_const(42.0f32);
let copy_op = const_val.copy_to_device(DeviceSpec::Cuda { device_id: 0 });
match matcher.rewrite(©_op, &mut ()) {
RewriteResult::Rewritten(r) => {
assert!(Arc::ptr_eq(&r, &const_val), "Copy {{ src: c, .. }} should rewrite to c");
}
_ => panic!("Copy {{ src: c, .. }} should match when src is constant"),
}
}
#[derive(Default)]
struct TestContext {
counter: u32,
}
impl TestContext {
fn increment(&mut self) -> u32 {
self.counter += 1;
self.counter
}
}
#[test]
fn test_context_declaration() {
let matcher = patterns! {
@context TestContext;
x if matches!(x.op(), Op::Const(_)) => {
let count = ctx.increment();
if count > 0 {
Some(Arc::clone(x))
} else {
None
}
}
};
let c = UOp::native_const(42i32);
let mut ctx = TestContext::default();
assert_eq!(ctx.counter, 0);
let result1 = matcher.rewrite(&c, &mut ctx);
assert!(matches!(result1, RewriteResult::Rewritten(_)));
assert_eq!(ctx.counter, 1);
let result2 = matcher.rewrite(&c, &mut ctx);
assert!(matches!(result2, RewriteResult::Rewritten(_)));
assert_eq!(ctx.counter, 2);
}
#[test]
fn test_context_with_graph_rewrite() {
use crate::rewrite::graph_rewrite;
let matcher = patterns! {
@context TestContext;
Add(x, @zero) => {
ctx.increment();
Some(Arc::clone(x))
}
};
let x = UOp::native_const(5i32);
let zero = UOp::native_const(0i32);
let add = binary(BinaryOp::Add, x.clone(), zero);
let mut ctx = TestContext::default();
let result = graph_rewrite(&matcher, add, &mut ctx);
assert!(Arc::ptr_eq(&result, &x));
assert_eq!(ctx.counter, 1);
}
#[test]
fn test_context_pattern_composition() {
let matcher1 = patterns! {
@context TestContext;
Add(x, @zero) => {
ctx.increment();
Some(Arc::clone(x))
}
};
let matcher2 = patterns! {
@context TestContext;
Mul(x, @one) => {
ctx.increment();
ctx.increment(); Some(Arc::clone(x))
}
};
let combined = matcher1 + matcher2;
let x = UOp::native_const(5i32);
let zero = UOp::native_const(0i32);
let one = UOp::native_const(1i32);
let add_zero = binary(BinaryOp::Add, x.clone(), zero);
let mul_one = binary(BinaryOp::Mul, x.clone(), one);
let mut ctx = TestContext::default();
let result1 = combined.rewrite(&add_zero, &mut ctx);
assert!(matches!(result1, RewriteResult::Rewritten(_)));
assert_eq!(ctx.counter, 1);
let result2 = combined.rewrite(&mul_one, &mut ctx);
assert!(matches!(result2, RewriteResult::Rewritten(_)));
assert_eq!(ctx.counter, 3); }
#[test]
fn test_commutative_pattern_with_special_zero() {
let matcher = patterns! {
Add[x, @zero] ~> x
};
let x = UOp::var("a", morok_dtype::DType::Int32, 0, i64::MAX);
let zero = UOp::native_const(0i32);
let add_x_zero = binary(BinaryOp::Add, x.clone(), zero.clone());
match matcher.rewrite(&add_x_zero, &mut ()) {
RewriteResult::Rewritten(r) => {
assert!(Arc::ptr_eq(&r, &x), "Add(x, 0) should rewrite to x");
}
_ => panic!("Add[x, @zero] should match Add(x, 0)"),
}
let add_zero_x = binary(BinaryOp::Add, zero.clone(), x.clone());
match matcher.rewrite(&add_zero_x, &mut ()) {
RewriteResult::Rewritten(r) => {
assert!(Arc::ptr_eq(&r, &x), "Add(0, x) should rewrite to x");
}
_ => panic!("Add[x, @zero] should match Add(0, x) via commutativity"),
}
}
#[test]
fn test_commutative_pattern_with_graph_rewrite() {
use crate::rewrite::graph_rewrite;
let matcher = patterns! {
Add[x, @zero] ~> x
};
let x = UOp::var("a", morok_dtype::DType::Int32, 0, i64::MAX);
let zero = UOp::native_const(0i32);
let add_zero_x = binary(BinaryOp::Add, zero.clone(), x.clone());
let result = graph_rewrite(&matcher, add_zero_x, &mut ());
assert!(Arc::ptr_eq(&result, &x), "graph_rewrite(Add(0, x)) should simplify to x");
}
#[test]
fn test_symbolic_simple_add_zero() {
use crate::rewrite::graph_rewrite;
use crate::symbolic::patterns::{constant_folding_dsl_patterns, identity_and_zero_patterns};
let matcher = constant_folding_dsl_patterns() + identity_and_zero_patterns();
let x = UOp::var("a", morok_dtype::DType::Int32, 0, i64::MAX);
let zero = UOp::native_const(0i32);
let add_zero_x = binary(BinaryOp::Add, zero.clone(), x.clone());
let result = graph_rewrite(&matcher, add_zero_x, &mut ());
assert!(Arc::ptr_eq(&result, &x), "combined patterns + graph_rewrite(Add(0, x)) should simplify to x");
}
#[test]
fn test_option_none_pattern() {
let matcher = patterns! {
Index { buffer: b, indices: _, gate: None } ~> b
};
let buffer = UOp::native_const(42.0f32);
let idx = UOp::index_const(0);
let ungated = UOp::index().buffer(buffer.clone()).indices(vec![idx.clone()]).call().unwrap();
match matcher.rewrite(&ungated, &mut ()) {
RewriteResult::Rewritten(r) => {
assert!(Arc::ptr_eq(&r, &buffer), "Should extract buffer from ungated Index");
}
_ => panic!("Index with gate: None should match"),
}
let gate = UOp::const_(DType::Bool, ConstValue::Int(1));
let gated = UOp::index().buffer(buffer.clone()).indices(vec![idx]).gate(gate).call().unwrap();
match matcher.rewrite(&gated, &mut ()) {
RewriteResult::NoMatch => {} _ => panic!("Index with gate: Some(_) should NOT match gate: None pattern"),
}
}
#[test]
fn test_option_some_pattern() {
let matcher = patterns! {
Index { buffer: _, indices: _, gate: Some(g) } ~> g
};
let buffer = UOp::native_const(42.0f32);
let idx = UOp::index_const(0);
let gate = UOp::const_(DType::Bool, ConstValue::Int(1));
let gated = UOp::index().buffer(buffer.clone()).indices(vec![idx.clone()]).gate(gate.clone()).call().unwrap();
match matcher.rewrite(&gated, &mut ()) {
RewriteResult::Rewritten(r) => {
assert!(Arc::ptr_eq(&r, &gate), "Should extract gate from gated Index");
}
_ => panic!("Index with gate: Some(g) should match"),
}
let ungated = UOp::index().buffer(buffer.clone()).indices(vec![idx]).call().unwrap();
match matcher.rewrite(&ungated, &mut ()) {
RewriteResult::NoMatch => {} _ => panic!("Index with gate: None should NOT match gate: Some(g) pattern"),
}
}
#[test]
fn test_nested_index_with_gate_none() {
let matcher = patterns! {
Index {
buffer: Index { buffer: real_buffer, indices: inner_indices, gate: None },
indices: outer_indices,
gate: None
} if outer_indices.len() == 1 && inner_indices.len() == 1 => |real_buffer, inner_indices| {
UOp::index().buffer(real_buffer.clone()).indices(vec![inner_indices[0].clone()]).call().ok()
}
};
let real_buffer = UOp::native_const(42.0f32);
let idx1 = UOp::index_const(5);
let idx2 = UOp::index_const(10);
let inner_idx = UOp::index().buffer(real_buffer.clone()).indices(vec![idx1.clone()]).call().unwrap();
let outer_idx = UOp::index().buffer(inner_idx.clone()).indices(vec![idx2.clone()]).call().unwrap();
match matcher.rewrite(&outer_idx, &mut ()) {
RewriteResult::Rewritten(r) => {
if let Op::Index { buffer, indices, gate } = r.op() {
assert!(Arc::ptr_eq(buffer, &real_buffer), "Buffer should be real_buffer");
assert_eq!(indices.len(), 1, "Should have 1 index");
assert!(Arc::ptr_eq(&indices[0], &idx1), "Index should be idx1 from inner");
assert!(gate.is_none(), "Gate should be None");
} else {
panic!("Result should be Index op");
}
}
_ => panic!("Nested Index pattern should match"),
}
let gate = UOp::const_(DType::Bool, ConstValue::Int(1));
let gated_outer = UOp::index().buffer(inner_idx.clone()).indices(vec![idx2.clone()]).gate(gate).call().unwrap();
match matcher.rewrite(&gated_outer, &mut ()) {
RewriteResult::NoMatch => {} _ => panic!("Should NOT match when outer gate is Some"),
}
}
#[test]
fn test_for_loop_binary_wildcard() {
#[allow(unused_variables)]
let matcher = patterns! {
for op in binary [*] {
op(x, @zero) ~> x
}
};
let x = UOp::native_const(42i32);
let zero = UOp::native_const(0i32);
let add_zero = binary(BinaryOp::Add, x.clone(), zero.clone());
match matcher.rewrite(&add_zero, &mut ()) {
RewriteResult::Rewritten(r) => assert!(Arc::ptr_eq(&r, &x)),
_ => panic!("Add(x, 0) from binary [*] should match"),
}
let mul_zero = binary(BinaryOp::Mul, x.clone(), zero.clone());
match matcher.rewrite(&mul_zero, &mut ()) {
RewriteResult::Rewritten(r) => assert!(Arc::ptr_eq(&r, &x)),
_ => panic!("Mul(x, 0) from binary [*] should match"),
}
let xor_zero = x.try_xor_op(&zero).unwrap();
match matcher.rewrite(&xor_zero, &mut ()) {
RewriteResult::Rewritten(r) => assert!(Arc::ptr_eq(&r, &x)),
_ => panic!("Xor(x, 0) from binary [*] should match"),
}
}
#[test]
fn test_for_loop_unary_wildcard() {
#[allow(unused_variables)]
let matcher = patterns! {
for op in unary [*] {
op(c) if matches!(c.op(), Op::Const(_)) ~> c
}
};
let c = UOp::native_const(42.0f32);
let neg_c = UOp::new(Op::Unary(morok_ir::UnaryOp::Neg, c.clone()), c.dtype());
match matcher.rewrite(&neg_c, &mut ()) {
RewriteResult::Rewritten(r) => assert!(Arc::ptr_eq(&r, &c)),
_ => panic!("Neg(const) from unary [*] should match"),
}
let sqrt_c = c.try_sqrt().unwrap();
match matcher.rewrite(&sqrt_c, &mut ()) {
RewriteResult::Rewritten(r) => assert!(Arc::ptr_eq(&r, &c)),
_ => panic!("Sqrt(const) from unary [*] should match"),
}
let exp2_c = c.try_exp2().unwrap();
match matcher.rewrite(&exp2_c, &mut ()) {
RewriteResult::Rewritten(r) => assert!(Arc::ptr_eq(&r, &c)),
_ => panic!("Exp2(const) from unary [*] should match"),
}
}
#[test]
fn test_for_loop_ternary_wildcard() {
#[allow(unused_variables)]
let matcher = patterns! {
for op in ternary [*] {
op(a, b, c) ~> {
Arc::clone(a)
}
}
};
let cond = UOp::const_(DType::Bool, ConstValue::Bool(true));
let true_val = UOp::native_const(2.0f32);
let false_val = UOp::native_const(3.0f32);
let where_abc = UOp::try_where(cond.clone(), true_val.clone(), false_val.clone()).unwrap();
match matcher.rewrite(&where_abc, &mut ()) {
RewriteResult::Rewritten(r) => assert!(Arc::ptr_eq(&r, &cond)),
_ => panic!("Where from ternary [*] should match"),
}
let a = UOp::native_const(1.0f32);
let b = UOp::native_const(2.0f32);
let c = UOp::native_const(3.0f32);
let mulacc_abc = UOp::try_mulacc(a.clone(), b.clone(), c.clone()).unwrap();
match matcher.rewrite(&mulacc_abc, &mut ()) {
RewriteResult::Rewritten(r) => assert!(Arc::ptr_eq(&r, &a)),
_ => panic!("MulAcc from ternary [*] should match"),
}
}