use std::f64::consts::PI;
use morok_dtype::{DType, ScalarDType};
use test_case::test_case;
use crate::types::ConstValue;
#[test_case(ConstValue::Bool(true), DType::Bool, ConstValue::Bool(true); "bool_to_bool_true")]
#[test_case(ConstValue::Bool(false), DType::Bool, ConstValue::Bool(false); "bool_to_bool_false")]
#[test_case(ConstValue::Int(42), DType::Int64, ConstValue::Int(42); "int64_to_int64")]
#[test_case(ConstValue::UInt(42), DType::UInt64, ConstValue::UInt(42); "uint64_to_uint64")]
#[test_case(ConstValue::Float(PI), DType::Float64, ConstValue::Float(PI); "float64_to_float64")]
fn test_identity_cast(input: ConstValue, dtype: DType, expected: ConstValue) {
assert_eq!(input.cast(&dtype), Some(expected));
}
#[test_case(ConstValue::Bool(true), DType::Int8, ConstValue::Int(1); "bool_true_to_int8")]
#[test_case(ConstValue::Bool(false), DType::Int8, ConstValue::Int(0); "bool_false_to_int8")]
#[test_case(ConstValue::Bool(true), DType::Int64, ConstValue::Int(1); "bool_true_to_int64")]
#[test_case(ConstValue::Bool(false), DType::Int64, ConstValue::Int(0); "bool_false_to_int64")]
#[test_case(ConstValue::Bool(true), DType::UInt8, ConstValue::UInt(1); "bool_true_to_uint8")]
#[test_case(ConstValue::Bool(false), DType::UInt8, ConstValue::UInt(0); "bool_false_to_uint8")]
#[test_case(ConstValue::Bool(true), DType::UInt64, ConstValue::UInt(1); "bool_true_to_uint64")]
#[test_case(ConstValue::Bool(false), DType::UInt64, ConstValue::UInt(0); "bool_false_to_uint64")]
#[test_case(ConstValue::Bool(true), DType::Float32, ConstValue::Float(1.0); "bool_true_to_float32")]
#[test_case(ConstValue::Bool(false), DType::Float32, ConstValue::Float(0.0); "bool_false_to_float32")]
#[test_case(ConstValue::Bool(true), DType::Float64, ConstValue::Float(1.0); "bool_true_to_float64")]
#[test_case(ConstValue::Bool(false), DType::Float64, ConstValue::Float(0.0); "bool_false_to_float64")]
fn test_bool_cast(input: ConstValue, dtype: DType, expected: ConstValue) {
assert_eq!(input.cast(&dtype), Some(expected));
}
#[test]
fn test_int64_to_smaller_signed() {
let value = ConstValue::Int(42);
assert_eq!(value.cast(&DType::Int8), Some(ConstValue::Int(42)));
assert_eq!(value.cast(&DType::Int16), Some(ConstValue::Int(42)));
assert_eq!(value.cast(&DType::Int32), Some(ConstValue::Int(42)));
assert_eq!(value.cast(&DType::Int64), Some(ConstValue::Int(42)));
}
#[test]
fn test_small_int_widening() {
let value = ConstValue::Int(127);
assert_eq!(value.cast(&DType::Int64), Some(ConstValue::Int(127)));
}
#[test]
fn test_int64_to_uint64_allowed() {
let value = ConstValue::Int(42);
assert_eq!(value.cast(&DType::UInt64), Some(ConstValue::UInt(42)));
}
#[test]
fn test_uint64_to_int64_allowed() {
let value = ConstValue::UInt(42);
assert_eq!(value.cast(&DType::Int64), Some(ConstValue::Int(42)));
}
#[test]
fn test_int64_to_float_allowed() {
let value = ConstValue::Int(42);
assert_eq!(value.cast(&DType::Float32), Some(ConstValue::Float(42.0)));
assert_eq!(value.cast(&DType::Float64), Some(ConstValue::Float(42.0)));
}
#[test]
fn test_float_to_int_allowed() {
let value = ConstValue::Float(42.7);
assert_eq!(value.cast(&DType::Int64), Some(ConstValue::Int(42)));
assert_eq!(value.cast(&DType::UInt64), Some(ConstValue::UInt(42)));
assert_eq!(value.cast(&DType::Bool), Some(ConstValue::Bool(true)));
let zero = ConstValue::Float(0.0);
assert_eq!(zero.cast(&DType::Bool), Some(ConstValue::Bool(false)));
}
#[test]
fn test_float_to_float_allowed() {
let value = ConstValue::Float(PI);
assert_eq!(value.cast(&DType::Float32), Some(ConstValue::Float(PI)));
assert_eq!(value.cast(&DType::Float64), Some(ConstValue::Float(PI)));
}
#[test]
fn test_unsupported_void_cast() {
let value = ConstValue::Int(42);
let void_dtype = DType::Scalar(ScalarDType::Void);
assert_eq!(value.cast(&void_dtype), None);
}
#[test]
fn test_index_cast_from_int() {
let value = ConstValue::Int(42);
let index_dtype = DType::Scalar(ScalarDType::Index);
assert_eq!(value.cast(&index_dtype), Some(ConstValue::Int(42)));
}
#[test]
fn test_vector_dtype_returns_none() {
let value = ConstValue::Int(42);
let vec_dtype = DType::Scalar(ScalarDType::Int32).vec(4);
assert_eq!(value.cast(&vec_dtype), None);
}
#[test]
fn test_unsupported_fp8_cast() {
let value = ConstValue::Bool(true);
assert_eq!(value.cast(&DType::Scalar(ScalarDType::FP8E4M3)), None);
assert_eq!(value.cast(&DType::Scalar(ScalarDType::FP8E5M2)), None);
}
#[test]
fn test_bool_to_all_int_types() {
let t = ConstValue::Bool(true);
let f = ConstValue::Bool(false);
assert_eq!(t.cast(&DType::Int8), Some(ConstValue::Int(1)));
assert_eq!(f.cast(&DType::Int8), Some(ConstValue::Int(0)));
assert_eq!(t.cast(&DType::Int16), Some(ConstValue::Int(1)));
assert_eq!(f.cast(&DType::Int16), Some(ConstValue::Int(0)));
assert_eq!(t.cast(&DType::Int32), Some(ConstValue::Int(1)));
assert_eq!(f.cast(&DType::Int32), Some(ConstValue::Int(0)));
assert_eq!(t.cast(&DType::Int64), Some(ConstValue::Int(1)));
assert_eq!(f.cast(&DType::Int64), Some(ConstValue::Int(0)));
}
#[test]
fn test_bool_to_all_uint_types() {
let t = ConstValue::Bool(true);
let f = ConstValue::Bool(false);
assert_eq!(t.cast(&DType::UInt8), Some(ConstValue::UInt(1)));
assert_eq!(f.cast(&DType::UInt8), Some(ConstValue::UInt(0)));
assert_eq!(t.cast(&DType::UInt16), Some(ConstValue::UInt(1)));
assert_eq!(f.cast(&DType::UInt16), Some(ConstValue::UInt(0)));
assert_eq!(t.cast(&DType::UInt32), Some(ConstValue::UInt(1)));
assert_eq!(f.cast(&DType::UInt32), Some(ConstValue::UInt(0)));
assert_eq!(t.cast(&DType::UInt64), Some(ConstValue::UInt(1)));
assert_eq!(f.cast(&DType::UInt64), Some(ConstValue::UInt(0)));
}
#[test]
fn test_bool_to_float_types() {
let t = ConstValue::Bool(true);
let f = ConstValue::Bool(false);
assert_eq!(t.cast(&DType::Float16), Some(ConstValue::Float(1.0)));
assert_eq!(f.cast(&DType::Float16), Some(ConstValue::Float(0.0)));
assert_eq!(t.cast(&DType::Float32), Some(ConstValue::Float(1.0)));
assert_eq!(f.cast(&DType::Float32), Some(ConstValue::Float(0.0)));
assert_eq!(t.cast(&DType::Float64), Some(ConstValue::Float(1.0)));
assert_eq!(f.cast(&DType::Float64), Some(ConstValue::Float(0.0)));
}
#[test]
fn test_cast_logic_truncation() {
}
#[test]
fn test_const_value_always_uses_wide_storage() {
let value = ConstValue::Int(42);
assert_eq!(value.dtype(), DType::Int64);
let value = ConstValue::UInt(42);
assert_eq!(value.dtype(), DType::UInt64);
let value = ConstValue::Float(PI);
assert_eq!(value.dtype(), DType::Float64);
}