use std::f32::consts::PI;
use crate::{AxisId, ConstValue, Op, UOp};
use morok_dtype::DType;
#[test]
fn test_vmin_vmax_const() {
assert_eq!(UOp::native_const(5i32).vmin(), &ConstValue::Int(5));
assert_eq!(UOp::native_const(5i32).vmax(), &ConstValue::Int(5));
assert_eq!(UOp::native_const(-3i32).vmin(), &ConstValue::Int(-3));
assert_eq!(UOp::native_const(-3i32).vmax(), &ConstValue::Int(-3));
assert_eq!(UOp::native_const(PI).vmin(), &ConstValue::Float(PI as f64));
assert_eq!(UOp::native_const(PI).vmax(), &ConstValue::Float(PI as f64));
assert_eq!(UOp::native_const(true).vmin(), &ConstValue::Bool(true));
assert_eq!(UOp::native_const(true).vmax(), &ConstValue::Bool(true));
}
#[test]
fn test_vmin_vmax_add() {
let sum = UOp::native_const(2i32).try_add(&UOp::native_const(3i32)).unwrap();
assert_eq!(sum.vmin(), &ConstValue::Int(5));
assert_eq!(sum.vmax(), &ConstValue::Int(5));
}
#[test]
fn test_vmin_vmax_sub() {
let a = UOp::native_const(10i32);
let b = UOp::native_const(3i32);
let diff = a.try_sub(&b).unwrap();
assert_eq!(diff.vmin(), &ConstValue::Int(7));
assert_eq!(diff.vmax(), &ConstValue::Int(7));
}
#[test]
fn test_vmin_vmax_mul() {
let a = UOp::native_const(-2i32);
let b = UOp::native_const(3i32);
let prod = a.try_mul(&b).unwrap();
assert_eq!(prod.vmin(), &ConstValue::Int(-6));
assert_eq!(prod.vmax(), &ConstValue::Int(-6));
}
#[test]
fn test_vmin_vmax_mul_range() {
let a = UOp::define_var("a".to_string(), 0, 3);
let b = UOp::define_var("b".to_string(), 0, 4);
let prod = a.try_mul(&b).unwrap();
assert_eq!(prod.vmin(), &ConstValue::Int(0));
assert_eq!(prod.vmax(), &ConstValue::Int(12));
}
#[test]
fn test_vmin_vmax_max() {
let a = UOp::native_const(5i32);
let b = UOp::native_const(10i32);
let max_val = a.try_max(&b).unwrap();
assert_eq!(max_val.vmin(), &ConstValue::Int(10));
assert_eq!(max_val.vmax(), &ConstValue::Int(10));
}
#[test]
fn test_vmin_vmax_idiv() {
let a = UOp::native_const(15i32);
let b = UOp::native_const(3i32);
let div = a.try_div(&b).unwrap();
assert_eq!(div.vmin(), &ConstValue::Int(5));
assert_eq!(div.vmax(), &ConstValue::Int(5));
}
#[test]
fn test_vmin_vmax_mod() {
let a = UOp::native_const(17i32);
let b = UOp::native_const(5i32);
let modulo = a.try_mod(&b).unwrap();
assert_eq!(modulo.vmin(), &ConstValue::Int(2));
assert_eq!(modulo.vmax(), &ConstValue::Int(2));
}
#[test]
fn test_vmin_vmax_neg() {
let five = UOp::native_const(5i32);
let neg = five.neg();
assert_eq!(neg.vmin(), &ConstValue::Int(-5));
assert_eq!(neg.vmax(), &ConstValue::Int(-5));
}
#[test]
fn test_vmin_vmax_neg_range() {
let var = UOp::define_var("x".to_string(), 0, 5);
let neg = var.neg();
assert_eq!(neg.vmin(), &ConstValue::Int(-5));
assert_eq!(neg.vmax(), &ConstValue::Int(0));
}
#[test]
fn test_vmin_vmax_cmplt() {
let a = UOp::native_const(5i32);
let b = UOp::native_const(10i32);
let cmp = a.try_cmplt(&b).unwrap();
assert_eq!(cmp.vmin(), &ConstValue::Bool(true));
assert_eq!(cmp.vmax(), &ConstValue::Bool(true));
}
#[test]
fn test_vmin_vmax_eq() {
let a = UOp::native_const(5i32);
let b = UOp::native_const(5i32);
let eq = a.try_cmpeq(&b).unwrap();
assert_eq!(eq.vmin(), &ConstValue::Bool(true));
assert_eq!(eq.vmax(), &ConstValue::Bool(true));
}
#[test]
fn test_vmin_vmax_and_bool() {
let and = UOp::native_const(true).try_and_op(&UOp::native_const(false)).unwrap();
assert_eq!(and.vmin(), &ConstValue::Bool(false));
assert_eq!(and.vmax(), &ConstValue::Bool(false));
}
#[test]
fn test_vmin_vmax_or_bool() {
let or = UOp::native_const(true).try_or_op(&UOp::native_const(false)).unwrap();
assert_eq!(or.vmin(), &ConstValue::Bool(true));
assert_eq!(or.vmax(), &ConstValue::Bool(true));
}
#[test]
fn test_vmin_vmax_and_int() {
let a = UOp::native_const(15i32); let b = UOp::native_const(7i32); let and = a.try_and_op(&b).unwrap();
assert_eq!(and.vmin(), &ConstValue::Int(7));
assert_eq!(and.vmax(), &ConstValue::Int(7));
}
#[test]
fn test_vmin_vmax_shl() {
let a = UOp::native_const(3i32);
let b = UOp::native_const(2i32);
let shl = a.try_shl_op(&b).unwrap();
assert_eq!(shl.vmin(), &ConstValue::Int(12));
assert_eq!(shl.vmax(), &ConstValue::Int(12));
}
#[test]
fn test_vmin_vmax_shr() {
let a = UOp::native_const(12i32);
let b = UOp::native_const(2i32);
let shr = a.try_shr_op(&b).unwrap();
assert_eq!(shr.vmin(), &ConstValue::Int(3));
assert_eq!(shr.vmax(), &ConstValue::Int(3));
}
#[test]
fn test_vmin_vmax_define_var() {
let var = UOp::define_var("x".to_string(), 0, 20);
assert_eq!(var.vmin(), &ConstValue::Int(0));
assert_eq!(var.vmax(), &ConstValue::Int(20));
}
#[test]
fn test_vmin_vmax_define_var_with_min() {
let var = UOp::define_var("x".to_string(), 5, 20);
assert_eq!(var.vmin(), &ConstValue::Int(5));
assert_eq!(var.vmax(), &ConstValue::Int(20));
}
#[test]
fn test_vmin_vmax_range() {
let end = UOp::native_const(10i32);
let range = UOp::new(
Op::Range {
end,
axis_id: AxisId::Renumbered(0),
axis_type: crate::types::AxisType::Loop,
deps: smallvec::SmallVec::new(),
},
DType::Int32,
);
assert_eq!(range.vmin(), &ConstValue::Int(0));
assert_eq!(range.vmax(), &ConstValue::Int(9));
}
#[test]
fn test_vmin_vmax_cast() {
let float_val = UOp::native_const(5.7f32);
let int_val = float_val.cast(DType::Int32);
assert_eq!(int_val.vmin(), &ConstValue::Int(5));
assert_eq!(int_val.vmax(), &ConstValue::Int(5));
}
#[test]
fn test_vmin_vmax_cast_range() {
let var = UOp::define_var("x".to_string(), 0, 1000);
let casted = var.cast(DType::Int8);
assert_eq!(casted.vmin(), &ConstValue::Int(0));
assert_eq!(casted.vmax(), &ConstValue::Int(127));
}
#[test]
fn test_vmin_vmax_where_true() {
let where_op = UOp::try_where(UOp::native_const(true), UOp::native_const(10i32), UOp::native_const(5i32)).unwrap();
assert_eq!(where_op.vmin(), &ConstValue::Int(10));
assert_eq!(where_op.vmax(), &ConstValue::Int(10));
}
#[test]
fn test_vmin_vmax_where_false() {
let where_op = UOp::try_where(UOp::native_const(false), UOp::native_const(10i32), UOp::native_const(5i32)).unwrap();
assert_eq!(where_op.vmin(), &ConstValue::Int(5));
assert_eq!(where_op.vmax(), &ConstValue::Int(5));
}
#[test]
fn test_vmin_vmax_where_range() {
let var = UOp::define_var("cond".to_string(), 0, 1);
let zero = UOp::const_(DType::Index, ConstValue::Int(0));
let cond = var.try_cmpgt(&zero).unwrap();
let true_val = UOp::native_const(10i32);
let false_val = UOp::native_const(5i32);
let where_op = UOp::try_where(cond, true_val, false_val).unwrap();
assert_eq!(where_op.vmin(), &ConstValue::Int(5));
assert_eq!(where_op.vmax(), &ConstValue::Int(10));
}
#[test]
fn test_vmin_vmax_mulacc() {
let a = UOp::native_const(3i32);
let b = UOp::native_const(4i32);
let c = UOp::native_const(5i32);
let mulacc = UOp::try_mulacc(a, b, c).unwrap();
assert_eq!(mulacc.vmin(), &ConstValue::Int(17));
assert_eq!(mulacc.vmax(), &ConstValue::Int(17));
}
#[test]
fn test_vmin_vmax_complex_expression() {
let x = UOp::var("x", DType::Int32, 0, 10);
let five = UOp::native_const(5i32);
let two = UOp::native_const(2i32);
let x_plus_5 = x.try_add(&five).unwrap();
let result = x_plus_5.try_mul(&two).unwrap();
assert_eq!(result.vmin(), &ConstValue::Int(10));
assert_eq!(result.vmax(), &ConstValue::Int(30));
}
#[test]
fn test_vmin_vmax_nested_max() {
let a = UOp::native_const(3i32);
let b = UOp::native_const(7i32);
let c = UOp::native_const(5i32);
let max_ab = a.try_max(&b).unwrap();
let max_abc = max_ab.try_max(&c).unwrap();
assert_eq!(max_abc.vmin(), &ConstValue::Int(7));
assert_eq!(max_abc.vmax(), &ConstValue::Int(7));
}
#[test]
fn test_vmin_vmax_float_ops() {
let a = UOp::native_const(2.5f32);
let b = UOp::native_const(1.5f32);
let sum = a.try_add(&b).unwrap();
assert_eq!(sum.vmin(), &ConstValue::Float(4.0));
assert_eq!(sum.vmax(), &ConstValue::Float(4.0));
let diff = a.try_sub(&b).unwrap();
assert_eq!(diff.vmin(), &ConstValue::Float(1.0));
assert_eq!(diff.vmax(), &ConstValue::Float(1.0));
let prod = a.try_mul(&b).unwrap();
assert_eq!(prod.vmin(), &ConstValue::Float(3.75));
assert_eq!(prod.vmax(), &ConstValue::Float(3.75));
let div = a.try_div(&b).unwrap();
if let ConstValue::Float(min_val) = div.vmin() {
assert!((min_val - 1.6666666666666667).abs() < 1e-10);
} else {
panic!("Expected float result");
}
}
#[test]
fn test_vmin_vmax_division_by_zero_range() {
let a = UOp::native_const(10i32);
let b = UOp::var("b", DType::Int32, 0, 1); let div = a.try_div(&b).unwrap();
assert_eq!(div.vmin(), &ConstValue::Int(i32::MIN as i64));
assert_eq!(div.vmax(), &ConstValue::Int(i32::MAX as i64));
}
#[test]
fn test_vmin_vmax_mod_by_zero_range() {
let a = UOp::native_const(10i32);
let b = UOp::var("b", DType::Int32, 0, 1); let modulo = a.try_mod(&b).unwrap();
assert_eq!(modulo.vmin(), &ConstValue::Int(i32::MIN as i64));
assert_eq!(modulo.vmax(), &ConstValue::Int(i32::MAX as i64));
}
#[test]
fn test_vmin_vmax_shift_overflow() {
let a = UOp::native_const(1i32);
let b = UOp::native_const(64i32); let shl = a.try_shl_op(&b).unwrap();
assert_eq!(shl.vmin(), &ConstValue::Int(i32::MIN as i64));
assert_eq!(shl.vmax(), &ConstValue::Int(i32::MAX as i64));
}