#[repr(u8)]
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub enum DataTypeCode {
Int = 0,
UInt = 1,
Float = 2,
OpaqueHandle = 3,
Bfloat = 4,
Complex = 5,
Bool = 6,
Float8E3m4 = 7,
Float8E4m3 = 8,
Float8E4m3b11fnuz = 9,
Float8E4m3fn = 10,
Float8E4m3fnuz = 11,
Float8E5m2 = 12,
Float8E5m2fnuz = 13,
Float8E8m0fnu = 14,
Float6E2m3fn = 15,
Float6E3m2fn = 16,
Float4E2m1fn = 17,
}
#[repr(C)]
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub struct DataType {
pub code: DataTypeCode,
pub bits: u8,
pub lanes: u16,
}
impl From<(DataTypeCode, u8, u16)> for DataType {
fn from(value: (DataTypeCode, u8, u16)) -> Self {
Self {
code: value.0,
bits: value.1,
lanes: value.2,
}
}
}
impl Default for DataType {
fn default() -> Self {
Self::F32
}
}
impl DataType {
pub const BF16: Self = Self {
code: DataTypeCode::Bfloat,
bits: 16,
lanes: 1,
};
pub const BOOL: Self = Self {
code: DataTypeCode::Bool,
bits: 8,
lanes: 1,
};
pub const F16: Self = Self {
code: DataTypeCode::Float,
bits: 16,
lanes: 1,
};
pub const F32: Self = Self {
code: DataTypeCode::Float,
bits: 32,
lanes: 1,
};
pub const F64: Self = Self {
code: DataTypeCode::Float,
bits: 64,
lanes: 1,
};
pub const I128: Self = Self {
code: DataTypeCode::Int,
bits: 128,
lanes: 1,
};
pub const I16: Self = Self {
code: DataTypeCode::Int,
bits: 16,
lanes: 1,
};
pub const I32: Self = Self {
code: DataTypeCode::Int,
bits: 32,
lanes: 1,
};
pub const I64: Self = Self {
code: DataTypeCode::Int,
bits: 64,
lanes: 1,
};
pub const I8: Self = Self {
code: DataTypeCode::Int,
bits: 8,
lanes: 1,
};
pub const U128: Self = Self {
code: DataTypeCode::UInt,
bits: 128,
lanes: 1,
};
pub const U16: Self = Self {
code: DataTypeCode::UInt,
bits: 16,
lanes: 1,
};
pub const U32: Self = Self {
code: DataTypeCode::UInt,
bits: 32,
lanes: 1,
};
pub const U64: Self = Self {
code: DataTypeCode::UInt,
bits: 64,
lanes: 1,
};
pub const U8: Self = Self {
code: DataTypeCode::UInt,
bits: 8,
lanes: 1,
};
}
impl DataType {
pub const fn size(&self) -> usize {
(self.bits as u32 * self.lanes as u32).div_ceil(8) as usize
}
}
#[cfg(test)]
mod tests {
pub use super::*;
#[test]
fn test_size() {
assert_eq!(DataType::F32.size(), 4);
assert_eq!(DataType::BOOL.size(), 1);
}
}