#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum DType {
F32,
F16,
BF16,
F64,
I8,
I16,
I32,
I64,
U8,
U32,
Bool,
C64,
}
impl DType {
pub const fn size_bytes(self) -> usize {
match self {
Self::Bool | Self::I8 | Self::U8 => 1,
Self::F16 | Self::BF16 | Self::I16 => 2,
Self::F32 | Self::I32 | Self::U32 => 4,
Self::F64 | Self::I64 | Self::C64 => 8,
}
}
pub const fn is_float(self) -> bool {
matches!(self, Self::F32 | Self::F16 | Self::BF16 | Self::F64)
}
pub const fn is_complex(self) -> bool {
matches!(self, Self::C64)
}
pub const fn is_int(self) -> bool {
matches!(
self,
Self::I8 | Self::I16 | Self::I32 | Self::I64 | Self::U8 | Self::U32
)
}
pub const fn promotion_rank(self) -> u8 {
match self {
Self::Bool => 0,
Self::U8 | Self::I8 => 1,
Self::I16 | Self::BF16 => 2,
Self::F16 => 3,
Self::U32 | Self::I32 => 4,
Self::I64 => 5,
Self::F32 => 6,
Self::F64 => 7,
Self::C64 => 8,
}
}
pub fn promote(self, other: Self) -> Self {
if self == other {
return self;
}
if matches!(
(self, other),
(Self::F16, Self::BF16) | (Self::BF16, Self::F16)
) {
return Self::F32;
}
let promote_int_to_float = |int: Self, float: Self| -> Self {
match (int, float) {
(_, Self::F64) => Self::F64,
(Self::I64, _) => Self::F64, (_, Self::F32) => Self::F32,
(_, Self::F16) | (_, Self::BF16) => Self::F32, _ => float,
}
};
match (
self.is_int(),
other.is_int(),
self.is_float(),
other.is_float(),
) {
(true, false, false, true) => promote_int_to_float(self, other),
(false, true, true, false) => promote_int_to_float(other, self),
_ => {
if self.promotion_rank() >= other.promotion_rank() {
self
} else {
other
}
}
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct Element {
pub dtype: DType,
pub subtype: ElementSubtype,
pub saturating: bool,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum ElementSubtype {
Standard,
Fp8E4m3,
Fp8E5m2,
}
impl Element {
pub const fn new(dtype: DType) -> Self {
Self {
dtype,
subtype: ElementSubtype::Standard,
saturating: false,
}
}
pub const fn fp8_e4m3() -> Self {
Self {
dtype: DType::U8,
subtype: ElementSubtype::Fp8E4m3,
saturating: true,
}
}
pub const fn fp8_e5m2() -> Self {
Self {
dtype: DType::U8,
subtype: ElementSubtype::Fp8E5m2,
saturating: true,
}
}
pub const fn saturating(self) -> Self {
Self {
saturating: true,
..self
}
}
}
impl std::fmt::Display for DType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::F32 => write!(f, "f32"),
Self::F16 => write!(f, "f16"),
Self::BF16 => write!(f, "bf16"),
Self::F64 => write!(f, "f64"),
Self::I8 => write!(f, "i8"),
Self::I16 => write!(f, "i16"),
Self::I32 => write!(f, "i32"),
Self::I64 => write!(f, "i64"),
Self::U8 => write!(f, "u8"),
Self::U32 => write!(f, "u32"),
Self::Bool => write!(f, "bool"),
Self::C64 => write!(f, "c64"),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn element_constructors() {
let f = Element::new(DType::F32);
assert_eq!(f.dtype, DType::F32);
assert_eq!(f.subtype, ElementSubtype::Standard);
assert!(!f.saturating);
let e4 = Element::fp8_e4m3();
assert_eq!(e4.subtype, ElementSubtype::Fp8E4m3);
assert!(e4.saturating);
assert_eq!(e4.dtype, DType::U8);
let s = Element::new(DType::I32).saturating();
assert!(s.saturating);
assert_eq!(s.dtype, DType::I32);
}
#[test]
fn promote_same() {
assert_eq!(DType::F32.promote(DType::F32), DType::F32);
assert_eq!(DType::I8.promote(DType::I8), DType::I8);
}
#[test]
fn promote_int_widening() {
assert_eq!(DType::I8.promote(DType::I16), DType::I16);
assert_eq!(DType::I32.promote(DType::I64), DType::I64);
}
#[test]
fn promote_int_to_float() {
assert_eq!(DType::I32.promote(DType::F32), DType::F32);
assert_eq!(DType::I64.promote(DType::F32), DType::F64);
assert_eq!(DType::I8.promote(DType::F16), DType::F32);
}
#[test]
fn promote_f16_bf16_goes_to_f32() {
assert_eq!(DType::F16.promote(DType::BF16), DType::F32);
assert_eq!(DType::BF16.promote(DType::F16), DType::F32);
}
#[test]
fn promote_is_commutative_for_well_defined_pairs() {
let pairs = [
(DType::F32, DType::F16),
(DType::I32, DType::F64),
(DType::Bool, DType::I8),
];
for (a, b) in pairs {
assert_eq!(
a.promote(b),
b.promote(a),
"promote({a},{b}) should equal promote({b},{a})"
);
}
}
}