use hmll_sys::hmll_dtype;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[non_exhaustive]
pub enum DType {
Bool,
BFloat16,
Complex64,
Float4,
Float6E2M3,
Float6E3M2,
Float8E4M3,
Float8E5M2,
Float8E8M0,
Float16,
Float32,
Int4,
Int8,
Int16,
Int32,
Int64,
UInt4,
UInt8,
UInt16,
UInt32,
UInt64,
Unknown,
}
impl DType {
#[inline(always)]
pub const fn from_raw(dtype: hmll_dtype) -> Self {
match dtype {
hmll_dtype::HMLL_DTYPE_BOOL => DType::Bool,
hmll_dtype::HMLL_DTYPE_BFLOAT16 => DType::BFloat16,
hmll_dtype::HMLL_DTYPE_COMPLEX => DType::Complex64,
hmll_dtype::HMLL_DTYPE_FLOAT4 => DType::Float4,
hmll_dtype::HMLL_DTYPE_FLOAT6_E2M3 => DType::Float6E2M3,
hmll_dtype::HMLL_DTYPE_FLOAT6_E3M2 => DType::Float6E3M2,
hmll_dtype::HMLL_DTYPE_FLOAT8_E4M3 => DType::Float8E4M3,
hmll_dtype::HMLL_DTYPE_FLOAT8_E5M2 => DType::Float8E5M2,
hmll_dtype::HMLL_DTYPE_FLOAT8_E8M0 => DType::Float8E8M0,
hmll_dtype::HMLL_DTYPE_FLOAT16 => DType::Float16,
hmll_dtype::HMLL_DTYPE_FLOAT32 => DType::Float32,
hmll_dtype::HMLL_DTYPE_SIGNED_INT4 => DType::Int4,
hmll_dtype::HMLL_DTYPE_SIGNED_INT8 => DType::Int8,
hmll_dtype::HMLL_DTYPE_SIGNED_INT16 => DType::Int16,
hmll_dtype::HMLL_DTYPE_SIGNED_INT32 => DType::Int32,
hmll_dtype::HMLL_DTYPE_SIGNED_INT64 => DType::Int64,
hmll_dtype::HMLL_DTYPE_UNSIGNED_INT4 => DType::UInt4,
hmll_dtype::HMLL_DTYPE_UNSIGNED_INT8 => DType::UInt8,
hmll_dtype::HMLL_DTYPE_UNSIGNED_INT16 => DType::UInt16,
hmll_dtype::HMLL_DTYPE_UNSIGNED_INT32 => DType::UInt32,
hmll_dtype::HMLL_DTYPE_UNSIGNED_INT64 => DType::UInt64,
_ => DType::Unknown,
}
}
#[inline]
pub const fn bits(&self) -> u8 {
match self {
DType::Bool | DType::Int8 | DType::UInt8 => 8,
DType::Float8E4M3 | DType::Float8E5M2 | DType::Float8E8M0 => 8,
DType::Float4 | DType::Int4 | DType::UInt4 => 4,
DType::Float6E2M3 | DType::Float6E3M2 => 6,
DType::BFloat16 | DType::Float16 | DType::Int16 | DType::UInt16 => 16,
DType::Float32 | DType::Int32 | DType::UInt32 => 32,
DType::Complex64 | DType::Int64 | DType::UInt64 => 64,
DType::Unknown => 0,
}
}
#[inline]
pub const fn is_float(&self) -> bool {
matches!(
self,
DType::BFloat16
| DType::Float4
| DType::Float6E2M3
| DType::Float6E3M2
| DType::Float8E4M3
| DType::Float8E5M2
| DType::Float8E8M0
| DType::Float16
| DType::Float32
)
}
#[inline]
pub const fn is_signed_int(&self) -> bool {
matches!(
self,
DType::Int4 | DType::Int8 | DType::Int16 | DType::Int32 | DType::Int64
)
}
#[inline]
pub const fn is_unsigned_int(&self) -> bool {
matches!(
self,
DType::UInt4 | DType::UInt8 | DType::UInt16 | DType::UInt32 | DType::UInt64
)
}
#[inline]
pub const fn is_int(&self) -> bool {
self.is_signed_int() || self.is_unsigned_int()
}
}
impl std::fmt::Display for DType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(match self {
DType::Bool => "BOOL",
DType::BFloat16 => "BF16",
DType::Complex64 => "C64",
DType::Float4 => "F4",
DType::Float6E2M3 => "F6_E2M3",
DType::Float6E3M2 => "F6_E3M2",
DType::Float8E4M3 => "F8_E4M3",
DType::Float8E5M2 => "F8_E5M2",
DType::Float8E8M0 => "F8_E8M0",
DType::Float16 => "F16",
DType::Float32 => "F32",
DType::Int4 => "I4",
DType::Int8 => "I8",
DType::Int16 => "I16",
DType::Int32 => "I32",
DType::Int64 => "I64",
DType::UInt4 => "U4",
DType::UInt8 => "U8",
DType::UInt16 => "U16",
DType::UInt32 => "U32",
DType::UInt64 => "U64",
DType::Unknown => "UNKNOWN",
})
}
}