use std::sync::Arc;
use proptest::prelude::*;
use morok_dtype::DType;
use morok_ir::types::{BinaryOp, ConstValue};
use morok_ir::uop::cached_property::CachedProperty;
use morok_ir::uop::properties::VminVmaxProperty;
use morok_ir::{Op, UOp};
use crate::rewrite::graph_rewrite;
use crate::symbolic::{symbolic, symbolic_simple};
use morok_ir::test::property::generators::*;
proptest! {
#![proptest_config(ProptestConfig::with_cases(1000))]
#[test]
fn identity_add_zero_right(x in arb_simple_uop(DType::Int32)) {
let zero = UOp::native_const(0i32);
let expr = x.try_add(&zero).unwrap();
let matcher = symbolic_simple();
let simplified = graph_rewrite(&matcher, expr, &mut ());
prop_assert!(Arc::ptr_eq(&simplified, &x),
"x + 0 should simplify to x");
}
#[test]
fn identity_add_zero_left(x in arb_simple_uop(DType::Int32)) {
let zero = UOp::native_const(0i32);
let expr = zero.try_add(&x).unwrap();
let matcher = symbolic_simple();
let simplified = graph_rewrite(&matcher, expr, &mut ());
prop_assert!(Arc::ptr_eq(&simplified, &x),
"0 + x should simplify to x");
}
#[test]
fn identity_sub_zero(x in arb_simple_uop(DType::Int32)) {
let zero = UOp::native_const(0i32);
let expr = x.try_sub(&zero).unwrap();
let matcher = symbolic_simple();
let simplified = graph_rewrite(&matcher, expr, &mut ());
prop_assert!(Arc::ptr_eq(&simplified, &x),
"x - 0 should simplify to x");
}
#[test]
fn identity_mul_one_right(x in arb_simple_uop(DType::Int32)) {
let one = UOp::native_const(1i32);
let expr = x.try_mul(&one).unwrap();
let matcher = symbolic_simple();
let simplified = graph_rewrite(&matcher, expr, &mut ());
prop_assert!(Arc::ptr_eq(&simplified, &x),
"x * 1 should simplify to x");
}
#[test]
fn identity_mul_one_left(x in arb_simple_uop(DType::Int32)) {
let one = UOp::native_const(1i32);
let expr = one.try_mul(&x).unwrap();
let matcher = symbolic_simple();
let simplified = graph_rewrite(&matcher, expr, &mut ());
prop_assert!(Arc::ptr_eq(&simplified, &x),
"1 * x should simplify to x");
}
#[test]
fn identity_idiv_one(x in arb_simple_uop(DType::Int32)) {
let one = UOp::native_const(1i32);
let expr = x.try_div(&one).unwrap();
let matcher = symbolic_simple();
let simplified = graph_rewrite(&matcher, expr, &mut ());
prop_assert!(Arc::ptr_eq(&simplified, &x),
"x / 1 should simplify to x");
}
#[test]
fn identity_or_zero_right(x in arb_simple_uop(DType::Int32)) {
let zero = UOp::native_const(0i32);
let expr = x.try_or_op(&zero).unwrap();
let matcher = symbolic_simple();
let simplified = graph_rewrite(&matcher, expr, &mut ());
prop_assert!(Arc::ptr_eq(&simplified, &x),
"x | 0 should simplify to x");
}
#[test]
fn identity_xor_zero_right(x in arb_simple_uop(DType::Int32)) {
let zero = UOp::native_const(0i32);
let expr = x.try_xor_op(&zero).unwrap();
let matcher = symbolic_simple();
let simplified = graph_rewrite(&matcher, expr, &mut ());
prop_assert!(Arc::ptr_eq(&simplified, &x),
"x ^ 0 should simplify to x");
}
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(1000))]
#[test]
fn zero_mul_right(x in arb_simple_uop(DType::Int32)) {
let zero = UOp::native_const(0i32);
let expr = x.try_mul(&zero).unwrap();
let matcher = symbolic_simple();
let simplified = graph_rewrite(&matcher, expr, &mut ());
prop_assert!(Arc::ptr_eq(&simplified, &zero),
"x * 0 should simplify to 0");
}
#[test]
fn zero_mul_left(x in arb_simple_uop(DType::Int32)) {
let zero = UOp::native_const(0i32);
let expr = zero.try_mul(&x).unwrap();
let matcher = symbolic_simple();
let simplified = graph_rewrite(&matcher, expr, &mut ());
prop_assert!(Arc::ptr_eq(&simplified, &zero),
"0 * x should simplify to 0");
}
#[test]
fn zero_and_right(x in arb_simple_uop(DType::Int32)) {
let zero = UOp::native_const(0i32);
let expr = x.try_and_op(&zero).unwrap();
let matcher = symbolic_simple();
let simplified = graph_rewrite(&matcher, expr, &mut ());
prop_assert!(Arc::ptr_eq(&simplified, &zero),
"x & 0 should simplify to 0");
}
#[test]
fn zero_and_left(x in arb_simple_uop(DType::Int32)) {
let zero = UOp::native_const(0i32);
let expr = zero.try_and_op(&x).unwrap();
let matcher = symbolic_simple();
let simplified = graph_rewrite(&matcher, expr, &mut ());
prop_assert!(Arc::ptr_eq(&simplified, &zero),
"0 & x should simplify to 0");
}
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(1000))]
#[test]
fn self_idiv_one(x in arb_var_uop(DType::Int32)) {
let expr = x.try_div(&x).unwrap();
let matcher = symbolic_simple();
let simplified = graph_rewrite(&matcher, expr, &mut ());
match simplified.op() {
Op::Const(cv) => prop_assert_eq!(cv.0, ConstValue::Int(1),
"x / x should be 1"),
_ => prop_assert!(false, "x / x should simplify to Const(1), got {:?}", simplified.op()),
}
}
#[test]
fn self_and_identity(x in arb_simple_uop(DType::Int32)) {
let expr = x.try_and_op(&x).unwrap();
let matcher = symbolic_simple();
let simplified = graph_rewrite(&matcher, expr, &mut ());
prop_assert!(Arc::ptr_eq(&simplified, &x),
"x & x should simplify to x");
}
#[test]
fn self_or_identity(x in arb_simple_uop(DType::Int32)) {
let expr = x.try_or_op(&x).unwrap();
let matcher = symbolic_simple();
let simplified = graph_rewrite(&matcher, expr, &mut ());
prop_assert!(Arc::ptr_eq(&simplified, &x),
"x | x should simplify to x");
}
#[test]
fn self_lt_false(x in arb_var_uop(DType::Int32)) {
let expr = x.try_cmplt(&x).unwrap();
let matcher = symbolic_simple();
let simplified = graph_rewrite(&matcher, expr, &mut ());
match simplified.op() {
Op::Const(cv) => prop_assert_eq!(cv.0, ConstValue::Bool(false),
"x < x should be false"),
_ => prop_assert!(false, "x < x should simplify to Const(false), got {:?}", simplified.op()),
}
}
#[test]
fn self_eq_true(x in arb_var_uop(DType::Int32)) {
let expr = x.try_cmpeq(&x).unwrap();
let matcher = symbolic();
let simplified = graph_rewrite(&matcher, expr, &mut ());
match simplified.op() {
Op::Const(cv) => prop_assert_eq!(cv.0, ConstValue::Bool(true),
"x == x should be true"),
_ => prop_assert!(false, "x == x should simplify to Const(true), got {:?}", simplified.op()),
}
}
#[test]
fn self_ne_false(x in arb_var_uop(DType::Int32)) {
let expr = x.try_cmpne(&x).unwrap();
let matcher = symbolic_simple();
let simplified = graph_rewrite(&matcher, expr, &mut ());
match simplified.op() {
Op::Const(cv) => prop_assert_eq!(cv.0, ConstValue::Bool(false),
"x != x should be false"),
_ => prop_assert!(false, "x != x should simplify to Const(false), got {:?}", simplified.op()),
}
}
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(2000))]
#[test]
fn const_fold_add(a in arb_small_int(), b in arb_small_int()) {
let a_uop = UOp::const_(DType::Int32, a);
let b_uop = UOp::const_(DType::Int32, b);
let expr = a_uop.try_add(&b_uop).unwrap();
let matcher = symbolic_simple();
let simplified = graph_rewrite(&matcher, expr, &mut ());
match simplified.op() {
Op::Const(cv) => {
if let (ConstValue::Int(av), ConstValue::Int(bv)) = (a, b) {
let expected = av.wrapping_add(bv);
if let ConstValue::Int(result) = cv.0 {
prop_assert_eq!(result as i32, expected as i32,
"{} + {} should equal {}", av, bv, expected);
}
}
}
_ => prop_assert!(false, "Constant addition should fold to constant"),
}
}
#[test]
fn const_fold_mul(a in arb_small_int(), b in arb_small_int()) {
let a_uop = UOp::const_(DType::Int32, a);
let b_uop = UOp::const_(DType::Int32, b);
let expr = a_uop.try_mul(&b_uop).unwrap();
let matcher = symbolic_simple();
let simplified = graph_rewrite(&matcher, expr, &mut ());
prop_assert!(matches!(simplified.op(), Op::Const(_)),
"Constant multiplication should fold to constant");
}
#[test]
fn const_fold_idiv(a in arb_small_int(), b in nonzero_int()) {
let a_uop = UOp::const_(DType::Int32, a);
let b_uop = UOp::const_(DType::Int32, b);
let expr = a_uop.try_div(&b_uop).unwrap();
let matcher = symbolic_simple();
let simplified = graph_rewrite(&matcher, expr, &mut ());
prop_assert!(matches!(simplified.op(), Op::Const(_)),
"Constant division should fold to constant");
}
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(1000))]
#[test]
fn commutativity_add(
x in arb_simple_uop(DType::Int32),
y in arb_simple_uop(DType::Int32),
) {
let xy = x.try_add(&y).unwrap();
let yx = y.try_add(&x).unwrap();
let matcher = symbolic_simple();
let opt_xy = graph_rewrite(&matcher, xy, &mut ());
let opt_yx = graph_rewrite(&matcher, yx, &mut ());
prop_assert!(
(Arc::ptr_eq(&opt_xy, &opt_yx)) ||
(matches!((opt_xy.op(), opt_yx.op()),
(Op::Binary(BinaryOp::Add, _, _), Op::Binary(BinaryOp::Add, _, _)))),
"Addition should commute after optimization"
);
}
#[test]
fn idempotent_and(x in arb_simple_uop(DType::Int32)) {
let x_and_x = x.try_and_op(&x).unwrap();
let x_and_x_and_x = x_and_x.try_and_op(&x).unwrap();
let matcher = symbolic_simple();
let opt1 = graph_rewrite(&matcher, x_and_x, &mut ());
let opt2 = graph_rewrite(&matcher, x_and_x_and_x, &mut ());
prop_assert!(Arc::ptr_eq(&opt1, &opt2) || Arc::ptr_eq(&opt1, &x),
"x & x should be idempotent: either both simplify to same form or to x");
}
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(500))]
#[test]
fn nested_div_collapse(
a in arb_var_uop(DType::Int32),
b in 2..8i32,
c in 2..8i32,
) {
let (_, vmax) = VminVmaxProperty::get(&a);
if let ConstValue::Int(max) = vmax {
prop_assume!(*max >= (b as i64) * (c as i64));
}
let b_uop = UOp::native_const(b);
let c_uop = UOp::native_const(c);
let div1 = a.try_div(&b_uop).unwrap();
let div2 = div1.try_div(&c_uop).unwrap();
let matcher = symbolic();
let simplified = graph_rewrite(&matcher, div2, &mut ());
if let Op::Binary(BinaryOp::Idiv, var, divisor) = simplified.op() {
prop_assert!(Arc::ptr_eq(var, &a), "Variable should be preserved");
if let Op::Const(cv) = divisor.op() {
let expected = (b as i64) * (c as i64);
prop_assert_eq!(cv.0, ConstValue::Int(expected),
"(a // {}) // {} should simplify to a // {}", b, c, expected);
} else {
prop_assert!(false, "Divisor should be constant");
}
} else {
prop_assert!(false, "Should simplify to Idiv");
}
}
#[test]
fn nested_mul_collapse(
a in arb_var_uop(DType::Int32),
b in 2..20i32,
c in 2..20i32,
) {
let b_uop = UOp::native_const(b);
let c_uop = UOp::native_const(c);
let mul1 = a.try_mul(&b_uop).unwrap();
let mul2 = mul1.try_mul(&c_uop).unwrap();
let matcher = symbolic();
let simplified = graph_rewrite(&matcher, mul2, &mut ());
if let Op::Binary(BinaryOp::Mul, var, multiplier) = simplified.op() {
prop_assert!(Arc::ptr_eq(var, &a), "Variable should be preserved");
if let Op::Const(cv) = multiplier.op() {
let expected = (b as i64) * (c as i64);
prop_assert_eq!(cv.0, ConstValue::Int(expected),
"(a * {}) * {} should simplify to a * {}", b, c, expected);
} else {
prop_assert!(false, "Multiplier should be constant");
}
} else {
prop_assert!(false, "Should simplify to Mul");
}
}
#[test]
fn mod_idempotence(
a in arb_var_uop(DType::Int32),
b in 2..100i32,
) {
let (_, vmax) = VminVmaxProperty::get(&a);
if let ConstValue::Int(max) = vmax {
prop_assume!(*max >= b as i64);
}
let b_uop = UOp::native_const(b);
let mod1 = a.try_mod(&b_uop).unwrap();
let mod2 = mod1.try_mod(&b_uop).unwrap();
let matcher = symbolic_simple();
let simplified = graph_rewrite(&matcher, mod2, &mut ());
if let Op::Binary(BinaryOp::Mod, var, divisor) = simplified.op() {
prop_assert!(Arc::ptr_eq(var, &a), "Variable should be preserved");
prop_assert!(Arc::ptr_eq(divisor, &b_uop), "Divisor should be preserved");
} else {
prop_assert!(false, "Should simplify to Mod(a, b)");
}
}
#[test]
fn nested_add_collapse(
a in arb_var_uop(DType::Int32),
b in -100..100i32,
c in -100..100i32,
) {
let b_uop = UOp::native_const(b);
let c_uop = UOp::native_const(c);
let add1 = a.try_add(&b_uop).unwrap();
let add2 = add1.try_add(&c_uop).unwrap();
let matcher = symbolic();
let simplified = graph_rewrite(&matcher, add2, &mut ());
let expected_sum = (b as i64) + (c as i64);
match simplified.op() {
Op::Binary(BinaryOp::Add, var, addend) => {
prop_assert!(Arc::ptr_eq(var, &a), "Variable should be preserved");
if let Op::Const(cv) = addend.op() {
prop_assert_eq!(cv.0, ConstValue::Int(expected_sum),
"(a + {}) + {} should simplify to a + {}", b, c, expected_sum);
}
}
Op::Binary(BinaryOp::Sub, var, subtrahend) => {
prop_assert!(Arc::ptr_eq(var, &a), "Variable should be preserved");
if let Op::Const(cv) = subtrahend.op() {
prop_assert_eq!(cv.0, ConstValue::Int(-expected_sum),
"(a + {}) + {} should simplify to a - {}", b, c, -expected_sum);
}
}
Op::DefineVar { .. } => {
prop_assert!(Arc::ptr_eq(&simplified, &a),
"(a + {}) + {} = a + 0 should simplify to a", b, c);
prop_assert_eq!(expected_sum, 0,
"DefineVar result should only happen when sum is 0");
}
_ => prop_assert!(false, "Should simplify to Add, Sub, or identity (when sum is 0)"),
}
}
#[test]
fn nested_sub_collapse(
a in arb_var_uop(DType::Int32),
b in 1..100i32,
c in 1..100i32,
) {
let b_uop = UOp::native_const(b);
let c_uop = UOp::native_const(c);
let sub1 = a.try_sub(&b_uop).unwrap();
let sub2 = sub1.try_sub(&c_uop).unwrap();
let matcher = symbolic();
let simplified = graph_rewrite(&matcher, sub2, &mut ());
if let Op::Binary(BinaryOp::Sub, var, subtrahend) = simplified.op() {
prop_assert!(Arc::ptr_eq(var, &a), "Variable should be preserved");
if let Op::Const(cv) = subtrahend.op() {
let expected = (b as i64) + (c as i64);
prop_assert_eq!(cv.0, ConstValue::Int(expected),
"(a - {}) - {} should simplify to a - {}", b, c, expected);
} else {
prop_assert!(false, "Subtrahend should be constant");
}
} else {
prop_assert!(false, "Should simplify to Sub");
}
}
#[test]
fn mul_div_inverse(
a in arb_var_uop(DType::Int32),
b in 1..100i32,
) {
let b_uop = UOp::native_const(b);
let mul = a.try_mul(&b_uop).unwrap();
let div = mul.try_div(&b_uop).unwrap();
let matcher = symbolic_simple();
let simplified = graph_rewrite(&matcher, div, &mut ());
prop_assert!(Arc::ptr_eq(&simplified, &a),
"(a * {}) // {} should simplify to a", b, b);
}
}
fn eval_uop(expr: &Arc<UOp>, vars: &std::collections::HashMap<String, i64>) -> Option<i64> {
use morok_ir::uop::eval::eval_binary_op;
match expr.op() {
Op::Const(cv) => match cv.0 {
ConstValue::Int(v) => Some(v),
_ => None,
},
Op::DefineVar { name, .. } => vars.get(name.as_str()).copied(),
Op::Bind { var, .. } => eval_uop(var, vars),
Op::Binary(op, a, b) => {
let av = eval_uop(a, vars)?;
let bv = eval_uop(b, vars)?;
match eval_binary_op(*op, ConstValue::Int(av), ConstValue::Int(bv))? {
ConstValue::Int(v) => Some(v),
ConstValue::Bool(b) => Some(b as i64),
_ => None,
}
}
Op::Ternary(morok_ir::TernaryOp::Where, cond, t, f) => {
let cv = eval_uop(cond, vars)?;
if cv != 0 { eval_uop(t, vars) } else { eval_uop(f, vars) }
}
Op::Unary(morok_ir::UnaryOp::Not, x) => {
let v = eval_uop(x, vars)?;
Some(if v == 0 { 1 } else { 0 })
}
Op::Invalid => Some(i64::MIN), _ => None,
}
}
fn arb_divmod_expr() -> impl Strategy<Value = (Arc<UOp>, i64, Vec<(String, i64, i64)>)> {
(
-20i64..20, -20i64..20, -20i64..20, 2i64..16, 1i64..8, 1i64..8, )
.prop_map(|(fa, fb, k, div, a_max, b_max)| {
let a = UOp::var("a", DType::Index, 0, a_max);
let b = UOp::var("b", DType::Index, 0, b_max);
let mut expr = UOp::index_const(k);
if fa != 0 {
expr = expr.try_add(&UOp::index_const(fa).try_mul(&a).unwrap()).unwrap();
}
if fb != 0 {
expr = expr.try_add(&UOp::index_const(fb).try_mul(&b).unwrap()).unwrap();
}
let vars = vec![("a".into(), 0, a_max), ("b".into(), 0, b_max)];
(expr, div, vars)
})
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(2000))]
#[test]
fn divmod_mod_soundness((expr, div, var_ranges) in arb_divmod_expr()) {
let c = UOp::index_const(div);
let modded = expr.try_mod(&c).unwrap();
let matcher = symbolic_simple();
let simplified = graph_rewrite(&matcher, modded.clone(), &mut ());
for a_val in var_ranges[0].1..=var_ranges[0].2 {
for b_val in var_ranges[1].1..=var_ranges[1].2 {
let mut vars = std::collections::HashMap::new();
vars.insert("a".into(), a_val);
vars.insert("b".into(), b_val);
if let (Some(orig), Some(simp)) = (eval_uop(&modded, &vars), eval_uop(&simplified, &vars)) {
prop_assert_eq!(orig, simp,
"Mod mismatch at a={}, b={}, div={}.\n Original: {}\n Simplified: {}",
a_val, b_val, div, modded.tree(), simplified.tree());
}
}
}
}
#[test]
fn divmod_idiv_soundness((expr, div, var_ranges) in arb_divmod_expr()) {
let c = UOp::index_const(div);
let divided = expr.try_div(&c).unwrap();
let matcher = symbolic_simple();
let simplified = graph_rewrite(&matcher, divided.clone(), &mut ());
for a_val in var_ranges[0].1..=var_ranges[0].2 {
for b_val in var_ranges[1].1..=var_ranges[1].2 {
let mut vars = std::collections::HashMap::new();
vars.insert("a".into(), a_val);
vars.insert("b".into(), b_val);
if let (Some(orig), Some(simp)) = (eval_uop(÷d, &vars), eval_uop(&simplified, &vars)) {
prop_assert_eq!(orig, simp,
"Idiv mismatch at a={}, b={}, div={}.\n Original: {}\n Simplified: {}",
a_val, b_val, div, divided.tree(), simplified.tree());
}
}
}
}
}