use cubecl_common::{e2m1, e2m1x2, e3m2, e5m2};
use cubecl_core::{
ir::{BarrierLevel, ElemType, FloatKind, IntKind, UIntKind},
tf32,
};
use half::{bf16, f16};
use std::fmt::Display;
use super::Dialect;
#[derive(Debug, Clone, PartialEq, Eq, Copy, Hash)]
pub enum Elem<D: Dialect> {
TF32,
F32,
F64,
F16,
F16x2,
BF16,
BF16x2,
FP4(FP4Kind),
FP4x2(FP4Kind),
FP6(FP6Kind),
FP6x2(FP6Kind),
FP8(FP8Kind),
FP8x2(FP8Kind),
I8,
I16,
I32,
I64,
U8,
U16,
U32,
U64,
Bool,
Barrier(BarrierLevel),
Atomic(AtomicKind<D>),
_Dialect(std::marker::PhantomData<D>),
}
#[derive(Debug, Clone, PartialEq, Eq, Copy, Hash)]
pub enum FP4Kind {
E2M1,
}
#[derive(Debug, Clone, PartialEq, Eq, Copy, Hash)]
pub enum FP6Kind {
E2M3,
E3M2,
}
#[derive(Debug, Clone, PartialEq, Eq, Copy, Hash)]
pub enum FP8Kind {
E4M3,
E5M2,
UE8M0,
}
impl Display for FP4Kind {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let name = match self {
FP4Kind::E2M1 => "e2m1",
};
f.write_str(name)
}
}
impl Display for FP6Kind {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let name = match self {
FP6Kind::E2M3 => "e2m3",
FP6Kind::E3M2 => "e3m2",
};
f.write_str(name)
}
}
impl Display for FP8Kind {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let name = match self {
FP8Kind::E4M3 => "e4m3",
FP8Kind::E5M2 => "e5m2",
FP8Kind::UE8M0 => "e8m0",
};
f.write_str(name)
}
}
impl<D: Dialect> Display for Elem<D> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
D::compile_elem(f, self, false)
}
}
impl<D: Dialect> Elem<D> {
pub const fn size(&self) -> usize {
match self {
Elem::FP4(_) => core::mem::size_of::<e2m1>(),
Elem::FP4x2(_) => core::mem::size_of::<e2m1x2>(),
Elem::FP6(_) => core::mem::size_of::<e3m2>(),
Elem::FP6x2(_) => 2 * core::mem::size_of::<e3m2>(),
Elem::FP8(_) => core::mem::size_of::<e5m2>(),
Elem::FP8x2(_) => 2 * core::mem::size_of::<e5m2>(),
Elem::F16 => core::mem::size_of::<f16>(),
Elem::F16x2 => 2 * core::mem::size_of::<f16>(),
Elem::BF16x2 => 2 * core::mem::size_of::<bf16>(),
Elem::BF16 => core::mem::size_of::<bf16>(),
Elem::TF32 => core::mem::size_of::<tf32>(),
Elem::F32 => core::mem::size_of::<f32>(),
Elem::F64 => core::mem::size_of::<f64>(),
Elem::I8 => core::mem::size_of::<i8>(),
Elem::I16 => core::mem::size_of::<i16>(),
Elem::I32 => core::mem::size_of::<i32>(),
Elem::I64 => core::mem::size_of::<i64>(),
Elem::U8 => core::mem::size_of::<u8>(),
Elem::U16 => core::mem::size_of::<u16>(),
Elem::U32 => core::mem::size_of::<u32>(),
Elem::U64 => core::mem::size_of::<u64>(),
Elem::Bool => core::mem::size_of::<bool>(),
Elem::Barrier(_) => core::mem::size_of::<u64>(),
Elem::Atomic(AtomicKind::I32) => core::mem::size_of::<i32>(),
Elem::Atomic(AtomicKind::I64) => core::mem::size_of::<i64>(),
Elem::Atomic(AtomicKind::U32) => core::mem::size_of::<u32>(),
Elem::Atomic(AtomicKind::U64) => core::mem::size_of::<u64>(),
Elem::Atomic(AtomicKind::F16) => core::mem::size_of::<f16>(),
Elem::Atomic(AtomicKind::F16x2) => core::mem::size_of::<f32>(),
Elem::Atomic(AtomicKind::BF16) => core::mem::size_of::<bf16>(),
Elem::Atomic(AtomicKind::BF16x2) => core::mem::size_of::<f32>(),
Elem::Atomic(AtomicKind::F32) => core::mem::size_of::<f32>(),
Elem::Atomic(AtomicKind::F64) => core::mem::size_of::<f64>(),
Elem::Atomic(AtomicKind::_Dialect(_)) => 0,
Elem::_Dialect(_) => 0,
}
}
pub const fn size_bits(&self) -> usize {
match self {
Elem::FP4(_) => 4,
other => other.size() * 8,
}
}
pub const fn unpacked(&self) -> Self {
match self {
Elem::FP4x2(ty) => Elem::FP4(*ty),
Elem::FP6x2(ty) => Elem::FP6(*ty),
Elem::FP8x2(ty) => Elem::FP8(*ty),
Elem::F16x2 => Elem::F16,
Elem::BF16x2 => Elem::BF16,
elem => *elem,
}
}
pub const fn packing_factor(&self) -> usize {
match self {
Elem::FP4x2(_) | Elem::FP6x2(_) | Elem::FP8x2(_) | Elem::F16x2 | Elem::BF16x2 => 2,
_ => 1,
}
}
pub const fn ident(&self) -> &str {
match self {
Elem::FP4(_) => "fp4",
Elem::FP4x2(_) => "fp4x2",
Elem::FP6(_) => "fp6",
Elem::FP6x2(_) => "fp6x2",
Elem::FP8(_) => "fp8",
Elem::FP8x2(_) => "fp8x2",
Elem::F16 => "f16",
Elem::F16x2 => "f16x2",
Elem::BF16x2 => "bf16x2",
Elem::BF16 => "bf16",
Elem::TF32 => "tf32",
Elem::F32 => "f32",
Elem::F64 => "f64",
Elem::I8 => "i8",
Elem::I16 => "i16",
Elem::I32 => "i32",
Elem::I64 => "i64",
Elem::U8 => "u8",
Elem::U16 => "u16",
Elem::U32 => "u32",
Elem::U64 => "u64",
Elem::Bool => "bool",
Elem::Barrier(BarrierLevel::Cube) => "cuda::barrier<cuda::thread_scope_block>",
Elem::Barrier(BarrierLevel::Unit) => "cuda::barrier<cuda::thread_scope_thread>",
Elem::Atomic(_) => "atomic",
Elem::_Dialect(_) => "",
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Copy, Hash)]
pub enum AtomicKind<D: Dialect> {
I32,
I64,
U32,
U64,
F16,
F16x2,
BF16,
BF16x2,
F32,
F64,
_Dialect(std::marker::PhantomData<D>),
}
impl<D: Dialect> From<ElemType> for AtomicKind<D> {
fn from(value: ElemType) -> Self {
match value {
ElemType::Float(FloatKind::F16) => AtomicKind::F16,
ElemType::Float(FloatKind::BF16) => AtomicKind::BF16,
ElemType::Float(FloatKind::F32) => AtomicKind::F32,
ElemType::Float(FloatKind::F64) => AtomicKind::F64,
ElemType::Int(IntKind::I32) => AtomicKind::I32,
ElemType::Int(IntKind::I64) => AtomicKind::I64,
ElemType::UInt(UIntKind::U32) => AtomicKind::U32,
ElemType::UInt(UIntKind::U64) => AtomicKind::U64,
other => unimplemented!("Invalid atomic type: {other}"),
}
}
}
impl<D: Dialect> AtomicKind<D> {
pub fn as_elem(self) -> Elem<D> {
match self {
AtomicKind::I32 => Elem::I32,
AtomicKind::I64 => Elem::I64,
AtomicKind::U32 => Elem::U32,
AtomicKind::U64 => Elem::U64,
AtomicKind::F16 => Elem::F16,
AtomicKind::F16x2 => Elem::F16x2,
AtomicKind::BF16 => Elem::BF16,
AtomicKind::BF16x2 => Elem::BF16x2,
AtomicKind::F32 => Elem::F32,
AtomicKind::F64 => Elem::F64,
AtomicKind::_Dialect(_) => unreachable!(),
}
}
}
impl<D: Dialect> Display for AtomicKind<D> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
D::compile_atomic_kind(f, self)
}
}