mod index_lowering;
use crate::{
pattern::RewriteResult,
symbolic::{sym, symbolic, symbolic_simple},
};
use morok_dtype::DType;
use morok_ir::{BinaryOp, ConstValue, Op, TernaryOp, UOp, UnaryOp};
use std::{f32::consts::PI, sync::Arc};
#[test]
fn test_symbolic_simple_identity_folding() {
let matcher = symbolic_simple();
let five = UOp::native_const(5i32);
let zero = UOp::native_const(0i32);
let add = five.try_add(&zero).unwrap();
let result = matcher.rewrite(&add, &mut ());
assert!(matches!(result, RewriteResult::Rewritten(_)));
if let RewriteResult::Rewritten(rewritten) = result {
assert!(std::sync::Arc::ptr_eq(&rewritten, &five));
}
let add2 = zero.try_add(&five).unwrap();
let result2 = matcher.rewrite(&add2, &mut ());
assert!(matches!(result2, RewriteResult::Rewritten(_)));
let one = UOp::native_const(1i32);
let mul = five.try_mul(&one).unwrap();
let result3 = matcher.rewrite(&mul, &mut ());
assert!(matches!(result3, RewriteResult::Rewritten(_)));
let sub = five.try_sub(&zero).unwrap();
let result4 = matcher.rewrite(&sub, &mut ());
assert!(matches!(result4, RewriteResult::Rewritten(_)));
let idiv = five.try_div(&one).unwrap();
let result5 = matcher.rewrite(&idiv, &mut ());
assert!(matches!(result5, RewriteResult::Rewritten(_)));
let five_f = UOp::native_const(5.0f32);
let one_f = UOp::native_const(1.0f32);
let fdiv = five_f.try_div(&one_f).unwrap();
let result6 = matcher.rewrite(&fdiv, &mut ());
assert!(matches!(result6, RewriteResult::Rewritten(_)));
let or_op = five.try_or_op(&zero).unwrap();
let result7 = matcher.rewrite(&or_op, &mut ());
assert!(matches!(result7, RewriteResult::Rewritten(_)));
let xor_op = five.try_xor_op(&zero).unwrap();
let result8 = matcher.rewrite(&xor_op, &mut ());
assert!(matches!(result8, RewriteResult::Rewritten(_)));
}
#[test]
fn test_symbolic_simple_zero_propagation() {
let matcher = symbolic_simple();
let five = UOp::native_const(5i32);
let zero = UOp::native_const(0i32);
let mul = five.try_mul(&zero).unwrap();
let result = matcher.rewrite(&mul, &mut ());
assert!(matches!(result, RewriteResult::Rewritten(_)));
if let RewriteResult::Rewritten(rewritten) = result {
if let Op::Const(cv) = rewritten.op() {
assert_eq!(cv.0, ConstValue::Int(0));
} else {
panic!("Expected Const op, got {:?}", rewritten.op());
}
}
let mul2 = zero.try_mul(&five).unwrap();
let result2 = matcher.rewrite(&mul2, &mut ());
assert!(matches!(result2, RewriteResult::Rewritten(_)));
let and_op = five.try_and_op(&zero).unwrap();
let result3 = matcher.rewrite(&and_op, &mut ());
assert!(matches!(result3, RewriteResult::Rewritten(_)));
let and2 = zero.try_and_op(&five).unwrap();
let result4 = matcher.rewrite(&and2, &mut ());
assert!(matches!(result4, RewriteResult::Rewritten(_)));
}
#[test]
fn test_symbolic_simple_const_folding() {
let matcher = symbolic_simple();
let five = UOp::native_const(5i32);
let three = UOp::native_const(3i32);
let add = five.try_add(&three).unwrap();
let result = matcher.rewrite(&add, &mut ());
assert!(matches!(result, RewriteResult::Rewritten(_)));
if let RewriteResult::Rewritten(rewritten) = result {
if let Op::Const(cv) = rewritten.op() {
assert_eq!(cv.0, ConstValue::Int(8));
} else {
panic!("Expected Const(Int(8)), got {:?}", rewritten.op());
}
}
let two = UOp::native_const(2i32);
let mul = five.try_mul(&two).unwrap();
let result2 = matcher.rewrite(&mul, &mut ());
assert!(matches!(result2, RewriteResult::Rewritten(_)));
if let RewriteResult::Rewritten(rewritten) = result2 {
if let Op::Const(cv) = rewritten.op() {
assert_eq!(cv.0, ConstValue::Int(10));
} else {
panic!("Expected Const(Int(10)), got {:?}", rewritten.op());
}
}
}
#[test]
fn test_self_division() {
let matcher = symbolic_simple();
let x = UOp::var("x", DType::Int32, 0, i64::MAX);
let div = x.try_div(&x).unwrap();
let result = matcher.rewrite(&div, &mut ());
assert!(matches!(result, RewriteResult::Rewritten(_)));
if let RewriteResult::Rewritten(rewritten) = result {
if let Op::Const(cv) = rewritten.op() {
assert_eq!(cv.0, ConstValue::Int(1));
} else {
panic!("Expected Const(1), got {:?}", rewritten.op());
}
}
}
#[test]
fn test_division_by_neg_one() {
use crate::rewrite::graph_rewrite;
let matcher = symbolic_simple();
let x = UOp::var("x", DType::Int32, 0, i64::MAX);
let neg_one = UOp::native_const(-1i32);
let div = x.try_div(&neg_one).unwrap();
let result = graph_rewrite(&matcher, div, &mut ());
if let Op::Binary(morok_ir::BinaryOp::Mul, inner, c) = result.op() {
assert!(std::sync::Arc::ptr_eq(inner, &x));
assert!(matches!(c.op(), Op::Const(cv) if cv.0.is_neg_one()));
} else {
panic!("Expected MUL(x, -1), got {:?}", result.op());
}
}
#[test]
fn test_idempotent_modulo() {
let matcher = symbolic_simple();
let x = UOp::var("x", DType::Int32, 0, i64::MAX);
let y = UOp::var("y", DType::Int32, 0, i64::MAX);
let inner_mod = x.try_mod(&y).unwrap();
let outer_mod = inner_mod.try_mod(&y).unwrap();
let result = matcher.rewrite(&outer_mod, &mut ());
assert!(matches!(result, RewriteResult::Rewritten(_)));
if let RewriteResult::Rewritten(rewritten) = result {
if let Op::Binary(BinaryOp::Mod, a, b) = rewritten.op() {
assert!(std::sync::Arc::ptr_eq(a, &x));
assert!(std::sync::Arc::ptr_eq(b, &y));
} else {
panic!("Expected Binary(Mod, x, y), got {:?}", rewritten.op());
}
}
}
#[test]
fn test_idempotent_and() {
let matcher = symbolic_simple();
let x = UOp::var("x", DType::Int32, 0, i64::MAX);
let and_op = x.try_and_op(&x).unwrap();
let result = matcher.rewrite(&and_op, &mut ());
assert!(matches!(result, RewriteResult::Rewritten(_)));
if let RewriteResult::Rewritten(rewritten) = result {
assert!(std::sync::Arc::ptr_eq(&rewritten, &x));
}
}
#[test]
fn test_idempotent_or() {
let matcher = symbolic_simple();
let x = UOp::var("x", DType::Int32, 0, i64::MAX);
let or_op = x.try_or_op(&x).unwrap();
let result = matcher.rewrite(&or_op, &mut ());
assert!(matches!(result, RewriteResult::Rewritten(_)));
if let RewriteResult::Rewritten(rewritten) = result {
assert!(std::sync::Arc::ptr_eq(&rewritten, &x));
}
}
#[test]
fn test_non_idempotent_and() {
let matcher = symbolic_simple();
let x = UOp::var("x", DType::Int32, 0, i64::MAX);
let y = UOp::var("y", DType::Int32, 0, i64::MAX);
let and_op = x.try_and_op(&y).unwrap();
let result = matcher.rewrite(&and_op, &mut ());
assert!(matches!(result, RewriteResult::NoMatch));
}
#[test]
fn test_self_comparison_lt() {
let matcher = symbolic_simple();
let x = UOp::var("x", DType::Int32, 0, i64::MAX);
let lt = x.try_cmplt(&x).unwrap();
let result = matcher.rewrite(<, &mut ());
assert!(matches!(result, RewriteResult::Rewritten(_)));
if let RewriteResult::Rewritten(rewritten) = result {
if let Op::Const(cv) = rewritten.op() {
assert_eq!(cv.0, ConstValue::Bool(false));
} else {
panic!("Expected Const(Bool(false)), got {:?}", rewritten.op());
}
}
}
#[test]
fn test_self_modulo() {
let matcher = symbolic_simple();
let x = UOp::var("x", DType::Int32, 0, i64::MAX);
let modulo = x.try_mod(&x).unwrap();
let result = matcher.rewrite(&modulo, &mut ());
assert!(matches!(result, RewriteResult::Rewritten(_)));
if let RewriteResult::Rewritten(rewritten) = result {
if let Op::Const(cv) = rewritten.op() {
assert_eq!(cv.0, ConstValue::Int(0));
} else {
panic!("Expected Const(0), got {:?}", rewritten.op());
}
}
}
#[test]
fn test_self_inequality_int() {
let matcher = symbolic_simple();
let x = UOp::var("x", DType::Int32, 0, i64::MAX);
let ne = x.try_cmpne(&x).unwrap();
let result = matcher.rewrite(&ne, &mut ());
assert!(matches!(result, RewriteResult::Rewritten(_)));
if let RewriteResult::Rewritten(rewritten) = result {
if let Op::Const(cv) = rewritten.op() {
assert_eq!(cv.0, ConstValue::Bool(false));
} else {
panic!("Expected Const(Bool(false)), got {:?}", rewritten.op());
}
}
}
#[test]
fn test_self_inequality_float_no_fold() {
let matcher = symbolic_simple();
let x = UOp::var("x", DType::Float32, 0, i64::MAX);
let ne = x.try_cmpne(&x).unwrap();
let result = matcher.rewrite(&ne, &mut ());
assert!(matches!(result, RewriteResult::NoMatch));
}
#[test]
fn test_float_self_division() {
let matcher = symbolic_simple();
let x = UOp::var("x", DType::Float32, 0, i64::MAX);
let div = x.try_div(&x).unwrap();
let result = matcher.rewrite(&div, &mut ());
assert!(matches!(result, RewriteResult::Rewritten(_)));
if let RewriteResult::Rewritten(rewritten) = result {
if let Op::Const(cv) = rewritten.op() {
assert_eq!(cv.0, ConstValue::Float(1.0));
} else {
panic!("Expected Const(1.0), got {:?}", rewritten.op());
}
}
}
#[test]
fn test_division_cancel_multiplication() {
let matcher = symbolic_simple();
let x = UOp::var("x", DType::Float32, 0, i64::MAX);
let y = UOp::var("y", DType::Float32, 0, i64::MAX);
let mul = x.try_mul(&y).unwrap();
let div = mul.try_div(&y).unwrap();
let result = matcher.rewrite(&div, &mut ());
assert!(matches!(result, RewriteResult::Rewritten(_)));
if let RewriteResult::Rewritten(rewritten) = result {
assert!(std::sync::Arc::ptr_eq(&rewritten, &x));
}
}
#[test]
fn test_int_division_cancel_multiplication() {
let matcher = symbolic_simple();
let x = UOp::var("x", DType::Int32, 0, i64::MAX);
let y = UOp::var("y", DType::Int32, 0, i64::MAX);
let mul = x.try_mul(&y).unwrap();
let div = mul.try_div(&y).unwrap();
let result = matcher.rewrite(&div, &mut ());
assert!(matches!(result, RewriteResult::Rewritten(_)));
if let RewriteResult::Rewritten(rewritten) = result {
assert!(std::sync::Arc::ptr_eq(&rewritten, &x));
}
}
#[test]
fn test_cast_int_to_float_constant() {
let matcher = symbolic_simple();
let int_val = UOp::native_const(42i32);
let cast = int_val.cast(DType::Float32);
let result = matcher.rewrite(&cast, &mut ());
assert!(matches!(result, RewriteResult::Rewritten(_)));
if let RewriteResult::Rewritten(rewritten) = result {
if let Op::Const(cv) = rewritten.op() {
assert_eq!(cv.0, ConstValue::Float(42.0));
} else {
panic!("Expected Const(Float(42.0)), got {:?}", rewritten.op());
}
assert_eq!(rewritten.dtype(), DType::Float32);
}
}
#[test]
fn test_cast_float_to_int_constant() {
let matcher = symbolic_simple();
let float_val = UOp::native_const(PI);
let cast = float_val.cast(DType::Int32);
let result = matcher.rewrite(&cast, &mut ());
assert!(matches!(result, RewriteResult::Rewritten(_)));
if let RewriteResult::Rewritten(rewritten) = result {
if let Op::Const(cv) = rewritten.op() {
assert_eq!(cv.0, ConstValue::Int(3));
} else {
panic!("Expected Const(Int(3)), got {:?}", rewritten.op());
}
assert_eq!(rewritten.dtype(), DType::Int32);
}
}
#[test]
fn test_cast_bool_to_int_constant() {
let matcher = symbolic_simple();
let bool_val = UOp::native_const(true);
let cast = bool_val.cast(DType::Int32);
let result = matcher.rewrite(&cast, &mut ());
assert!(matches!(result, RewriteResult::Rewritten(_)));
if let RewriteResult::Rewritten(rewritten) = result {
if let Op::Const(cv) = rewritten.op() {
assert_eq!(cv.0, ConstValue::Int(1));
} else {
panic!("Expected Const(Int(1)), got {:?}", rewritten.op());
}
}
}
#[test]
fn test_noop_cast_same_dtype() {
use crate::rewrite::graph_rewrite;
let matcher = symbolic_simple();
let x = UOp::var("x", DType::Int32, 0, i64::MAX);
let cast = x.cast(DType::Int32);
let result = graph_rewrite(&matcher, cast, &mut ());
assert!(std::sync::Arc::ptr_eq(&result, &x), "Noop cast should be eliminated");
}
#[test]
fn test_double_cast_collapse_safe() {
let matcher = symbolic_simple();
let x = UOp::var("x", DType::Int16, 0, i16::MAX as i64);
let inner_cast = x.cast(DType::Int32);
let outer_cast = inner_cast.cast(DType::Int16);
let result = matcher.rewrite(&outer_cast, &mut ());
assert!(matches!(result, RewriteResult::Rewritten(_)));
if let RewriteResult::Rewritten(rewritten) = result {
assert!(std::sync::Arc::ptr_eq(&rewritten, &x), "Expected x, got {:?}", rewritten.op());
}
}
#[test]
fn test_double_cast_no_collapse_unsafe() {
let matcher = symbolic_simple();
let x = UOp::var("x", DType::Int32, 0, i64::MAX);
let inner_cast = x.cast(DType::Float32);
let outer_cast = inner_cast.cast(DType::Int32);
let result = matcher.rewrite(&outer_cast, &mut ());
assert!(matches!(result, RewriteResult::NoMatch), "Unsafe double cast should NOT collapse: Int32->Float32->Int32");
}
#[test]
fn test_cast_non_constant_no_fold() {
let matcher = symbolic_simple();
let x = UOp::var("x", DType::Int32, 0, i64::MAX);
let cast = x.cast(DType::Float32);
let result = matcher.rewrite(&cast, &mut ());
assert!(matches!(result, RewriteResult::NoMatch));
}
#[test]
fn test_combine_identical_terms() {
let matcher = symbolic();
let x = UOp::var("x", DType::Int32, 0, i64::MAX);
let add = x.try_add(&x).unwrap();
let result = matcher.rewrite(&add, &mut ());
if !matches!(result, RewriteResult::Rewritten(_)) {
eprintln!("Test failed: x + x didn't match. Result: {:?}", result);
eprintln!("Add op: {:?}", add.op());
}
assert!(matches!(result, RewriteResult::Rewritten(_)));
if let RewriteResult::Rewritten(rewritten) = result {
if let Op::Binary(BinaryOp::Mul, c, var) = rewritten.op() {
if let Op::Const(cv) = c.op() {
assert_eq!(cv.0, ConstValue::Int(2));
} else {
panic!("Expected constant, got {:?}", c.op());
}
assert!(Arc::ptr_eq(var, &x));
} else {
panic!("Expected Mul, got {:?}", rewritten.op());
}
}
}
#[test]
fn test_combine_terms_with_coefficients() {
let matcher = symbolic();
let x = UOp::var("x", DType::Int32, 0, i64::MAX);
let c3 = UOp::native_const(3i32);
let c5 = UOp::native_const(5i32);
let term1 = c3.try_mul(&x).unwrap();
let term2 = c5.try_mul(&x).unwrap();
let add = term1.try_add(&term2).unwrap();
let result = matcher.rewrite(&add, &mut ());
assert!(matches!(result, RewriteResult::Rewritten(_)));
if let RewriteResult::Rewritten(rewritten) = result {
if let Op::Binary(BinaryOp::Mul, var, c) = rewritten.op() {
assert!(Arc::ptr_eq(var, &x));
if let Op::Const(cv) = c.op() {
assert_eq!(cv.0, ConstValue::Int(8));
} else {
panic!("Expected constant, got {:?}", c.op());
}
} else {
panic!("Expected Mul, got {:?}", rewritten.op());
}
}
}
#[test]
fn test_combine_terms_reversed_multiplication() {
let matcher = symbolic();
let x = UOp::var("x", DType::Int32, 0, i64::MAX);
let c3 = UOp::native_const(3i32);
let c5 = UOp::native_const(5i32);
let term1 = x.try_mul(&c3).unwrap();
let term2 = x.try_mul(&c5).unwrap();
let add = term1.try_add(&term2).unwrap();
let result = matcher.rewrite(&add, &mut ());
assert!(matches!(result, RewriteResult::Rewritten(_)));
if let RewriteResult::Rewritten(rewritten) = result {
if let Op::Binary(BinaryOp::Mul, var, c) = rewritten.op() {
assert!(Arc::ptr_eq(var, &x));
if let Op::Const(cv) = c.op() {
assert_eq!(cv.0, ConstValue::Int(8));
} else {
panic!("Expected constant, got {:?}", c.op());
}
} else {
panic!("Expected Mul, got {:?}", rewritten.op());
}
}
}
#[test]
fn test_no_combine_different_variables() {
let matcher = symbolic_simple();
let x = UOp::var("x", DType::Int32, 0, i64::MAX);
let y = UOp::var("y", DType::Int32, 0, i64::MAX);
let c3 = UOp::native_const(3i32);
let c5 = UOp::native_const(5i32);
let term1 = c3.try_mul(&x).unwrap();
let term2 = c5.try_mul(&y).unwrap();
let add = term1.try_add(&term2).unwrap();
let result = matcher.rewrite(&add, &mut ());
assert!(matches!(result, RewriteResult::NoMatch));
}
#[test]
fn test_alu_fold_addition_chain() {
let matcher = symbolic();
let x = UOp::var("x", DType::Int32, 0, i64::MAX);
let c3 = UOp::native_const(3i32);
let c5 = UOp::native_const(5i32);
let add1 = x.try_add(&c3).unwrap();
let add2 = add1.try_add(&c5).unwrap();
let result = matcher.rewrite(&add2, &mut ());
assert!(matches!(result, RewriteResult::Rewritten(_)));
if let RewriteResult::Rewritten(rewritten) = result {
if let Op::Binary(BinaryOp::Add, var, c) = rewritten.op() {
assert!(Arc::ptr_eq(var, &x));
if let Op::Const(cv) = c.op() {
assert_eq!(cv.0, ConstValue::Int(8));
} else {
panic!("Expected constant, got {:?}", c.op());
}
} else {
panic!("Expected Add, got {:?}", rewritten.op());
}
}
}
#[test]
fn test_alu_fold_multiplication_chain() {
let matcher = symbolic();
let x = UOp::var("x", DType::Int32, 0, i64::MAX);
let c2 = UOp::native_const(2i32);
let c3 = UOp::native_const(3i32);
let mul1 = x.try_mul(&c2).unwrap();
let mul2 = mul1.try_mul(&c3).unwrap();
let result = matcher.rewrite(&mul2, &mut ());
assert!(matches!(result, RewriteResult::Rewritten(_)));
if let RewriteResult::Rewritten(rewritten) = result {
if let Op::Binary(BinaryOp::Mul, var, c) = rewritten.op() {
assert!(Arc::ptr_eq(var, &x));
if let Op::Const(cv) = c.op() {
assert_eq!(cv.0, ConstValue::Int(6));
} else {
panic!("Expected constant, got {:?}", c.op());
}
} else {
panic!("Expected Mul, got {:?}", rewritten.op());
}
}
}
#[test]
fn test_alu_fold_sub_then_add_positive() {
let matcher = symbolic();
let x = UOp::var("x", DType::Int32, 0, i64::MAX);
let c3 = UOp::native_const(3i32);
let c5 = UOp::native_const(5i32);
let sub = x.try_sub(&c3).unwrap();
let add = sub.try_add(&c5).unwrap();
let result = matcher.rewrite(&add, &mut ());
assert!(matches!(result, RewriteResult::Rewritten(_)));
if let RewriteResult::Rewritten(rewritten) = result {
if let Op::Binary(BinaryOp::Add, var, c) = rewritten.op() {
assert!(Arc::ptr_eq(var, &x));
if let Op::Const(cv) = c.op() {
assert_eq!(cv.0, ConstValue::Int(2));
} else {
panic!("Expected constant, got {:?}", c.op());
}
} else {
panic!("Expected Add, got {:?}", rewritten.op());
}
}
}
#[test]
fn test_alu_fold_sub_then_add_negative() {
let matcher = symbolic();
let x = UOp::var("x", DType::Int32, 0, i64::MAX);
let c5 = UOp::native_const(5i32);
let c3 = UOp::native_const(3i32);
let sub = x.try_sub(&c5).unwrap();
let add = sub.try_add(&c3).unwrap();
let result = matcher.rewrite(&add, &mut ());
assert!(matches!(result, RewriteResult::Rewritten(_)));
if let RewriteResult::Rewritten(rewritten) = result {
if let Op::Binary(BinaryOp::Sub, var, c) = rewritten.op() {
assert!(Arc::ptr_eq(var, &x));
if let Op::Const(cv) = c.op() {
assert_eq!(cv.0, ConstValue::Int(2));
} else {
panic!("Expected constant, got {:?}", c.op());
}
} else {
panic!("Expected Sub, got {:?}", rewritten.op());
}
}
}
#[test]
fn test_alu_fold_add_then_sub_positive() {
let matcher = symbolic();
let x = UOp::var("x", DType::Int32, 0, i64::MAX);
let c5 = UOp::native_const(5i32);
let c3 = UOp::native_const(3i32);
let add = x.try_add(&c5).unwrap();
let sub = add.try_sub(&c3).unwrap();
let result = matcher.rewrite(&sub, &mut ());
assert!(matches!(result, RewriteResult::Rewritten(_)));
if let RewriteResult::Rewritten(rewritten) = result {
if let Op::Binary(BinaryOp::Add, var, c) = rewritten.op() {
assert!(Arc::ptr_eq(var, &x));
if let Op::Const(cv) = c.op() {
assert_eq!(cv.0, ConstValue::Int(2));
} else {
panic!("Expected constant, got {:?}", c.op());
}
} else {
panic!("Expected Add, got {:?}", rewritten.op());
}
}
}
#[test]
fn test_alu_fold_add_then_sub_negative() {
let matcher = symbolic();
let x = UOp::var("x", DType::Int32, 0, i64::MAX);
let c3 = UOp::native_const(3i32);
let c5 = UOp::native_const(5i32);
let add = x.try_add(&c3).unwrap();
let sub = add.try_sub(&c5).unwrap();
let result = matcher.rewrite(&sub, &mut ());
assert!(matches!(result, RewriteResult::Rewritten(_)));
if let RewriteResult::Rewritten(rewritten) = result {
if let Op::Binary(BinaryOp::Sub, var, c) = rewritten.op() {
assert!(Arc::ptr_eq(var, &x));
if let Op::Const(cv) = c.op() {
assert_eq!(cv.0, ConstValue::Int(2));
} else {
panic!("Expected constant, got {:?}", c.op());
}
} else {
panic!("Expected Sub, got {:?}", rewritten.op());
}
}
}
#[test]
fn test_division_cancel_with_multiplication() {
let matcher = symbolic_simple();
let a = UOp::var("a", DType::Int32, 0, i64::MAX);
let b = UOp::var("b", DType::Int32, 0, i64::MAX);
let mul = a.try_mul(&b).unwrap();
let div = mul.try_div(&b).unwrap();
let result = matcher.rewrite(&div, &mut ());
assert!(matches!(result, RewriteResult::Rewritten(_)));
if let RewriteResult::Rewritten(rewritten) = result {
assert!(Arc::ptr_eq(&rewritten, &a));
}
}
#[test]
fn test_division_chain_folding() {
let matcher = symbolic();
let a = UOp::var("a", DType::Int32, 0, i64::MAX);
let c2 = UOp::native_const(2i32);
let c3 = UOp::native_const(3i32);
let div1 = a.try_div(&c2).unwrap();
let div2 = div1.try_div(&c3).unwrap();
let result = matcher.rewrite(&div2, &mut ());
assert!(matches!(result, RewriteResult::Rewritten(_)));
if let RewriteResult::Rewritten(rewritten) = result {
if let Op::Binary(BinaryOp::Idiv, var, c) = rewritten.op() {
assert!(Arc::ptr_eq(var, &a));
if let Op::Const(cv) = c.op() {
assert_eq!(cv.0, ConstValue::Int(6));
} else {
panic!("Expected constant, got {:?}", c.op());
}
} else {
panic!("Expected Idiv, got {:?}", rewritten.op());
}
}
}
#[test]
fn test_exact_division_with_divides_helper() {
let matcher = symbolic();
let x = UOp::var("x", DType::Int32, 0, i64::MAX);
let c12 = UOp::native_const(12i32);
let c3 = UOp::native_const(3i32);
let mul = c12.try_mul(&x).unwrap();
let div = mul.try_div(&c3).unwrap();
let result = matcher.rewrite(&div, &mut ());
assert!(matches!(result, RewriteResult::Rewritten(_)));
if let RewriteResult::Rewritten(rewritten) = result {
if let Op::Binary(BinaryOp::Mul, c, var) = rewritten.op() {
if let Op::Const(cv) = c.op() {
assert_eq!(cv.0, ConstValue::Int(4));
} else {
panic!("Expected constant, got {:?}", c.op());
}
assert!(Arc::ptr_eq(var, &x));
} else {
panic!("Expected Mul, got {:?}", rewritten.op());
}
}
}
#[test]
fn test_modulo_with_divisible_left_operand() {
let matcher = symbolic();
let x = UOp::var("x", DType::Int32, 0, i64::MAX);
let y = UOp::var("y", DType::Int32, 0, i64::MAX);
let c6 = UOp::native_const(6i32);
let c3 = UOp::native_const(3i32);
let mul = c6.try_mul(&x).unwrap();
let add = mul.try_add(&y).unwrap();
let modulo = add.try_mod(&c3).unwrap();
let result = matcher.rewrite(&modulo, &mut ());
assert!(matches!(result, RewriteResult::Rewritten(_)));
if let RewriteResult::Rewritten(rewritten) = result {
if let Op::Binary(BinaryOp::Mod, var, c) = rewritten.op() {
assert!(Arc::ptr_eq(var, &y));
assert!(Arc::ptr_eq(c, &c3));
} else {
panic!("Expected Mod, got {:?}", rewritten.op());
}
}
}
#[test]
fn test_modulo_with_divisible_right_operand() {
let matcher = symbolic();
let x = UOp::var("x", DType::Int32, 0, i64::MAX);
let y = UOp::var("y", DType::Int32, 0, i64::MAX);
let c9 = UOp::native_const(9i32);
let c3 = UOp::native_const(3i32);
let mul = c9.try_mul(&y).unwrap();
let add = x.try_add(&mul).unwrap();
let modulo = add.try_mod(&c3).unwrap();
let result = matcher.rewrite(&modulo, &mut ());
assert!(matches!(result, RewriteResult::Rewritten(_)));
if let RewriteResult::Rewritten(rewritten) = result {
if let Op::Binary(BinaryOp::Mod, var, c) = rewritten.op() {
assert!(Arc::ptr_eq(var, &x));
assert!(Arc::ptr_eq(c, &c3));
} else {
panic!("Expected Mod, got {:?}", rewritten.op());
}
}
}
#[test]
fn test_modulo_no_simplification() {
let matcher = symbolic_simple();
let x = UOp::var("x", DType::Int32, 0, i64::MAX);
let y = UOp::var("y", DType::Int32, 0, i64::MAX);
let c3 = UOp::native_const(3i32);
let add = x.try_add(&y).unwrap();
let modulo = add.try_mod(&c3).unwrap();
let result = matcher.rewrite(&modulo, &mut ());
assert!(matches!(result, RewriteResult::NoMatch));
}
#[test]
fn test_distribute_division_over_addition() {
let matcher = symbolic();
let x = UOp::var("x", DType::Int32, 0, i64::MAX);
let y = UOp::var("y", DType::Int32, 0, i64::MAX);
let c6 = UOp::native_const(6i32);
let c9 = UOp::native_const(9i32);
let c3 = UOp::native_const(3i32);
let term1 = c6.try_mul(&x).unwrap();
let term2 = c9.try_mul(&y).unwrap();
let add = term1.try_add(&term2).unwrap();
let div = add.try_div(&c3).unwrap();
let result = matcher.rewrite(&div, &mut ());
assert!(matches!(result, RewriteResult::Rewritten(_)));
if let RewriteResult::Rewritten(rewritten) = result {
if let Op::Binary(BinaryOp::Add, left, right) = rewritten.op() {
if let Op::Binary(BinaryOp::Mul, c, var) = left.op() {
if let Op::Const(cv) = c.op() {
assert_eq!(cv.0, ConstValue::Int(2));
}
assert!(Arc::ptr_eq(var, &x));
} else {
panic!("Expected Mul on left, got {:?}", left.op());
}
if let Op::Binary(BinaryOp::Mul, c, var) = right.op() {
if let Op::Const(cv) = c.op() {
assert_eq!(cv.0, ConstValue::Int(3));
}
assert!(Arc::ptr_eq(var, &y));
} else {
panic!("Expected Mul on right, got {:?}", right.op());
}
} else {
panic!("Expected Add, got {:?}", rewritten.op());
}
}
}
#[test]
fn test_distribute_division_over_subtraction() {
let matcher = symbolic();
let x = UOp::var("x", DType::Int32, 0, i64::MAX);
let y = UOp::var("y", DType::Int32, 0, i64::MAX);
let c12 = UOp::native_const(12i32);
let c6 = UOp::native_const(6i32);
let c3 = UOp::native_const(3i32);
let term1 = c12.try_mul(&x).unwrap();
let term2 = c6.try_mul(&y).unwrap();
let sub = term1.try_sub(&term2).unwrap();
let div = sub.try_div(&c3).unwrap();
let result = matcher.rewrite(&div, &mut ());
assert!(matches!(result, RewriteResult::Rewritten(_)));
if let RewriteResult::Rewritten(rewritten) = result {
if let Op::Binary(BinaryOp::Sub, left, right) = rewritten.op() {
if let Op::Binary(BinaryOp::Mul, c, var) = left.op() {
if let Op::Const(cv) = c.op() {
assert_eq!(cv.0, ConstValue::Int(4));
}
assert!(Arc::ptr_eq(var, &x));
}
if let Op::Binary(BinaryOp::Mul, c, var) = right.op() {
if let Op::Const(cv) = c.op() {
assert_eq!(cv.0, ConstValue::Int(2));
}
assert!(Arc::ptr_eq(var, &y));
}
} else {
panic!("Expected Sub, got {:?}", rewritten.op());
}
}
}
#[test]
fn test_distribute_multiplication_over_addition() {
let matcher = symbolic();
let x = UOp::var("x", DType::Index, 0, i64::MAX);
let y = UOp::index_const(5);
let c2 = UOp::index_const(2);
let add = x.try_add(&y).unwrap();
let mul = c2.try_mul(&add).unwrap();
let result = matcher.rewrite(&mul, &mut ());
assert!(matches!(result, RewriteResult::Rewritten(_)));
if let RewriteResult::Rewritten(rewritten) = result {
if let Op::Binary(BinaryOp::Add, left, right) = rewritten.op() {
if let Op::Binary(BinaryOp::Mul, c, var) = left.op() {
assert!(Arc::ptr_eq(c, &c2));
assert!(Arc::ptr_eq(var, &x));
}
if let Op::Binary(BinaryOp::Mul, c, var) = right.op() {
assert!(Arc::ptr_eq(c, &c2));
assert!(Arc::ptr_eq(var, &y));
}
} else {
panic!("Expected Add, got {:?}", rewritten.op());
}
}
}
#[test]
fn test_distribute_multiplication_over_addition_reversed() {
use crate::rewrite::graph_rewrite;
let x = UOp::var("x", DType::Index, 0, i64::MAX);
let y = UOp::var("y", DType::Index, 0, i64::MAX);
let c3 = UOp::index_const(3);
let add = x.try_add(&y).unwrap();
let mul = add.try_mul(&c3).unwrap();
let result = graph_rewrite(sym(), mul, &mut ());
assert!(matches!(result.op(), Op::Binary(BinaryOp::Add, ..)), "Expected Add, got {:?}", result.op());
}
#[test]
fn test_distribute_large_constant() {
use crate::rewrite::graph_rewrite;
let x = UOp::var("x", DType::Index, 0, i64::MAX);
let y = UOp::var("y", DType::Index, 0, i64::MAX);
let c100 = UOp::index_const(100);
let add = x.try_add(&y).unwrap();
let mul = add.try_mul(&c100).unwrap();
let result = graph_rewrite(sym(), mul, &mut ());
assert!(matches!(result.op(), Op::Binary(BinaryOp::Add, ..)), "Expected Add, got {:?}", result.op());
}
#[test]
#[ignore = "Distribution patterns conflict with compositional optimization"]
fn test_compositional_optimization_minimal_failure() {
use crate::rewrite::graph_rewrite;
let matcher = symbolic_simple();
let a_var = UOp::var("a", DType::Int32, 0, 1);
let zero = UOp::native_const(0i32);
let two = UOp::native_const(2i32);
let add = zero.try_add(&a_var).unwrap();
let a = add.try_mul(&two).unwrap();
let b = two.clone();
let expr_unopt = a.try_mul(&b).unwrap();
let direct_opt = graph_rewrite(&matcher, expr_unopt, &mut ());
let opt_a = graph_rewrite(&matcher, a.clone(), &mut ());
let opt_b = graph_rewrite(&matcher, b.clone(), &mut ());
let expr_opt_subs = opt_a.try_mul(&opt_b).unwrap();
let final_opt = graph_rewrite(&matcher, expr_opt_subs, &mut ());
fn count_ops(uop: &Arc<UOp>) -> usize {
match uop.op() {
Op::Binary(_, left, right) => 1 + count_ops(left) + count_ops(right),
Op::Unary(_, src) => 1 + count_ops(src),
Op::Ternary(_, a, b, c) => 1 + count_ops(a) + count_ops(b) + count_ops(c),
_ => 0,
}
}
let direct_count = count_ops(&direct_opt);
let final_count = count_ops(&final_opt);
println!("=== COMPOSITIONAL OPTIMIZATION DEBUG ===");
println!("Original a: (0 + var(\"a\")) * 2");
println!("Original b: 2");
println!("Full expr: ((0 + var(\"a\")) * 2) * 2");
println!();
println!("Optimized a: {:?}", opt_a.op());
println!("Optimized b: {:?}", opt_b.op());
println!();
println!("Direct optimization: {} ops -> {:?}", direct_count, direct_opt.op());
println!("Compositional optimization: {} ops -> {:?}", final_count, final_opt.op());
println!();
assert!(
final_count <= direct_count + 1,
"Compositional optimization ({} ops) should be nearly as good as direct ({} ops)",
final_count,
direct_count
);
}
#[test]
fn test_multiplication_chain_folding() {
let matcher = symbolic();
let a = UOp::var("a", DType::Int32, 0, i64::MAX);
let c2 = UOp::native_const(2i32);
let mul1 = a.try_mul(&c2).unwrap();
let mul2 = mul1.try_mul(&c2).unwrap();
let result = matcher.rewrite(&mul2, &mut ());
println!("=== MULTIPLICATION CHAIN TEST ===");
println!("Input: (var(\"a\") * 2) * 2");
match &result {
crate::pattern::RewriteResult::Rewritten(r) => {
println!("Result: {:?}", r.op());
}
_ => {
println!("Result: No rewrite");
}
}
assert!(matches!(result, crate::pattern::RewriteResult::Rewritten(_)));
if let crate::pattern::RewriteResult::Rewritten(rewritten) = result {
if let Op::Binary(BinaryOp::Mul, var, c) = rewritten.op() {
assert!(Arc::ptr_eq(var, &a), "Variable should be unchanged");
if let Op::Const(cv) = c.op() {
assert_eq!(cv.0, ConstValue::Int(4), "Constant should be folded to 4");
} else {
panic!("Expected constant 4, got {:?}", c.op());
}
} else {
panic!("Expected Binary(Mul, a, 4), got {:?}", rewritten.op());
}
}
}
#[test]
fn test_double_not_elimination() {
let matcher = symbolic_simple();
let x = UOp::var("x", DType::Bool, 0, 1);
let not_x = x.not();
let not_not_x = not_x.not();
let result = matcher.rewrite(¬_not_x, &mut ());
assert!(matches!(result, RewriteResult::Rewritten(_)));
if let RewriteResult::Rewritten(rewritten) = result {
assert!(Arc::ptr_eq(&rewritten, &x));
}
}
#[test]
fn test_double_not_int() {
let matcher = symbolic_simple();
let x = UOp::var("x", DType::Int32, 0, i64::MAX);
let not_x = x.not();
let not_not_x = not_x.not();
let result = matcher.rewrite(¬_not_x, &mut ());
assert!(matches!(result, RewriteResult::Rewritten(_)));
if let RewriteResult::Rewritten(rewritten) = result {
assert!(Arc::ptr_eq(&rewritten, &x));
}
}
#[test]
fn test_xor_self_cancellation() {
let matcher = symbolic_simple();
let x = UOp::var("x", DType::Int32, 0, 100);
let xor_self = x.try_xor_op(&x).unwrap();
let result = matcher.rewrite(&xor_self, &mut ());
assert!(matches!(result, RewriteResult::Rewritten(_)));
if let RewriteResult::Rewritten(rewritten) = result {
if let Op::Const(cv) = rewritten.op() {
assert_eq!(cv.0, ConstValue::Int(0));
} else {
panic!("Expected constant 0");
}
}
}
#[test]
fn test_double_neg_elimination() {
use crate::rewrite::graph_rewrite;
let matcher = symbolic();
let x = UOp::var("x", DType::Int32, 0, i64::MAX);
let neg_x = x.neg();
let neg_neg_x = neg_x.neg();
let result = graph_rewrite(&matcher, neg_neg_x, &mut ());
assert!(Arc::ptr_eq(&result, &x), "double neg should simplify back to x, got: {}", result.tree());
}
#[test]
fn test_double_neg_float() {
use crate::rewrite::graph_rewrite;
let matcher = symbolic();
let x = UOp::var("x", DType::Float32, 0, i64::MAX);
let neg_x = x.neg();
let neg_neg_x = neg_x.neg();
let result = graph_rewrite(&matcher, neg_neg_x, &mut ());
assert!(Arc::ptr_eq(&result, &x), "double neg should simplify back to x, got: {}", result.tree());
}
#[test]
fn test_propagate_invalid_through_neg() {
use crate::rewrite::graph_rewrite;
use crate::symbolic::patterns::propagate_invalid;
let matcher = propagate_invalid();
let cond = UOp::var("c", DType::Bool, 0, 1);
let x = UOp::var("x", DType::Index, 0, 100);
let invalid = UOp::new(Op::Invalid, DType::Index);
let gated = UOp::try_where(cond.clone(), x.clone(), invalid.clone()).unwrap();
let negated = gated.neg();
let result = graph_rewrite(&matcher, negated, &mut ());
let Op::Ternary(morok_ir::TernaryOp::Where, c, inner, inv) = result.op() else {
panic!("Expected WHERE, got: {}", result.tree());
};
assert!(Arc::ptr_eq(c, &cond), "condition should be preserved");
assert!(matches!(inv.op(), Op::Invalid), "false branch should be Invalid");
assert!(
matches!(inner.op(), Op::Binary(morok_ir::BinaryOp::Mul, _, _)),
"true branch should be MUL(x, -1), got: {}",
inner.tree()
);
}
#[test]
fn test_max_self_identity() {
let matcher = symbolic();
let x = UOp::var("x", DType::Int32, 0, 100);
let max_self = x.try_max(&x).unwrap();
let result = matcher.rewrite(&max_self, &mut ());
assert!(matches!(result, RewriteResult::Rewritten(_)));
if let RewriteResult::Rewritten(rewritten) = result {
assert!(Arc::ptr_eq(&rewritten, &x));
}
}
#[test]
fn test_max_self_float() {
let matcher = symbolic();
let x = UOp::var("x", DType::Float32, 0, i64::MAX);
let max_self = x.try_max(&x).unwrap();
let result = matcher.rewrite(&max_self, &mut ());
assert!(matches!(result, RewriteResult::Rewritten(_)));
if let RewriteResult::Rewritten(rewritten) = result {
assert!(Arc::ptr_eq(&rewritten, &x));
}
}
#[test]
fn test_pow_zero_is_one() {
let matcher = symbolic_simple();
let x = UOp::var("x", DType::Int32, 0, 100);
let zero = UOp::native_const(0i32);
let pow = x.try_pow(&zero).unwrap();
let result = matcher.rewrite(&pow, &mut ());
assert!(matches!(result, RewriteResult::Rewritten(_)));
if let RewriteResult::Rewritten(rewritten) = result {
if let Op::Const(cv) = rewritten.op() {
assert_eq!(cv.0, ConstValue::Int(1));
} else {
panic!("Expected constant 1");
}
}
}
#[test]
fn test_pow_one_is_identity() {
let matcher = symbolic_simple();
let x = UOp::var("x", DType::Int32, 0, 100);
let one = UOp::native_const(1i32);
let pow = x.try_pow(&one).unwrap();
let result = matcher.rewrite(&pow, &mut ());
assert!(matches!(result, RewriteResult::Rewritten(_)));
if let RewriteResult::Rewritten(rewritten) = result {
assert!(Arc::ptr_eq(&rewritten, &x));
}
}
#[test]
fn test_pow_float_zero() {
let matcher = symbolic_simple();
let x = UOp::var("x", DType::Float32, 0, 100);
let zero = UOp::native_const(0.0f32);
let pow = x.try_pow(&zero).unwrap();
let result = matcher.rewrite(&pow, &mut ());
assert!(matches!(result, RewriteResult::Rewritten(_)));
if let RewriteResult::Rewritten(rewritten) = result {
if let Op::Const(cv) = rewritten.op() {
assert_eq!(cv.0, ConstValue::Float(1.0));
} else {
panic!("Expected constant 1.0");
}
}
}
#[test]
fn test_where_same_branches() {
let matcher = symbolic_simple();
let cond = UOp::var("cond", DType::Bool, 0, 1);
let x = UOp::var("x", DType::Int32, 0, 100);
let where_op = UOp::try_where(cond, Arc::clone(&x), Arc::clone(&x)).unwrap();
let result = matcher.rewrite(&where_op, &mut ());
assert!(matches!(result, RewriteResult::Rewritten(_)));
if let RewriteResult::Rewritten(rewritten) = result {
assert!(Arc::ptr_eq(&rewritten, &x));
}
}
#[test]
fn test_where_bool_true_false() {
let matcher = symbolic_simple();
let x = UOp::var("x", DType::Bool, 0, 1);
let true_val = UOp::native_const(true);
let false_val = UOp::native_const(false);
let where_op = UOp::try_where(Arc::clone(&x), true_val, false_val).unwrap();
let result = matcher.rewrite(&where_op, &mut ());
assert!(matches!(result, RewriteResult::Rewritten(_)));
if let RewriteResult::Rewritten(rewritten) = result {
assert!(Arc::ptr_eq(&rewritten, &x));
}
}
#[test]
fn test_where_bool_false_true() {
let matcher = symbolic_simple();
let x = UOp::var("x", DType::Bool, 0, 1);
let false_val = UOp::native_const(false);
let true_val = UOp::native_const(true);
let where_op = UOp::try_where(Arc::clone(&x), false_val, true_val).unwrap();
let result = matcher.rewrite(&where_op, &mut ());
assert!(matches!(result, RewriteResult::Rewritten(_)));
if let RewriteResult::Rewritten(rewritten) = result {
if let Op::Unary(UnaryOp::Not, inner) = rewritten.op() {
assert!(Arc::ptr_eq(inner, &x));
} else {
panic!("Expected Not(x)");
}
}
}
#[test]
fn test_where_negated_condition() {
let matcher = symbolic();
let cond = UOp::var("cond", DType::Bool, 0, 1);
let not_cond = cond.not();
let t = UOp::var("t", DType::Int32, 0, 100);
let f = UOp::var("f", DType::Int32, 0, 100);
let where_op = UOp::try_where(not_cond, Arc::clone(&t), Arc::clone(&f)).unwrap();
let result = matcher.rewrite(&where_op, &mut ());
assert!(matches!(result, RewriteResult::Rewritten(_)));
if let RewriteResult::Rewritten(rewritten) = result {
if let Op::Ternary(TernaryOp::Where, new_cond, new_t, new_f) = rewritten.op() {
assert!(Arc::ptr_eq(new_cond, &cond));
assert!(Arc::ptr_eq(new_t, &f)); assert!(Arc::ptr_eq(new_f, &t)); } else {
panic!("Expected Where with swapped branches");
}
}
}
#[test]
fn test_where_const_true_condition() {
let matcher = symbolic_simple();
let true_cond = UOp::native_const(true);
let t = UOp::var("t", DType::Int32, 0, 100);
let f = UOp::var("f", DType::Int32, 0, 100);
let where_op = UOp::try_where(true_cond, Arc::clone(&t), f).unwrap();
let result = matcher.rewrite(&where_op, &mut ());
assert!(matches!(result, RewriteResult::Rewritten(_)));
if let RewriteResult::Rewritten(rewritten) = result {
assert!(Arc::ptr_eq(&rewritten, &t));
}
}
#[test]
fn test_where_const_false_condition() {
let matcher = symbolic_simple();
let false_cond = UOp::native_const(false);
let t = UOp::var("t", DType::Int32, 0, 100);
let f = UOp::var("f", DType::Int32, 0, 100);
let where_op = UOp::try_where(false_cond, t, Arc::clone(&f)).unwrap();
let result = matcher.rewrite(&where_op, &mut ());
assert!(matches!(result, RewriteResult::Rewritten(_)));
if let RewriteResult::Rewritten(rewritten) = result {
assert!(Arc::ptr_eq(&rewritten, &f));
}
}
#[test]
fn test_lt_bounds_always_true() {
let matcher = symbolic();
let a = UOp::var("a", DType::Int32, 0, 8); let c77 = UOp::native_const(77i32);
let lt = a.try_cmplt(&c77).unwrap();
let result = matcher.rewrite(<, &mut ());
assert!(matches!(result, RewriteResult::Rewritten(_)));
if let RewriteResult::Rewritten(rewritten) = result {
if let Op::Const(cv) = rewritten.op() {
assert_eq!(cv.0, ConstValue::Bool(true));
} else {
panic!("Expected Const(Bool(true)), got {:?}", rewritten.op());
}
}
}
#[test]
fn test_lt_bounds_always_true_edge() {
let matcher = symbolic();
let a = UOp::var("a", DType::Int32, 0, 8);
let c9 = UOp::native_const(9i32);
let lt = a.try_cmplt(&c9).unwrap();
let result = matcher.rewrite(<, &mut ());
assert!(matches!(result, RewriteResult::Rewritten(_)));
if let RewriteResult::Rewritten(rewritten) = result {
if let Op::Const(cv) = rewritten.op() {
assert_eq!(cv.0, ConstValue::Bool(true));
} else {
panic!("Expected Const(Bool(true)), got {:?}", rewritten.op());
}
}
}
#[test]
fn test_lt_bounds_indeterminate() {
let matcher = symbolic_simple();
let a = UOp::var("a", DType::Int32, 0, 8);
let c5 = UOp::native_const(5i32);
let lt = a.try_cmplt(&c5).unwrap();
let result = matcher.rewrite(<, &mut ());
assert!(matches!(result, RewriteResult::NoMatch));
}
#[test]
fn test_lt_bounds_always_false() {
let matcher = symbolic();
let a = UOp::var("a", DType::Int32, 0, 8);
let c0 = UOp::native_const(0i32);
let lt = a.try_cmplt(&c0).unwrap();
let result = matcher.rewrite(<, &mut ());
assert!(matches!(result, RewriteResult::Rewritten(_)));
if let RewriteResult::Rewritten(rewritten) = result {
if let Op::Const(cv) = rewritten.op() {
assert_eq!(cv.0, ConstValue::Bool(false));
} else {
panic!("Expected Const(Bool(false)), got {:?}", rewritten.op());
}
}
}
#[test]
fn test_lt_two_vars_always_true() {
let matcher = symbolic();
let a = UOp::var("a", DType::Int32, 0, 4); let b_base = UOp::var("b", DType::Int32, 0, 5); let c5 = UOp::native_const(5i32);
let b = b_base.try_add(&c5).unwrap();
let lt = a.try_cmplt(&b).unwrap();
let result = matcher.rewrite(<, &mut ());
assert!(matches!(result, RewriteResult::Rewritten(_)));
if let RewriteResult::Rewritten(rewritten) = result {
if let Op::Const(cv) = rewritten.op() {
assert_eq!(cv.0, ConstValue::Bool(true));
} else {
panic!("Expected Const(Bool(true)), got {:?}", rewritten.op());
}
}
}
#[test]
fn test_lt_two_vars_always_false() {
let matcher = symbolic();
let a_base = UOp::var("a", DType::Int32, 0, 5); let c5 = UOp::native_const(5i32);
let a = a_base.try_add(&c5).unwrap(); let b = UOp::var("b", DType::Int32, 0, 4);
let lt = a.try_cmplt(&b).unwrap();
let result = matcher.rewrite(<, &mut ());
assert!(matches!(result, RewriteResult::Rewritten(_)));
if let RewriteResult::Rewritten(rewritten) = result {
if let Op::Const(cv) = rewritten.op() {
assert_eq!(cv.0, ConstValue::Bool(false));
} else {
panic!("Expected Const(Bool(false)), got {:?}", rewritten.op());
}
}
}
#[test]
fn test_ge_bounds_always_true() {
let matcher = symbolic();
let a_base = UOp::var("a", DType::Int32, 0, 5); let c3 = UOp::native_const(3i32);
let a = a_base.try_add(&c3).unwrap();
let lt = a.try_cmplt(&c3).unwrap();
let result = matcher.rewrite(<, &mut ());
assert!(matches!(result, RewriteResult::Rewritten(_)));
if let RewriteResult::Rewritten(rewritten) = result {
if let Op::Const(cv) = rewritten.op() {
assert_eq!(cv.0, ConstValue::Bool(false));
} else {
panic!("Expected Const(Bool(false)), got {:?}", rewritten.op());
}
}
}
#[test]
fn test_eq_bounds_always_false() {
let matcher = symbolic();
let a = UOp::var("a", DType::Int32, 0, 4); let b_base = UOp::var("b", DType::Int32, 0, 10); let c10 = UOp::native_const(10i32);
let b = b_base.try_add(&c10).unwrap();
let eq = a.try_cmpeq(&b).unwrap();
let result = matcher.rewrite(&eq, &mut ());
assert!(matches!(result, RewriteResult::Rewritten(_)));
if let RewriteResult::Rewritten(rewritten) = result {
if let Op::Const(cv) = rewritten.op() {
assert_eq!(cv.0, ConstValue::Bool(false));
} else {
panic!("Expected Const(Bool(false)), got {:?}", rewritten.op());
}
}
}
#[test]
fn test_ne_bounds_always_true() {
let matcher = symbolic();
let a = UOp::var("a", DType::Int32, 0, 4); let b_base = UOp::var("b", DType::Int32, 0, 10);
let c10 = UOp::native_const(10i32);
let b = b_base.try_add(&c10).unwrap();
let ne = a.try_cmpne(&b).unwrap();
let result = matcher.rewrite(&ne, &mut ());
assert!(matches!(result, RewriteResult::Rewritten(_)));
if let RewriteResult::Rewritten(rewritten) = result {
if let Op::Const(cv) = rewritten.op() {
assert_eq!(cv.0, ConstValue::Bool(true));
} else {
panic!("Expected Const(Bool(true)), got {:?}", rewritten.op());
}
}
}
#[test]
fn test_nested_div_div() {
let matcher = symbolic();
let a = UOp::var("a", DType::Int32, 0, i64::MAX);
let c10 = UOp::native_const(10i32);
let c9 = UOp::native_const(9i32);
let div1 = a.try_div(&c10).unwrap();
let div2 = div1.try_div(&c9).unwrap();
let result = matcher.rewrite(&div2, &mut ());
assert!(matches!(result, RewriteResult::Rewritten(_)));
if let RewriteResult::Rewritten(rewritten) = result {
if let Op::Binary(BinaryOp::Idiv, var, c) = rewritten.op() {
assert!(Arc::ptr_eq(var, &a));
if let Op::Const(cv) = c.op() {
assert_eq!(cv.0, ConstValue::Int(90));
} else {
panic!("Expected constant 90, got {:?}", c.op());
}
} else {
panic!("Expected Idiv, got {:?}", rewritten.op());
}
}
}
#[test]
fn test_nested_mul_mul() {
let matcher = symbolic();
let a = UOp::var("a", DType::Int32, 0, i64::MAX);
let c10 = UOp::native_const(10i32);
let c9 = UOp::native_const(9i32);
let mul1 = a.try_mul(&c10).unwrap();
let mul2 = mul1.try_mul(&c9).unwrap();
let result = matcher.rewrite(&mul2, &mut ());
assert!(matches!(result, RewriteResult::Rewritten(_)));
if let RewriteResult::Rewritten(rewritten) = result {
if let Op::Binary(BinaryOp::Mul, var, c) = rewritten.op() {
assert!(Arc::ptr_eq(var, &a));
if let Op::Const(cv) = c.op() {
assert_eq!(cv.0, ConstValue::Int(90));
} else {
panic!("Expected constant 90, got {:?}", c.op());
}
} else {
panic!("Expected Mul, got {:?}", rewritten.op());
}
}
}
#[test]
fn test_nested_mod_mod_same_divisor() {
let matcher = symbolic_simple();
let a = UOp::var("a", DType::Int32, 0, i64::MAX);
let c5 = UOp::native_const(5i32);
let mod1 = a.try_mod(&c5).unwrap();
let mod2 = mod1.try_mod(&c5).unwrap();
let result = matcher.rewrite(&mod2, &mut ());
assert!(matches!(result, RewriteResult::Rewritten(_)));
if let RewriteResult::Rewritten(rewritten) = result {
if let Op::Binary(BinaryOp::Mod, var, c) = rewritten.op() {
assert!(Arc::ptr_eq(var, &a));
assert!(Arc::ptr_eq(c, &c5));
} else {
panic!("Expected Mod(a, 5), got {:?}", rewritten.op());
}
}
}
#[test]
fn test_nested_add_add() {
let matcher = symbolic();
let a = UOp::var("a", DType::Int32, 0, i64::MAX);
let c3 = UOp::native_const(3i32);
let c5 = UOp::native_const(5i32);
let add1 = a.try_add(&c3).unwrap();
let add2 = add1.try_add(&c5).unwrap();
let result = matcher.rewrite(&add2, &mut ());
assert!(matches!(result, RewriteResult::Rewritten(_)));
if let RewriteResult::Rewritten(rewritten) = result {
if let Op::Binary(BinaryOp::Add, var, c) = rewritten.op() {
assert!(Arc::ptr_eq(var, &a));
if let Op::Const(cv) = c.op() {
assert_eq!(cv.0, ConstValue::Int(8));
} else {
panic!("Expected constant 8, got {:?}", c.op());
}
} else {
panic!("Expected Add, got {:?}", rewritten.op());
}
}
}
#[test]
fn test_nested_sub_sub() {
let matcher = symbolic();
let a = UOp::var("a", DType::Int32, 0, i64::MAX);
let c3 = UOp::native_const(3i32);
let c5 = UOp::native_const(5i32);
let sub1 = a.try_sub(&c3).unwrap();
let sub2 = sub1.try_sub(&c5).unwrap();
let result = matcher.rewrite(&sub2, &mut ());
assert!(matches!(result, RewriteResult::Rewritten(_)));
if let RewriteResult::Rewritten(rewritten) = result {
if let Op::Binary(BinaryOp::Sub, var, c) = rewritten.op() {
assert!(Arc::ptr_eq(var, &a));
if let Op::Const(cv) = c.op() {
assert_eq!(cv.0, ConstValue::Int(8));
} else {
panic!("Expected constant 8, got {:?}", c.op());
}
} else {
panic!("Expected Sub, got {:?}", rewritten.op());
}
}
}
#[test]
fn test_bool_or_not_tautology() {
let matcher = symbolic();
let x = UOp::var("x", DType::Bool, 0, 1); let not_x = x.not();
let or_op = x.try_or_op(¬_x).unwrap();
let result = matcher.rewrite(&or_op, &mut ());
assert!(matches!(result, RewriteResult::Rewritten(_)));
if let RewriteResult::Rewritten(rewritten) = result {
if let Op::Const(cv) = rewritten.op() {
assert_eq!(cv.0, ConstValue::Bool(true));
} else {
panic!("Expected Const(Bool(true)), got {:?}", rewritten.op());
}
}
}
#[test]
fn test_bool_and_not_contradiction() {
let matcher = symbolic();
let x = UOp::var("x", DType::Bool, 0, 1);
let not_x = x.not();
let and_op = x.try_and_op(¬_x).unwrap();
let result = matcher.rewrite(&and_op, &mut ());
assert!(matches!(result, RewriteResult::Rewritten(_)));
if let RewriteResult::Rewritten(rewritten) = result {
if let Op::Const(cv) = rewritten.op() {
assert_eq!(cv.0, ConstValue::Bool(false));
} else {
panic!("Expected Const(Bool(false)), got {:?}", rewritten.op());
}
}
}
#[test]
fn test_bool_or_true_absorb() {
let matcher = symbolic();
let x = UOp::var("x", DType::Bool, 0, 1);
let true_const = UOp::const_(DType::Bool, ConstValue::Bool(true));
let or_op = true_const.try_or_op(&x).unwrap();
let result = matcher.rewrite(&or_op, &mut ());
assert!(matches!(result, RewriteResult::Rewritten(_)));
if let RewriteResult::Rewritten(rewritten) = result {
if let Op::Const(cv) = rewritten.op() {
assert_eq!(cv.0, ConstValue::Bool(true));
} else {
panic!("Expected Const(Bool(true)), got {:?}", rewritten.op());
}
}
}
#[test]
fn test_bool_and_false_absorb() {
let matcher = symbolic_simple();
let x = UOp::var("x", DType::Bool, 0, 1);
let false_const = UOp::const_(DType::Bool, ConstValue::Bool(false));
let and_op = false_const.try_and_op(&x).unwrap();
let result = matcher.rewrite(&and_op, &mut ());
assert!(matches!(result, RewriteResult::Rewritten(_)));
if let RewriteResult::Rewritten(rewritten) = result {
if let Op::Const(cv) = rewritten.op() {
assert_eq!(cv.0, ConstValue::Bool(false));
} else {
panic!("Expected Const(Bool(false)), got {:?}", rewritten.op());
}
}
}
#[test]
fn test_bool_and_true_identity() {
let matcher = symbolic();
let x = UOp::var("x", DType::Bool, 0, 1);
let true_const = UOp::const_(DType::Bool, ConstValue::Bool(true));
let and_op = true_const.try_and_op(&x).unwrap();
let result = matcher.rewrite(&and_op, &mut ());
assert!(matches!(result, RewriteResult::Rewritten(_)));
if let RewriteResult::Rewritten(rewritten) = result {
assert!(Arc::ptr_eq(&rewritten, &x));
}
}
#[test]
fn test_bool_or_false_identity() {
let matcher = symbolic_simple();
let x = UOp::var("x", DType::Bool, 0, 1);
let false_const = UOp::const_(DType::Bool, ConstValue::Bool(false));
let or_op = false_const.try_or_op(&x).unwrap();
let result = matcher.rewrite(&or_op, &mut ());
assert!(matches!(result, RewriteResult::Rewritten(_)));
if let RewriteResult::Rewritten(rewritten) = result {
assert!(Arc::ptr_eq(&rewritten, &x));
}
}
#[test]
fn test_lt_const_offset() {
let matcher = symbolic();
let a = UOp::var("a", DType::Int32, 0, i64::MAX);
let c2 = UOp::native_const(2i32);
let c5 = UOp::native_const(5i32);
let add = a.try_add(&c2).unwrap();
let lt = add.try_cmplt(&c5).unwrap();
let result = matcher.rewrite(<, &mut ());
assert!(matches!(result, RewriteResult::Rewritten(_)));
if let RewriteResult::Rewritten(rewritten) = result {
if let Op::Binary(BinaryOp::Lt, var, c) = rewritten.op() {
assert!(Arc::ptr_eq(var, &a));
if let Op::Const(cv) = c.op() {
assert_eq!(cv.0, ConstValue::Int(3)); } else {
panic!("Expected constant 3, got {:?}", c.op());
}
} else {
panic!("Expected Lt, got {:?}", rewritten.op());
}
}
}
#[test]
fn test_lt_const_offset_negative() {
let matcher = symbolic();
let a = UOp::var("a", DType::Int32, 0, i64::MAX);
let c10 = UOp::native_const(10i32);
let c5 = UOp::native_const(5i32);
let add = a.try_add(&c10).unwrap();
let lt = add.try_cmplt(&c5).unwrap();
let result = matcher.rewrite(<, &mut ());
assert!(matches!(result, RewriteResult::Rewritten(_)));
if let RewriteResult::Rewritten(rewritten) = result {
if let Op::Binary(BinaryOp::Lt, var, c) = rewritten.op() {
assert!(Arc::ptr_eq(var, &a));
if let Op::Const(cv) = c.op() {
assert_eq!(cv.0, ConstValue::Int(-5)); } else {
panic!("Expected constant -5, got {:?}", c.op());
}
} else {
panic!("Expected Lt, got {:?}", rewritten.op());
}
}
}
#[test]
fn test_lt_negation_flip() {
let matcher = symbolic();
let a = UOp::var("a", DType::Int32, 0, i64::MAX);
let b = UOp::var("b", DType::Int32, 0, i64::MAX);
let neg_a = a.neg();
let neg_b = b.neg();
let lt = neg_a.try_cmplt(&neg_b).unwrap();
let result = matcher.rewrite(<, &mut ());
assert!(matches!(result, RewriteResult::Rewritten(_)));
if let RewriteResult::Rewritten(rewritten) = result {
if let Op::Binary(BinaryOp::Lt, lhs, rhs) = rewritten.op() {
assert!(Arc::ptr_eq(lhs, &b));
assert!(Arc::ptr_eq(rhs, &a));
} else {
panic!("Expected Lt(b, a), got {:?}", rewritten.op());
}
}
}
#[test]
fn test_div_mod_recombine() {
let matcher = symbolic_simple();
let x = UOp::var("x", DType::Int32, 0, i64::MAX);
let n = UOp::native_const(4i32);
let mod_part = x.try_mod(&n).unwrap();
let div_part = x.try_div(&n).unwrap();
let mul_part = div_part.try_mul(&n).unwrap();
let add = mod_part.try_add(&mul_part).unwrap();
let result = matcher.rewrite(&add, &mut ());
assert!(matches!(result, RewriteResult::Rewritten(_)));
if let RewriteResult::Rewritten(rewritten) = result {
assert!(Arc::ptr_eq(&rewritten, &x));
}
}
#[test]
fn test_div_mod_recombine_commutative() {
let matcher = symbolic_simple();
let x = UOp::var("x", DType::Int32, 0, i64::MAX);
let n = UOp::native_const(4i32);
let div_part = x.try_div(&n).unwrap();
let mul_part = div_part.try_mul(&n).unwrap();
let mod_part = x.try_mod(&n).unwrap();
let add = mul_part.try_add(&mod_part).unwrap();
let result = matcher.rewrite(&add, &mut ());
assert!(matches!(result, RewriteResult::Rewritten(_)));
if let RewriteResult::Rewritten(rewritten) = result {
assert!(Arc::ptr_eq(&rewritten, &x));
}
}
#[test]
fn test_nested_div_const() {
let matcher = symbolic();
let a = UOp::var("a", DType::Int32, 0, i64::MAX);
let c2 = UOp::native_const(2i32);
let c1 = UOp::native_const(1i32);
let div_inner = a.try_div(&c2).unwrap();
let add = div_inner.try_add(&c1).unwrap();
let div_outer = add.try_div(&c2).unwrap();
let result = matcher.rewrite(&div_outer, &mut ());
assert!(matches!(result, RewriteResult::Rewritten(_)));
if let RewriteResult::Rewritten(rewritten) = result {
if let Op::Binary(BinaryOp::Idiv, lhs, rhs) = rewritten.op() {
if let Op::Binary(BinaryOp::Add, var, c) = lhs.op() {
assert!(Arc::ptr_eq(var, &a));
if let Op::Const(cv) = c.op() {
assert_eq!(cv.0, ConstValue::Int(2)); } else {
panic!("Expected constant 2, got {:?}", c.op());
}
} else {
panic!("Expected Add, got {:?}", lhs.op());
}
if let Op::Const(cv) = rhs.op() {
assert_eq!(cv.0, ConstValue::Int(4)); } else {
panic!("Expected constant 4, got {:?}", rhs.op());
}
} else {
panic!("Expected Idiv, got {:?}", rewritten.op());
}
}
}
#[test]
fn test_nested_div_const_larger() {
use crate::rewrite::graph_rewrite;
let matcher = symbolic();
let a = UOp::var("a", DType::Int32, 0, i64::MAX);
let c3 = UOp::native_const(3i32);
let c5 = UOp::native_const(5i32);
let c4 = UOp::native_const(4i32);
let div_inner = a.try_div(&c3).unwrap();
let add = div_inner.try_add(&c5).unwrap();
let div_outer = add.try_div(&c4).unwrap();
let result = graph_rewrite(&matcher, div_outer, &mut ());
assert!(
!matches!(result.op(), Op::Binary(BinaryOp::Idiv, lhs, _) if matches!(lhs.op(), Op::Binary(BinaryOp::Add, inner, _) if matches!(inner.op(), Op::Binary(BinaryOp::Idiv, _, _)))),
"Nested (a//c1 + c2) // c3 should be simplified, got: {}",
result.tree()
);
}
#[test]
fn test_div_mod_recombine_different_n() {
let matcher = symbolic_simple();
let x = UOp::var("x", DType::Int32, 0, i64::MAX);
let n4 = UOp::native_const(4i32);
let n5 = UOp::native_const(5i32);
let mod_part = x.try_mod(&n4).unwrap();
let div_part = x.try_div(&n5).unwrap();
let mul_part = div_part.try_mul(&n4).unwrap();
let add = mod_part.try_add(&mul_part).unwrap();
let result = matcher.rewrite(&add, &mut ());
assert!(!matches!(result, RewriteResult::Rewritten(ref r) if Arc::ptr_eq(r, &x)));
}
#[test]
fn test_div_mod_property_identity() {
let x_val = 17i32;
let n_val = 5i32;
let mod_result = x_val % n_val; let div_result = x_val / n_val; let recombined = mod_result + div_result * n_val;
assert_eq!(recombined, x_val);
}
#[test]
fn test_where_merge_branches() {
let matcher = symbolic_simple();
let a = UOp::var("a", DType::Bool, 0, 1);
let b = UOp::var("b", DType::Bool, 0, 1);
let c = UOp::var("c", DType::Int32, 0, i64::MAX);
let d = UOp::var("d", DType::Int32, 0, i64::MAX);
let inner_where = UOp::try_where(b.clone(), c.clone(), d.clone()).unwrap();
let outer_where = UOp::try_where(a.clone(), inner_where, d.clone()).unwrap();
let result = matcher.rewrite(&outer_where, &mut ());
assert!(matches!(result, RewriteResult::Rewritten(_)));
if let RewriteResult::Rewritten(rewritten) = result {
if let Op::Ternary(TernaryOp::Where, cond, true_val, false_val) = rewritten.op() {
if let Op::Binary(BinaryOp::And, lhs, rhs) = cond.op() {
assert!(Arc::ptr_eq(lhs, &a) || Arc::ptr_eq(lhs, &b));
assert!(Arc::ptr_eq(rhs, &a) || Arc::ptr_eq(rhs, &b));
} else {
panic!("Expected And condition, got {:?}", cond.op());
}
assert!(Arc::ptr_eq(true_val, &c));
assert!(Arc::ptr_eq(false_val, &d));
} else {
panic!("Expected Where, got {:?}", rewritten.op());
}
}
}
#[test]
fn test_where_merge_branches_no_match() {
let matcher = symbolic_simple();
let a = UOp::var("a", DType::Bool, 0, 1);
let b = UOp::var("b", DType::Bool, 0, 1);
let c = UOp::var("c", DType::Int32, 0, i64::MAX);
let d = UOp::var("d", DType::Int32, 0, i64::MAX);
let e = UOp::var("e", DType::Int32, 0, i64::MAX);
let inner_where = UOp::try_where(b.clone(), c.clone(), d.clone()).unwrap();
let outer_where = UOp::try_where(a.clone(), inner_where.clone(), e.clone()).unwrap();
let result = matcher.rewrite(&outer_where, &mut ());
if let RewriteResult::Rewritten(rewritten) = &result
&& let Op::Ternary(TernaryOp::Where, cond, _, _) = rewritten.op()
{
if let Op::Binary(BinaryOp::And, _, _) = cond.op() {
panic!("Should not merge branches when false values differ");
}
}
}
#[test]
fn test_cast_where_push() {
let matcher = sym();
let s = UOp::var("s", DType::Bool, 0, 1);
let a = UOp::native_const(1i32);
let b = UOp::native_const(0i32);
let where_op = UOp::try_where(s.clone(), a.clone(), b.clone()).unwrap();
let cast_where = where_op.cast(DType::Float32);
let result = matcher.rewrite(&cast_where, &mut ());
assert!(matches!(result, RewriteResult::Rewritten(_)));
if let RewriteResult::Rewritten(rewritten) = result {
if let Op::Ternary(TernaryOp::Where, cond, true_val, false_val) = rewritten.op() {
assert!(Arc::ptr_eq(cond, &s));
assert!(matches!(true_val.op(), Op::Cast { .. }));
assert!(matches!(false_val.op(), Op::Cast { .. }));
} else {
panic!("Expected Where, got {:?}", rewritten.op());
}
}
}
#[test]
fn test_vmin_vmax_collapse_addition() {
let matcher = symbolic();
let x = UOp::var("x", DType::Int32, 5, 5); let c3 = UOp::native_const(3i32);
let add = x.try_add(&c3).unwrap();
use crate::rewrite::graph_rewrite;
let result = graph_rewrite(&matcher, add, &mut ());
if let Op::Const(cv) = result.op() {
assert_eq!(cv.0, ConstValue::Int(8));
} else {
panic!("Expected const 8, got {:?}", result.op());
}
}
#[test]
fn test_bool_mul_is_and() {
let matcher = symbolic_simple();
let x = UOp::var("x", DType::Bool, 0, 1);
let y = UOp::var("y", DType::Bool, 0, 1);
let mul = x.try_mul(&y).unwrap();
let result = matcher.rewrite(&mul, &mut ());
assert!(matches!(result, RewriteResult::Rewritten(_)));
if let RewriteResult::Rewritten(rewritten) = result {
if let Op::Binary(BinaryOp::And, lhs, rhs) = rewritten.op() {
assert!(Arc::ptr_eq(lhs, &x));
assert!(Arc::ptr_eq(rhs, &y));
} else {
panic!("Expected And, got {:?}", rewritten.op());
}
}
}
#[test]
fn test_bool_add_is_or() {
let matcher = symbolic_simple();
let x = UOp::var("x", DType::Bool, 0, 1);
let y = UOp::var("y", DType::Bool, 0, 1);
let add = x.try_add(&y).unwrap();
let result = matcher.rewrite(&add, &mut ());
assert!(matches!(result, RewriteResult::Rewritten(_)));
if let RewriteResult::Rewritten(rewritten) = result {
if let Op::Binary(BinaryOp::Or, lhs, rhs) = rewritten.op() {
assert!(Arc::ptr_eq(lhs, &x));
assert!(Arc::ptr_eq(rhs, &y));
} else {
panic!("Expected Or, got {:?}", rewritten.op());
}
}
}
#[test]
fn test_bool_max_is_or() {
let matcher = symbolic_simple();
let x = UOp::var("x", DType::Bool, 0, 1);
let y = UOp::var("y", DType::Bool, 0, 1);
let max = x.try_max(&y).unwrap();
let result = matcher.rewrite(&max, &mut ());
assert!(matches!(result, RewriteResult::Rewritten(_)));
if let RewriteResult::Rewritten(rewritten) = result {
if let Op::Binary(BinaryOp::Or, lhs, rhs) = rewritten.op() {
assert!(Arc::ptr_eq(lhs, &x));
assert!(Arc::ptr_eq(rhs, &y));
} else {
panic!("Expected Or, got {:?}", rewritten.op());
}
}
}
#[test]
fn test_bool_mul_non_bool_no_match() {
let matcher = symbolic_simple();
let x = UOp::var("x", DType::Int32, 0, 100);
let y = UOp::var("y", DType::Int32, 0, 100);
let mul = x.try_mul(&y).unwrap();
let result = matcher.rewrite(&mul, &mut ());
if let RewriteResult::Rewritten(rewritten) = &result {
assert!(!matches!(rewritten.op(), Op::Binary(BinaryOp::And, ..)));
}
}
#[test]
fn test_term_combine_x_plus_xc() {
let matcher = symbolic();
let x = UOp::var("x", DType::Int32, 0, i64::MAX);
let c3 = UOp::native_const(3i32);
let xc = x.try_mul(&c3).unwrap();
let add = x.try_add(&xc).unwrap();
let result = matcher.rewrite(&add, &mut ());
assert!(matches!(result, RewriteResult::Rewritten(_)));
if let RewriteResult::Rewritten(rewritten) = result {
if let Op::Binary(BinaryOp::Mul, var, c) = rewritten.op() {
assert!(Arc::ptr_eq(var, &x));
if let Op::Const(cv) = c.op() {
assert_eq!(cv.0, ConstValue::Int(4));
} else {
panic!("Expected const 4, got {:?}", c.op());
}
} else {
panic!("Expected Mul, got {:?}", rewritten.op());
}
}
}
#[test]
fn test_term_combine_y_plus_x_plus_x() {
let matcher = symbolic();
let x = UOp::var("x", DType::Int32, 0, i64::MAX);
let y = UOp::var("y", DType::Int32, 0, i64::MAX);
let yx = y.try_add(&x).unwrap();
let add = yx.try_add(&x).unwrap();
let result = matcher.rewrite(&add, &mut ());
assert!(matches!(result, RewriteResult::Rewritten(_)));
if let RewriteResult::Rewritten(rewritten) = result {
if let Op::Binary(BinaryOp::Add, lhs, rhs) = rewritten.op() {
assert!(Arc::ptr_eq(lhs, &y));
if let Op::Binary(BinaryOp::Mul, _, c) = rhs.op()
&& let Op::Const(cv) = c.op()
{
assert_eq!(cv.0, ConstValue::Int(2));
}
} else {
panic!("Expected Add, got {:?}", rewritten.op());
}
}
}
#[test]
fn test_neg_one_times_x_plus_const() {
use crate::rewrite::graph_rewrite;
let matcher = symbolic();
let x = UOp::var("x", DType::Int32, 0, i64::MAX);
let neg_one = UOp::native_const(-1i32);
let c3 = UOp::native_const(3i32);
let add = x.try_add(&c3).unwrap();
let mul = neg_one.try_mul(&add).unwrap();
let result = graph_rewrite(&matcher, mul, &mut ());
if let Op::Binary(BinaryOp::Add, lhs, rhs) = result.op() {
match lhs.op() {
Op::Unary(UnaryOp::Neg, inner) => assert!(Arc::ptr_eq(inner, &x)),
Op::Binary(BinaryOp::Mul, _, _) => { }
_ => panic!("Expected Neg(x) or Mul(-1, x), got {:?}", lhs.op()),
}
if let Op::Const(cv) = rhs.op() {
assert_eq!(cv.0, ConstValue::Int(-3));
} else {
panic!("Expected const -3, got {:?}", rhs.op());
}
} else {
panic!("Expected Add after full rewrite, got {:?}", result.op());
}
}
#[test]
fn test_range_mod_end() {
use crate::rewrite::graph_rewrite;
let matcher = symbolic();
let end = UOp::index_const(8);
let range = UOp::range(end.clone(), 0);
let modulo = range.try_mod(&end).unwrap();
let result = graph_rewrite(&matcher, modulo, &mut ());
assert!(matches!(result.op(), Op::Range { .. }) || matches!(result.op(), Op::Const(_)));
}
#[test]
fn test_range_div_end() {
use crate::rewrite::graph_rewrite;
let matcher = symbolic();
let end = UOp::index_const(8);
let range = UOp::range(end.clone(), 0);
let div = range.try_div(&end).unwrap();
let result = graph_rewrite(&matcher, div, &mut ());
if let Op::Const(cv) = result.op() {
assert_eq!(cv.0, ConstValue::Int(0));
} else {
panic!("Expected const 0, got {:?}", result.op());
}
}
#[test]
fn test_mul_lt_ceil_div() {
let matcher = symbolic();
let x = UOp::var("x", DType::Index, 0, i64::MAX);
let c3 = UOp::index_const(3);
let c10 = UOp::index_const(10);
let mul = c3.try_mul(&x).unwrap();
let lt = mul.try_cmplt(&c10).unwrap();
let result = matcher.rewrite(<, &mut ());
assert!(matches!(result, RewriteResult::Rewritten(_)));
if let RewriteResult::Rewritten(rewritten) = result {
if let Op::Binary(BinaryOp::Lt, var, c) = rewritten.op() {
assert!(Arc::ptr_eq(var, &x));
if let Op::Const(cv) = c.op() {
assert_eq!(cv.0, ConstValue::Int(4));
} else {
panic!("Expected const 4, got {:?}", c.op());
}
} else {
panic!("Expected Lt, got {:?}", rewritten.op());
}
}
}
#[test]
fn test_mul_lt_exact_div() {
let matcher = symbolic();
let x = UOp::var("x", DType::Index, 0, 100);
let c4 = UOp::index_const(4);
let c12 = UOp::index_const(12);
let mul = c4.try_mul(&x).unwrap();
let lt = mul.try_cmplt(&c12).unwrap();
let result = matcher.rewrite(<, &mut ());
assert!(matches!(result, RewriteResult::Rewritten(_)));
if let RewriteResult::Rewritten(rewritten) = result {
if let Op::Binary(BinaryOp::Lt, var, c) = rewritten.op() {
assert!(Arc::ptr_eq(var, &x));
if let Op::Const(cv) = c.op() {
assert_eq!(cv.0, ConstValue::Int(3));
} else {
panic!("Expected const 3, got {:?}", c.op());
}
} else {
panic!("Expected Lt, got {:?}", rewritten.op());
}
}
}
#[test]
fn test_where_alu_combine_add() {
let matcher = symbolic();
let c = UOp::var("c", DType::Bool, 0, 1);
let t1 = UOp::native_const(1i32);
let b = UOp::var("b", DType::Int32, 0, 100);
let t2 = UOp::native_const(2i32);
let e = UOp::var("e", DType::Int32, 0, 100);
let w1 = UOp::try_where(c.clone(), t1, b.clone()).unwrap();
let w2 = UOp::try_where(c.clone(), t2, e.clone()).unwrap();
let add = w1.try_add(&w2).unwrap();
let result = matcher.rewrite(&add, &mut ());
assert!(matches!(result, RewriteResult::Rewritten(_)));
if let RewriteResult::Rewritten(rewritten) = result {
if let Op::Ternary(TernaryOp::Where, cond, _true_br, false_br) = rewritten.op() {
assert!(Arc::ptr_eq(cond, &c));
assert!(matches!(false_br.op(), Op::Binary(BinaryOp::Add, ..)));
} else {
panic!("Expected Where, got {:?}", rewritten.op());
}
}
}
#[test]
fn test_where_alu_combine_associative_add() {
use crate::rewrite::graph_rewrite;
let c = UOp::var("c", DType::Bool, 0, 1);
let y = UOp::var("y", DType::Int32, 0, 100);
let t1 = UOp::native_const(1i32);
let b = UOp::var("b", DType::Int32, 0, 100);
let t2 = UOp::native_const(2i32);
let e = UOp::var("e", DType::Int32, 0, 100);
let w1 = UOp::try_where(c.clone(), t1, b.clone()).unwrap();
let w2 = UOp::try_where(c.clone(), t2, e.clone()).unwrap();
let inner_add = y.try_add(&w1).unwrap();
let outer_add = inner_add.try_add(&w2).unwrap();
let result = graph_rewrite(&symbolic(), outer_add.clone(), &mut ());
let where_count = result.toposort().iter().filter(|n| matches!(n.op(), Op::Ternary(TernaryOp::Where, ..))).count();
assert!(where_count <= 1, "Expected WHERE nodes to be combined, got {where_count}");
}
#[test]
fn test_where_alu_combine_different_cond_no_match() {
let matcher = symbolic_simple();
let c1 = UOp::var("c1", DType::Bool, 0, 1);
let c2 = UOp::var("c2", DType::Bool, 0, 1);
let a = UOp::var("a", DType::Int32, 0, 100);
let b = UOp::var("b", DType::Int32, 0, 100);
let d = UOp::var("d", DType::Int32, 0, 100);
let e = UOp::var("e", DType::Int32, 0, 100);
let w1 = UOp::try_where(c1, a, b).unwrap();
let w2 = UOp::try_where(c2, d, e).unwrap();
let add = w1.try_add(&w2).unwrap();
let result = matcher.rewrite(&add, &mut ());
if let RewriteResult::Rewritten(rewritten) = &result {
assert!(!matches!(rewritten.op(), Op::Ternary(TernaryOp::Where, ..)));
}
}
#[test]
fn test_simplify_valid_redundant_upper_bounds() {
use crate::symbolic::valid_simplification::simplify_valid;
let x = UOp::range_const(20, 0);
let c10 = UOp::index_const(10);
let c5 = UOp::index_const(5);
let lt10 = x.lt(&c10);
let lt5 = x.lt(&c5);
let combined = lt10.and_(<5);
let result = simplify_valid(&combined);
if let Some(simplified) = result {
assert!(simplified.node_count() <= combined.node_count(), "Simplified result should not be larger");
}
}
#[test]
fn test_simplify_valid_no_parseable_clauses() {
use crate::symbolic::valid_simplification::simplify_valid;
let a = UOp::native_const(true);
let b = UOp::native_const(true);
let combined = a.and_(&b);
let result = simplify_valid(&combined);
assert!(result.is_none(), "Non-Lt clauses should not be simplified");
}
#[test]
fn test_drop_and_clauses_irrelevant_removed() {
use crate::rewrite::graph_rewrite;
use crate::symbolic::valid_simplification::pm_drop_and_clauses;
let r0 = UOp::range_const(10, 0);
let r1 = UOp::range_const(20, 1);
let c5 = UOp::index_const(5);
let c15 = UOp::index_const(15);
let clause1 = r0.lt(&c5);
let clause2 = r1.lt(&c15);
let combined_cond = clause1.try_and_op(&clause2).unwrap();
let expr = r0.try_add(&UOp::index_const(1)).unwrap();
let invalid = UOp::invalid_marker();
let gated = UOp::try_where(combined_cond, expr, invalid).unwrap();
let matcher = pm_drop_and_clauses();
let result = graph_rewrite(matcher, gated.clone(), &mut ());
assert!(!Arc::ptr_eq(&result, &gated), "Expected clause dropping");
assert!(matches!(result.op(), Op::Ternary(TernaryOp::Where, ..)));
}
#[test]
fn test_drop_and_clauses_all_relevant_kept() {
use crate::rewrite::graph_rewrite;
use crate::symbolic::valid_simplification::pm_drop_and_clauses;
let r0 = UOp::range_const(10, 0);
let c5 = UOp::index_const(5);
let c8 = UOp::index_const(8);
let clause1 = r0.lt(&c5);
let clause2 = r0.lt(&c8);
let combined = clause1.try_and_op(&clause2).unwrap();
let expr = r0.try_add(&UOp::index_const(1)).unwrap();
let invalid = UOp::invalid_marker();
let gated = UOp::try_where(combined, expr, invalid).unwrap();
let matcher = pm_drop_and_clauses();
let result = graph_rewrite(matcher, gated.clone(), &mut ());
assert!(Arc::ptr_eq(&result, &gated), "Both clauses relevant, should not change");
}
#[test]
fn test_drop_and_clauses_single_clause_no_change() {
use crate::rewrite::graph_rewrite;
use crate::symbolic::valid_simplification::pm_drop_and_clauses;
let r0 = UOp::range_const(10, 0);
let c5 = UOp::index_const(5);
let clause = r0.lt(&c5);
let expr = r0.try_add(&UOp::index_const(1)).unwrap();
let invalid = UOp::invalid_marker();
let gated = UOp::try_where(clause, expr, invalid).unwrap();
let matcher = pm_drop_and_clauses();
let result = graph_rewrite(matcher, gated.clone(), &mut ());
assert!(Arc::ptr_eq(&result, &gated), "Single clause should not change");
}
#[test]
fn test_sound_vmin_vmax_const() {
use morok_ir::uop::range_eval::compute_sound_vmin_vmax;
let c = UOp::native_const(42i32);
let result = compute_sound_vmin_vmax(&c);
assert_eq!(result, Some((ConstValue::Int(42), ConstValue::Int(42))));
}
#[test]
fn test_sound_vmin_vmax_range() {
use morok_ir::uop::range_eval::compute_sound_vmin_vmax;
let r = UOp::range_const(10, 0);
let result = compute_sound_vmin_vmax(&r);
assert_eq!(result, Some((ConstValue::Int(0), ConstValue::Int(9))));
}
#[test]
fn test_sound_vmin_vmax_add() {
use morok_ir::uop::range_eval::compute_sound_vmin_vmax;
let r = UOp::range_const(10, 0);
let c = UOp::index_const(5);
let sum = r.try_add(&c).unwrap();
let result = compute_sound_vmin_vmax(&sum);
assert_eq!(result, Some((ConstValue::Int(5), ConstValue::Int(14))));
}
#[test]
fn test_sound_vmin_vmax_and_const_mask() {
use morok_ir::uop::range_eval::compute_sound_vmin_vmax;
let r = UOp::range_const(100, 0);
let mask = UOp::native_const(7i32);
let r_int = r.cast(DType::Scalar(morok_dtype::ScalarDType::Int32));
let result_node = r_int.and_(&mask);
let result = compute_sound_vmin_vmax(&result_node);
assert_eq!(result, Some((ConstValue::Int(0), ConstValue::Int(7))));
}
#[test]
fn test_sound_vmin_vmax_and_variable_mask_unsound() {
use morok_ir::uop::range_eval::compute_sound_vmin_vmax;
let r1 = UOp::range_const(100, 0);
let r2 = UOp::range_const(50, 1);
let r1_int = r1.cast(DType::Scalar(morok_dtype::ScalarDType::Int32));
let r2_int = r2.cast(DType::Scalar(morok_dtype::ScalarDType::Int32));
let result_node = r1_int.and_(&r2_int);
let result = compute_sound_vmin_vmax(&result_node);
assert!(result.is_none(), "AND with variable mask should be unsound");
}
#[test]
fn test_sound_vmin_vmax_load_unsound() {
use morok_ir::uop::range_eval::compute_sound_vmin_vmax;
let buf = UOp::new_buffer(morok_dtype::DeviceSpec::Cpu, 100, DType::Scalar(morok_dtype::ScalarDType::Float32));
let idx = UOp::index_const(0);
let index = UOp::index().buffer(buf.clone()).indices(vec![idx]).call().unwrap();
let load = UOp::load().buffer(buf).index(index).call();
let result = compute_sound_vmin_vmax(&load);
assert!(result.is_none(), "LOAD should be unsound");
}
#[test]
fn test_sound_vmin_vmax_nested_sound() {
use morok_ir::uop::range_eval::compute_sound_vmin_vmax;
let c = UOp::index_const(3);
let r = UOp::range_const(10, 0);
let sum = c.try_add(&r).unwrap();
let result = compute_sound_vmin_vmax(&sum);
assert_eq!(result, Some((ConstValue::Int(3), ConstValue::Int(12))));
}
#[test]
fn test_sym_phase3_neg_distribution() {
use crate::rewrite::graph_rewrite;
use crate::symbolic::sym;
let x = UOp::range_const(10, 0);
let y = UOp::range_const(20, 1);
let x_int = x.cast(DType::Scalar(morok_dtype::ScalarDType::Int32));
let y_int = y.cast(DType::Scalar(morok_dtype::ScalarDType::Int32));
let sum = x_int.try_add(&y_int).unwrap();
let neg_one = UOp::native_const(-1i32);
let product = neg_one.try_mul(&sum).unwrap();
let result = graph_rewrite(sym(), product.clone(), &mut ());
assert!(!matches!(result.op(), Op::Binary(BinaryOp::Mul, ..)), "Expected negation distribution, got Mul");
}
#[test]
fn test_substitute_gated_skips_irrelevant_subtrees() {
use morok_ir::UOpKey;
use std::collections::HashMap;
let r0 = UOp::range_const(10, 0);
let r1 = UOp::range_const(20, 1);
let replacement = UOp::index_const(42);
let sum = r0.try_add(&r1).unwrap();
#[allow(clippy::mutable_key_type)]
let map = HashMap::from([(UOpKey(r0.clone()), replacement.clone())]);
let result = sum.substitute_gated(&map);
if let Op::Binary(BinaryOp::Add, lhs, rhs) = result.op() {
assert!(Arc::ptr_eq(lhs, &replacement) || Arc::ptr_eq(rhs, &replacement), "Expected replacement in result");
assert!(Arc::ptr_eq(lhs, &r1) || Arc::ptr_eq(rhs, &r1), "Expected r1 preserved in result");
} else {
panic!("Expected Add, got {:?}", std::mem::discriminant(result.op()));
}
}
#[test]
fn test_substitute_gated_empty_map() {
use std::collections::HashMap;
let r0 = UOp::range_const(10, 0);
#[allow(clippy::mutable_key_type)]
let map: HashMap<morok_ir::UOpKey, Arc<UOp>> = HashMap::new();
let result = r0.substitute_gated(&map);
assert!(Arc::ptr_eq(&result, &r0), "Empty map should return original");
}