use super::DType;
pub fn promote(lhs: DType, rhs: DType) -> DType {
use DType::*;
if lhs == rhs {
return lhs;
}
let priority = |dt: DType| -> u8 {
match dt {
Complex128 => 120,
Complex64 => 110,
F64 => 100,
F32 => 90,
BF16 => 85,
F16 => 80,
FP8E4M3 => 75, FP8E5M2 => 70, I64 => 65,
U64 => 60,
I32 => 55,
U32 => 50,
I16 => 45,
U16 => 40,
I8 => 35,
U8 => 30,
Bool => 25,
}
};
if lhs.is_signed_int() && rhs.is_unsigned_int() {
return match (lhs, rhs) {
(I64, _) => I64,
(I32, U64) => I64,
(I32, U32) => I64,
(I32, _) => I32,
(I16, U64 | U32 | U16) => I32,
(I16, _) => I16,
(I8, U64 | U32 | U16 | U8) => I16,
_ => I64,
};
}
if rhs.is_signed_int() && lhs.is_unsigned_int() {
return promote(rhs, lhs);
}
if priority(lhs) >= priority(rhs) {
lhs
} else {
rhs
}
}
#[allow(dead_code)]
pub fn can_cast_safely(from: DType, to: DType) -> bool {
use DType::*;
if from == to {
return true;
}
match (from, to) {
(FP8E4M3 | FP8E5M2, F16 | BF16 | F32 | F64) => true,
(F16 | BF16, F32 | F64) => true,
(F32, F64) => true,
(Complex64, Complex128) => true,
(F32 | F16 | BF16 | FP8E4M3 | FP8E5M2, Complex64) => true,
(F64 | F32 | F16 | BF16 | FP8E4M3 | FP8E5M2, Complex128) => true,
(I8 | U8 | I16 | U16, F32 | F64) => true,
(I32 | U32, F64) => true,
(I8 | U8 | I16 | U16 | I32 | U32 | I64 | U64, Complex64 | Complex128) => true,
(I8, I16 | I32 | I64) => true,
(I16, I32 | I64) => true,
(I32, I64) => true,
(U8, U16 | U32 | U64 | I16 | I32 | I64) => true,
(U16, U32 | U64 | I32 | I64) => true,
(U32, U64 | I64) => true,
(Bool, _) if to.is_int() || to.is_float() || to.is_complex() => true,
_ => false,
}
}
#[cfg(test)]
mod tests {
use super::*;
use DType::*;
#[test]
fn test_same_type_promotion() {
assert_eq!(promote(F32, F32), F32);
assert_eq!(promote(I64, I64), I64);
}
#[test]
fn test_float_promotion() {
assert_eq!(promote(F32, F64), F64);
assert_eq!(promote(F16, F32), F32);
assert_eq!(promote(BF16, F32), F32);
}
#[test]
fn test_int_float_promotion() {
assert_eq!(promote(I64, F32), F32);
assert_eq!(promote(I32, F64), F64);
}
#[test]
fn test_signed_unsigned_promotion() {
assert_eq!(promote(I32, U32), I64);
assert_eq!(promote(I16, U16), I32);
assert_eq!(promote(I8, U8), I16);
}
#[test]
fn test_safe_cast() {
assert!(can_cast_safely(I32, I64));
assert!(can_cast_safely(F32, F64));
assert!(can_cast_safely(U8, I32));
assert!(!can_cast_safely(I64, I32));
assert!(!can_cast_safely(F64, F32));
}
#[test]
fn test_fp8_same_type_promotion() {
assert_eq!(promote(FP8E4M3, FP8E4M3), FP8E4M3);
assert_eq!(promote(FP8E5M2, FP8E5M2), FP8E5M2);
}
#[test]
fn test_fp8_to_wider_float_promotion() {
assert_eq!(promote(FP8E4M3, F16), F16);
assert_eq!(promote(FP8E4M3, BF16), BF16);
assert_eq!(promote(FP8E4M3, F32), F32);
assert_eq!(promote(FP8E4M3, F64), F64);
assert_eq!(promote(FP8E5M2, F16), F16);
assert_eq!(promote(FP8E5M2, BF16), BF16);
assert_eq!(promote(FP8E5M2, F32), F32);
assert_eq!(promote(FP8E5M2, F64), F64);
}
#[test]
fn test_fp8_e4m3_vs_e5m2_promotion() {
assert_eq!(promote(FP8E4M3, FP8E5M2), FP8E4M3);
assert_eq!(promote(FP8E5M2, FP8E4M3), FP8E4M3);
}
#[test]
fn test_fp8_vs_int_promotion() {
assert_eq!(promote(FP8E4M3, I64), FP8E4M3);
assert_eq!(promote(FP8E5M2, I32), FP8E5M2);
}
#[test]
fn test_fp8_safe_cast() {
assert!(can_cast_safely(FP8E4M3, F16));
assert!(can_cast_safely(FP8E4M3, BF16));
assert!(can_cast_safely(FP8E4M3, F32));
assert!(can_cast_safely(FP8E4M3, F64));
assert!(can_cast_safely(FP8E5M2, F16));
assert!(can_cast_safely(FP8E5M2, F32));
assert!(can_cast_safely(FP8E5M2, F64));
assert!(!can_cast_safely(F32, FP8E4M3));
assert!(!can_cast_safely(F16, FP8E5M2));
}
}