use std::f32::consts::PI;
use morok_dtype::DType;
use crate::{ConstValue, UOp, error::Error};
#[test]
fn test_add_same_type() {
assert_eq!(UOp::native_const(5i32).try_add(&UOp::native_const(3i32)).unwrap().dtype(), DType::Int32);
}
#[test]
fn test_sub_same_type() {
assert_eq!(UOp::native_const(10.0f32).try_sub(&UOp::native_const(3.0f32)).unwrap().dtype(), DType::Float32);
}
#[test]
fn test_mul_same_type() {
assert_eq!(UOp::native_const(4i32).try_mul(&UOp::native_const(5i32)).unwrap().dtype(), DType::Int32);
}
#[test]
fn test_mod_same_type() {
assert_eq!(UOp::native_const(10i32).try_mod(&UOp::native_const(3i32)).unwrap().dtype(), DType::Int32);
}
#[test]
fn test_idiv_same_type() {
assert_eq!(UOp::native_const(10i32).try_div(&UOp::native_const(3i32)).unwrap().dtype(), DType::Int32);
}
#[test]
fn test_fdiv_same_type() {
assert_eq!(UOp::native_const(10.0f32).try_div(&UOp::native_const(3.0f32)).unwrap().dtype(), DType::Float32);
}
#[test]
fn test_max_same_type() {
assert_eq!(UOp::native_const(10i32).try_max(&UOp::native_const(20i32)).unwrap().dtype(), DType::Int32);
}
#[test]
fn test_pow_same_type() {
assert_eq!(UOp::native_const(2.0f32).try_pow(&UOp::native_const(3.0f32)).unwrap().dtype(), DType::Float32);
}
#[test]
fn test_neg_int() {
let result = UOp::native_const(5i32).neg();
assert_eq!(result.dtype(), DType::Int32);
}
#[test]
fn test_neg_float() {
let result = UOp::native_const(PI).neg();
assert_eq!(result.dtype(), DType::Float32);
}
#[test]
fn test_add_type_promotion_int_to_float() {
let int_val = UOp::native_const(5i32);
let float_val = UOp::native_const(PI);
let result = int_val.try_add(&float_val).unwrap();
assert_eq!(result.dtype(), DType::Float32);
}
#[test]
fn test_mul_type_promotion_smaller_to_larger() {
let small = UOp::const_(DType::Int8, ConstValue::Int(5));
let large = UOp::native_const(10i32);
let result = small.try_mul(&large).unwrap();
assert_eq!(result.dtype(), DType::Int32);
}
#[test]
fn test_sub_type_promotion_uint_to_int() {
let uint_val = UOp::const_(DType::UInt8, ConstValue::UInt(5));
let int_val = UOp::native_const(10i32);
let result = uint_val.try_sub(&int_val).unwrap();
assert_eq!(result.dtype(), DType::Int32);
}
#[test]
fn test_idiv_by_zero() {
let numerator = UOp::native_const(10i32);
let zero = UOp::native_const(0i32);
let result = numerator.try_div(&zero);
assert!(matches!(result, Err(Error::DivisionByZero)));
}
#[test]
fn test_fdiv_by_zero() {
let numerator = UOp::native_const(10.0f32);
let zero = UOp::native_const(0.0f32);
let result = numerator.try_div(&zero);
assert!(matches!(result, Err(Error::DivisionByZero)));
}
#[test]
fn test_mod_by_zero() {
let numerator = UOp::native_const(10i32);
let zero = UOp::native_const(0i32);
let result = numerator.try_mod(&zero);
assert!(matches!(result, Err(Error::DivisionByZero)));
}
#[test]
fn test_add_void_type() {
let void_val = UOp::const_(DType::Void, ConstValue::Int(0));
let int_val = UOp::const_(DType::Int32, ConstValue::Int(5));
let result = void_val.try_add(&int_val);
assert!(matches!(result, Err(Error::VoidTypeInOp)));
}
#[test]
fn test_mul_void_type() {
let void_val = UOp::const_(DType::Void, ConstValue::Int(0));
let float_val = UOp::const_(DType::Float32, ConstValue::Float(PI as f64));
let result = void_val.try_mul(&float_val);
assert!(matches!(result, Err(Error::VoidTypeInOp)));
}
#[test]
fn test_add_bool_and_int() {
let bool_val = UOp::native_const(true);
let int_val = UOp::native_const(5i32);
let result = bool_val.try_add(&int_val).unwrap();
assert_eq!(result.dtype(), DType::Int32);
}
#[test]
fn test_mul_different_float_types() {
let f16 = UOp::const_(DType::Float16, ConstValue::Float(2.0));
let f32 = UOp::const_(DType::Float32, ConstValue::Float(3.0));
let result = f16.try_mul(&f32).unwrap();
assert_eq!(result.dtype(), DType::Float32);
}
#[test]
fn test_chained_operations() {
let a = UOp::native_const(10i32);
let b = UOp::native_const(5i32);
let c = UOp::native_const(2i32);
let sum = a.try_add(&b).unwrap();
let product = sum.try_mul(&c).unwrap();
assert_eq!(product.dtype(), DType::Int32);
}
#[test]
fn test_chained_with_promotion() {
let int_val = UOp::native_const(10i32);
let float_val = UOp::native_const(2.5f32);
let sum = int_val.try_add(&float_val).unwrap();
assert_eq!(sum.dtype(), DType::Float32);
let product = sum.try_mul(&int_val).unwrap();
assert_eq!(product.dtype(), DType::Float32);
}