use morok_dtype::DType;
use crate::UOp;
#[test]
fn test_where_basic() {
assert_eq!(
UOp::try_where(UOp::native_const(true), UOp::native_const(1.0f32), UOp::native_const(0.0f32)).unwrap().dtype(),
DType::Float32
);
}
#[test]
fn test_where_int32() {
assert_eq!(
UOp::try_where(UOp::native_const(false), UOp::native_const(100i32), UOp::native_const(200i32)).unwrap().dtype(),
DType::Int32
);
}
#[test]
fn test_where_with_comparison() {
assert_eq!(
UOp::try_where(
UOp::native_const(5i32).try_cmplt(&UOp::native_const(10i32)).unwrap(),
UOp::native_const(1i32),
UOp::native_const(0i32)
)
.unwrap()
.dtype(),
DType::Int32
);
}
#[test]
fn test_where_same_branches() {
let value = UOp::native_const(42.0f32);
let result = UOp::try_where(UOp::native_const(true), value.clone(), value).unwrap();
assert_eq!(result.dtype(), DType::Float32);
}
#[test]
fn test_where_const_true_condition() {
assert_eq!(
UOp::try_where(UOp::native_const(true), UOp::native_const(100i32), UOp::native_const(200i32)).unwrap().dtype(),
DType::Int32
);
}
#[test]
fn test_where_const_false_condition() {
assert_eq!(
UOp::try_where(UOp::native_const(false), UOp::native_const(100i32), UOp::native_const(200i32)).unwrap().dtype(),
DType::Int32
);
}
#[test]
fn test_where_nested() {
let inner = UOp::try_where(UOp::native_const(false), UOp::native_const(2.0f32), UOp::native_const(3.0f32)).unwrap();
let result = UOp::try_where(UOp::native_const(true), UOp::native_const(1.0f32), inner).unwrap();
assert_eq!(result.dtype(), DType::Float32);
}
#[test]
fn test_where_with_different_dtypes() {
assert_eq!(
UOp::try_where(UOp::native_const(true), UOp::native_const(5i32), UOp::native_const(5.0f32)).unwrap().dtype(),
DType::Int32
);
}
#[test]
fn test_where_bool_branches() {
assert_eq!(
UOp::try_where(UOp::native_const(false), UOp::native_const(true), UOp::native_const(false)).unwrap().dtype(),
DType::Bool
);
}
#[test]
fn test_where_with_zero() {
assert_eq!(
UOp::try_where(UOp::native_const(true), UOp::native_const(1.0f32), UOp::native_const(0.0f32)).unwrap().dtype(),
DType::Float32
);
}
#[test]
fn test_mulacc_basic() {
assert_eq!(
UOp::try_mulacc(UOp::native_const(2.0f32), UOp::native_const(3.0f32), UOp::native_const(4.0f32))
.unwrap()
.dtype(),
DType::Float32
);
}
#[test]
fn test_mulacc_int32() {
assert_eq!(
UOp::try_mulacc(UOp::native_const(5i32), UOp::native_const(6i32), UOp::native_const(7i32)).unwrap().dtype(),
DType::Int32
);
}
#[test]
fn test_mulacc_with_zero_multiplier() {
assert_eq!(
UOp::try_mulacc(UOp::native_const(0.0f32), UOp::native_const(100.0f32), UOp::native_const(5.0f32))
.unwrap()
.dtype(),
DType::Float32
);
}
#[test]
fn test_mulacc_with_zero_accumulator() {
assert_eq!(
UOp::try_mulacc(UOp::native_const(2.0f32), UOp::native_const(3.0f32), UOp::native_const(0.0f32))
.unwrap()
.dtype(),
DType::Float32
);
}
#[test]
fn test_mulacc_with_one() {
assert_eq!(
UOp::try_mulacc(UOp::native_const(1.0f32), UOp::native_const(5.0f32), UOp::native_const(3.0f32))
.unwrap()
.dtype(),
DType::Float32
);
}
#[test]
fn test_mulacc_negative_values() {
assert_eq!(
UOp::try_mulacc(UOp::native_const(-2i32), UOp::native_const(3i32), UOp::native_const(10i32)).unwrap().dtype(),
DType::Int32
);
}
#[test]
fn test_mulacc_vs_separate_ops() {
let a = UOp::native_const(2.0f32);
let b = UOp::native_const(3.0f32);
let c = UOp::native_const(4.0f32);
let fused = UOp::try_mulacc(a.clone(), b.clone(), c.clone()).unwrap();
let mul = a.try_mul(&b).unwrap();
let separate = mul.try_add(&c).unwrap();
assert_eq!(fused.dtype(), separate.dtype());
assert_eq!(fused.dtype(), DType::Float32);
}
#[test]
fn test_mulacc_chained() {
let a = UOp::native_const(2.0f32);
let b = UOp::native_const(3.0f32);
let c = UOp::native_const(4.0f32);
let d = UOp::native_const(5.0f32);
let result1 = UOp::try_mulacc(a.clone(), b.clone(), c).unwrap();
let result2 = result1.try_mul(&d).unwrap();
assert_eq!(result2.dtype(), DType::Float32);
}