use std::fmt;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum DType {
F32,
F16,
BF16,
U8,
U16,
U32,
I32,
}
impl DType {
#[inline]
pub fn size_of(self) -> usize {
match self {
DType::F32 | DType::U32 | DType::I32 => 4,
DType::F16 | DType::BF16 | DType::U16 => 2,
DType::U8 => 1,
}
}
#[inline]
pub fn name(self) -> &'static str {
match self {
DType::F32 => "f32",
DType::F16 => "f16",
DType::BF16 => "bf16",
DType::U8 => "u8",
DType::U16 => "u16",
DType::U32 => "u32",
DType::I32 => "i32",
}
}
}
impl fmt::Display for DType {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(self.name())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_dtype_size_of() {
assert_eq!(DType::F32.size_of(), 4);
assert_eq!(DType::F16.size_of(), 2);
assert_eq!(DType::BF16.size_of(), 2);
assert_eq!(DType::U8.size_of(), 1);
assert_eq!(DType::U16.size_of(), 2);
assert_eq!(DType::U32.size_of(), 4);
assert_eq!(DType::I32.size_of(), 4);
}
#[test]
fn test_dtype_display() {
assert_eq!(format!("{}", DType::F32), "f32");
assert_eq!(format!("{}", DType::BF16), "bf16");
assert_eq!(format!("{}", DType::U8), "u8");
}
}