use morok_dtype::DType;
use morok_ir::types::ConstValue;
use morok_ir::{Op, UOp};
use std::sync::Arc;
use crate::rewrite::graph_rewrite;
use crate::symbolic::symbolic_simple;
use crate::z3::verify::verify_equivalence;
#[test]
fn test_identity_add_zero() {
let x = UOp::var("x", DType::Int32, 0, 100);
let zero = UOp::native_const(0i32);
let expr = x.try_add(&zero).unwrap();
let matcher = symbolic_simple();
let simplified = graph_rewrite(&matcher, expr.clone(), &mut ());
assert!(Arc::ptr_eq(&simplified, &x), "x + 0 should simplify to x");
verify_equivalence(&expr, &simplified).expect("x + 0 should equal x");
}
#[test]
fn test_identity_mul_one() {
let x = UOp::var("x", DType::Int32, 0, 100);
let one = UOp::native_const(1i32);
let expr = x.try_mul(&one).unwrap();
let matcher = symbolic_simple();
let simplified = graph_rewrite(&matcher, expr.clone(), &mut ());
assert!(Arc::ptr_eq(&simplified, &x), "x * 1 should simplify to x");
verify_equivalence(&expr, &simplified).expect("x * 1 should equal x");
}
#[test]
fn test_identity_sub_zero() {
let x = UOp::var("x", DType::Int32, 0, 100);
let zero = UOp::native_const(0i32);
let expr = x.try_sub(&zero).unwrap();
let matcher = symbolic_simple();
let simplified = graph_rewrite(&matcher, expr.clone(), &mut ());
assert!(Arc::ptr_eq(&simplified, &x), "x - 0 should simplify to x");
verify_equivalence(&expr, &simplified).expect("x - 0 should equal x");
}
#[test]
fn test_identity_div_one() {
let x = UOp::var("x", DType::Int32, 0, 100);
let one = UOp::native_const(1i32);
let expr = x.try_div(&one).unwrap();
let matcher = symbolic_simple();
let simplified = graph_rewrite(&matcher, expr.clone(), &mut ());
assert!(Arc::ptr_eq(&simplified, &x), "x / 1 should simplify to x");
verify_equivalence(&expr, &simplified).expect("x / 1 should equal x");
}
#[test]
fn test_identity_mod_one() {
let x = UOp::var("x", DType::Int32, 0, 100);
let one = UOp::native_const(1i32);
let expr = x.try_mod(&one).unwrap();
let zero = UOp::native_const(0i32);
verify_equivalence(&expr, &zero).expect("x % 1 should equal 0");
}
#[test]
fn test_zero_mul_zero() {
let x = UOp::var("x", DType::Int32, 0, 100);
let zero = UOp::native_const(0i32);
let expr = x.try_mul(&zero).unwrap();
let matcher = symbolic_simple();
let simplified = graph_rewrite(&matcher, expr.clone(), &mut ());
assert!(Arc::ptr_eq(&simplified, &zero), "x * 0 should simplify to 0");
verify_equivalence(&expr, &simplified).expect("x * 0 should equal 0");
}
#[test]
fn test_zero_and_zero() {
let x = UOp::var("x", DType::Int32, 0, 100);
let zero = UOp::native_const(0i32);
let expr = x.try_and_op(&zero).unwrap();
let matcher = symbolic_simple();
let simplified = graph_rewrite(&matcher, expr.clone(), &mut ());
assert!(Arc::ptr_eq(&simplified, &zero), "x & 0 should simplify to 0");
}
#[test]
fn test_zero_div_x() {
let x = UOp::var("x", DType::Int32, 1, 100); let zero = UOp::native_const(0i32);
let expr = zero.try_div(&x).unwrap();
verify_equivalence(&expr, &zero).expect("0 / x should equal 0");
}
#[test]
fn test_self_sub_zero() {
let x = UOp::var("x", DType::Int32, 0, 100);
let expr = x.try_sub(&x).unwrap();
let zero = UOp::native_const(0i32);
verify_equivalence(&expr, &zero).expect("x - x should equal 0");
}
#[test]
fn test_self_div_one() {
let x = UOp::var("x", DType::Int32, 1, 100); let expr = x.try_div(&x).unwrap();
let matcher = symbolic_simple();
let simplified = graph_rewrite(&matcher, expr.clone(), &mut ());
match simplified.op() {
Op::Const(cv) => assert_eq!(cv.0, ConstValue::Int(1), "x / x should be 1"),
other => panic!("Expected Const(1), got {:?}", other),
}
verify_equivalence(&expr, &simplified).expect("x / x should equal 1");
}
#[test]
fn test_self_mod_zero() {
let x = UOp::var("x", DType::Int32, 1, 100); let expr = x.try_mod(&x).unwrap();
let matcher = symbolic_simple();
let simplified = graph_rewrite(&matcher, expr.clone(), &mut ());
match simplified.op() {
Op::Const(cv) => assert_eq!(cv.0, ConstValue::Int(0), "x % x should be 0"),
other => panic!("Expected Const(0), got {:?}", other),
}
verify_equivalence(&expr, &simplified).expect("x % x should equal 0");
}
#[test]
fn test_self_and_identity() {
let x = UOp::var("x", DType::Int32, 0, 100);
let expr = x.try_and_op(&x).unwrap();
let matcher = symbolic_simple();
let simplified = graph_rewrite(&matcher, expr.clone(), &mut ());
assert!(Arc::ptr_eq(&simplified, &x), "x & x should simplify to x");
}
#[test]
fn test_div_cancel_mul() {
let a = UOp::var("a", DType::Int32, 0, 100);
let b = UOp::var("b", DType::Int32, 1, 100); let a_mul_b = a.try_mul(&b).unwrap();
let expr = a_mul_b.try_div(&b).unwrap();
let matcher = symbolic_simple();
let simplified = graph_rewrite(&matcher, expr.clone(), &mut ());
assert!(Arc::ptr_eq(&simplified, &a), "(a * b) / b should simplify to a");
verify_equivalence(&expr, &simplified).expect("(a * b) / b should equal a");
}
#[test]
fn test_div_chain() {
let a = UOp::var("a", DType::Int32, 0, 100);
let b = UOp::var("b", DType::Int32, 1, 10); let c = UOp::var("c", DType::Int32, 1, 10); let a_div_b = a.try_div(&b).unwrap();
let expr = a_div_b.try_div(&c).unwrap();
let matcher = symbolic_simple();
let simplified = graph_rewrite(&matcher, expr.clone(), &mut ());
verify_equivalence(&expr, &simplified).expect("(a / b) / c should be semantically equivalent");
}
#[test]
fn test_div_gcd_factor() {
let a = UOp::var("a", DType::Int32, 0, 60);
let b = UOp::var("b", DType::Int32, 0, 10); let six = UOp::native_const(6i32);
let a_mul_6 = a.try_mul(&six).unwrap();
let b_mul_6 = b.try_mul(&six).unwrap();
let expr = a_mul_6.try_div(&b_mul_6).unwrap();
let matcher = symbolic_simple();
let simplified = graph_rewrite(&matcher, expr.clone(), &mut ());
verify_equivalence(&expr, &simplified).expect("(a * c) / (b * c) should be semantically equivalent");
}
#[test]
fn test_mod_self_zero() {
let a = UOp::var("a", DType::Int32, 0, 100); let expr = a.try_mod(&a).unwrap();
let matcher = symbolic_simple();
let simplified = graph_rewrite(&matcher, expr.clone(), &mut ());
match simplified.op() {
Op::Const(cv) => assert_eq!(cv.0, ConstValue::Int(0), "a % a should be 0"),
other => panic!("Expected Const(0), got {:?}", other),
}
verify_equivalence(&expr, &simplified).expect("a % a should equal 0");
}
#[test]
fn test_term_combine_add() {
let x = UOp::var("x", DType::Int32, 0, 100);
let expr = x.try_add(&x).unwrap();
let matcher = symbolic_simple();
let simplified = graph_rewrite(&matcher, expr.clone(), &mut ());
verify_equivalence(&expr, &simplified).expect("x + x should be semantically equivalent to 2 * x");
}
#[test]
fn test_term_combine_coefficients() {
let x = UOp::var("x", DType::Int32, 0, 100);
let two = UOp::native_const(2i32);
let three = UOp::native_const(3i32);
let two_x = two.try_mul(&x).unwrap();
let three_x = three.try_mul(&x).unwrap();
let expr = two_x.try_add(&three_x).unwrap();
let matcher = symbolic_simple();
let simplified = graph_rewrite(&matcher, expr.clone(), &mut ());
verify_equivalence(&expr, &simplified).expect("(2 * x) + (3 * x) should equal 5 * x");
}
#[test]
fn test_const_folding_add() {
let x = UOp::var("x", DType::Int32, 0, 100);
let three = UOp::native_const(3i32);
let five = UOp::native_const(5i32);
let x_plus_3 = x.try_add(&three).unwrap();
let expr = x_plus_3.try_add(&five).unwrap();
let matcher = symbolic_simple();
let simplified = graph_rewrite(&matcher, expr.clone(), &mut ());
verify_equivalence(&expr, &simplified).expect("(x + 3) + 5 should equal x + 8");
}
#[test]
fn test_const_folding_mul() {
let x = UOp::var("x", DType::Int32, 0, 100);
let two = UOp::native_const(2i32);
let three = UOp::native_const(3i32);
let x_mul_2 = x.try_mul(&two).unwrap();
let expr = x_mul_2.try_mul(&three).unwrap();
let matcher = symbolic_simple();
let simplified = graph_rewrite(&matcher, expr.clone(), &mut ());
verify_equivalence(&expr, &simplified).expect("(x * 2) * 3 should equal x * 6");
}