use morok_dtype::DType;
use morok_ir::types::{BinaryOp, ConstValue};
use morok_ir::{Op, UOp};
use std::sync::Arc;
use crate::rewrite::graph_rewrite;
use crate::symbolic::{symbolic, symbolic_simple};
#[test]
fn test_lt_always_true() {
let idx = UOp::var("idx", DType::Int32, 0, 15);
let size = UOp::native_const(32i32);
let check = idx.try_cmplt(&size).unwrap();
let matcher = symbolic();
let result = graph_rewrite(&matcher, check, &mut ());
match result.op() {
Op::Const(c) => assert_eq!(c.0, ConstValue::Bool(true)),
other => panic!("Expected Op::Const, got {:?}", other),
}
}
#[test]
fn test_lt_unknown() {
let idx = UOp::var("idx", DType::Int32, 0, 100);
let size = UOp::native_const(50i32);
let check = idx.try_cmplt(&size).unwrap();
let matcher = symbolic_simple();
let result = graph_rewrite(&matcher, check, &mut ());
match result.op() {
Op::Binary(BinaryOp::Lt, a, b) => {
assert!(Arc::ptr_eq(a, &idx));
assert!(Arc::ptr_eq(b, &size));
}
other => panic!("Expected Op::Binary(Lt, _, _), got {:?}", other),
}
}
#[test]
fn test_eq_same_var() {
let x = UOp::var("x", DType::Int32, 0, 100);
let check = x.try_cmpeq(&x).unwrap();
let matcher = symbolic();
let result = graph_rewrite(&matcher, check, &mut ());
match result.op() {
Op::Const(c) => assert_eq!(c.0, ConstValue::Bool(true)),
other => panic!("Expected Op::Const, got {:?}", other),
}
}
#[test]
fn test_ne_same_var() {
let x = UOp::var("x", DType::Int32, 0, 100);
let check = x.try_cmpne(&x).unwrap();
let matcher = symbolic_simple();
let result = graph_rewrite(&matcher, check, &mut ());
match result.op() {
Op::Const(c) => assert_eq!(c.0, ConstValue::Bool(false)),
other => panic!("Expected Op::Const, got {:?}", other),
}
}
#[test]
fn test_cascading_bounds_elimination() {
let idx = UOp::var("idx", DType::Int32, 0, 10);
let size = UOp::native_const(20i32);
let bounds_check = idx.try_cmplt(&size).unwrap();
let safe_val = UOp::native_const(42i32);
let error_val = UOp::native_const(-1i32);
let where_op = UOp::try_where(bounds_check, safe_val.clone(), error_val).unwrap();
let matcher = symbolic_simple();
let result = graph_rewrite(&matcher, where_op, &mut ());
assert!(Arc::ptr_eq(&result, &safe_val));
}