#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
pub enum Dtype {
F32,
F16,
U32,
I32,
I8,
}
impl Dtype {
pub const fn bytes_per_elem(self) -> usize {
match self {
Dtype::F32 | Dtype::U32 | Dtype::I32 => 4,
Dtype::F16 => 2,
Dtype::I8 => 1,
}
}
pub const fn name(self) -> &'static str {
match self {
Dtype::F32 => "f32",
Dtype::F16 => "f16",
Dtype::U32 => "u32",
Dtype::I32 => "i32",
Dtype::I8 => "i8",
}
}
}
pub trait HostDtype: Copy + Send + Sync + 'static {
const DTYPE: Dtype;
}
impl HostDtype for u32 {
const DTYPE: Dtype = Dtype::U32;
}
impl HostDtype for i32 {
const DTYPE: Dtype = Dtype::I32;
}
impl HostDtype for f32 {
const DTYPE: Dtype = Dtype::F32;
}
impl HostDtype for half::f16 {
const DTYPE: Dtype = Dtype::F16;
}
impl HostDtype for i8 {
const DTYPE: Dtype = Dtype::I8;
}