use std::fmt;
use super::complex::{Complex64, Complex128};
use super::fp8::{FP8E4M3, FP8E5M2};
#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
#[non_exhaustive]
#[repr(u8)]
pub enum DType {
F64 = 0,
F32 = 1,
F16 = 2,
BF16 = 3,
FP8E4M3 = 4,
FP8E5M2 = 5,
I64 = 10,
I32 = 11,
I16 = 12,
I8 = 13,
U64 = 20,
U32 = 21,
U16 = 22,
U8 = 23,
Bool = 30,
Complex64 = 40,
Complex128 = 41,
}
impl DType {
#[inline]
pub const fn size_in_bytes(self) -> usize {
match self {
Self::Complex128 => 16,
Self::F64 | Self::I64 | Self::U64 | Self::Complex64 => 8,
Self::F32 | Self::I32 | Self::U32 => 4,
Self::F16 | Self::BF16 | Self::I16 | Self::U16 => 2,
Self::FP8E4M3 | Self::FP8E5M2 | Self::I8 | Self::U8 | Self::Bool => 1,
}
}
#[inline]
pub const fn is_float(self) -> bool {
matches!(
self,
Self::F64 | Self::F32 | Self::F16 | Self::BF16 | Self::FP8E4M3 | Self::FP8E5M2
)
}
#[inline]
pub const fn is_complex(self) -> bool {
matches!(self, Self::Complex64 | Self::Complex128)
}
#[inline]
pub const fn complex_component_dtype(self) -> Option<Self> {
match self {
Self::Complex64 => Some(Self::F32),
Self::Complex128 => Some(Self::F64),
_ => None,
}
}
#[inline]
pub const fn is_signed_int(self) -> bool {
matches!(self, Self::I64 | Self::I32 | Self::I16 | Self::I8)
}
#[inline]
pub const fn is_unsigned_int(self) -> bool {
matches!(self, Self::U64 | Self::U32 | Self::U16 | Self::U8)
}
#[inline]
pub const fn is_int(self) -> bool {
self.is_signed_int() || self.is_unsigned_int()
}
#[inline]
pub const fn is_bool(self) -> bool {
matches!(self, Self::Bool)
}
#[inline]
pub const fn is_signed(self) -> bool {
self.is_float() || self.is_signed_int() || self.is_complex()
}
#[inline]
pub const fn default_float() -> Self {
Self::F32
}
#[inline]
pub const fn default_int() -> Self {
Self::I64
}
pub const fn short_name(self) -> &'static str {
match self {
Self::F64 => "f64",
Self::F32 => "f32",
Self::F16 => "f16",
Self::BF16 => "bf16",
Self::FP8E4M3 => "fp8e4m3",
Self::FP8E5M2 => "fp8e5m2",
Self::I64 => "i64",
Self::I32 => "i32",
Self::I16 => "i16",
Self::I8 => "i8",
Self::U64 => "u64",
Self::U32 => "u32",
Self::U16 => "u16",
Self::U8 => "u8",
Self::Bool => "bool",
Self::Complex64 => "c64",
Self::Complex128 => "c128",
}
}
pub fn min_value(self) -> f64 {
match self {
Self::F64 => f64::MIN,
Self::F32 => f32::MIN as f64,
Self::F16 => -65504.0, Self::BF16 => -3.4e38, Self::FP8E4M3 => -448.0, Self::FP8E5M2 => -57344.0, Self::I64 => i64::MIN as f64,
Self::I32 => i32::MIN as f64,
Self::I16 => i16::MIN as f64,
Self::I8 => i8::MIN as f64,
Self::U64 => 0.0,
Self::U32 => 0.0,
Self::U16 => 0.0,
Self::U8 => 0.0,
Self::Bool => 0.0,
Self::Complex64 => f32::MIN as f64,
Self::Complex128 => f64::MIN,
}
}
pub fn fill_bytes_impl(self, value: f64, count: usize) -> Vec<u8> {
#[inline]
fn typed_to_bytes<T: bytemuck::NoUninit>(v: Vec<T>) -> Vec<u8> {
bytemuck::cast_slice::<T, u8>(&v).to_vec()
}
match self {
DType::F64 => typed_to_bytes(vec![value; count]),
DType::F32 => typed_to_bytes(vec![value as f32; count]),
DType::F16 => {
let bits = crate::dtype::half_from_f32_util(value as f32, true);
typed_to_bytes(vec![bits; count])
}
DType::BF16 => {
let bits = crate::dtype::half_from_f32_util(value as f32, false);
typed_to_bytes(vec![bits; count])
}
DType::FP8E4M3 => {
vec![FP8E4M3::from_f32(value as f32).to_bits(); count]
}
DType::FP8E5M2 => {
vec![FP8E5M2::from_f32(value as f32).to_bits(); count]
}
DType::I64 => typed_to_bytes(vec![value as i64; count]),
DType::I32 => typed_to_bytes(vec![value as i32; count]),
DType::I16 => typed_to_bytes(vec![value as i16; count]),
DType::I8 => typed_to_bytes(vec![value as i8; count]),
DType::U64 => typed_to_bytes(vec![value as u64; count]),
DType::U32 => typed_to_bytes(vec![value as u32; count]),
DType::U16 => typed_to_bytes(vec![value as u16; count]),
DType::U8 => vec![value as u8; count],
DType::Bool => vec![if value != 0.0 { 1u8 } else { 0u8 }; count],
DType::Complex64 => typed_to_bytes(vec![Complex64::new(value as f32, 0.0); count]),
DType::Complex128 => typed_to_bytes(vec![Complex128::new(value, 0.0); count]),
}
}
pub fn max_value(self) -> f64 {
match self {
Self::F64 => f64::MAX,
Self::F32 => f32::MAX as f64,
Self::F16 => 65504.0,
Self::BF16 => 3.4e38,
Self::FP8E4M3 => 448.0,
Self::FP8E5M2 => 57344.0,
Self::I64 => i64::MAX as f64,
Self::I32 => i32::MAX as f64,
Self::I16 => i16::MAX as f64,
Self::I8 => i8::MAX as f64,
Self::U64 => u64::MAX as f64,
Self::U32 => u32::MAX as f64,
Self::U16 => u16::MAX as f64,
Self::U8 => u8::MAX as f64,
Self::Bool => 1.0,
Self::Complex64 => f32::MAX as f64,
Self::Complex128 => f64::MAX,
}
}
}
impl fmt::Display for DType {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.short_name())
}
}