use super::{ConstantValue, Variable, VariableKind};
use crate::{BarrierLevel, TypeHash};
use core::fmt::Display;
use cubecl_common::{
e2m1, e2m1x2, e2m3, e3m2, e4m3, e5m2, flex32,
quant::scheme::{QuantParam, QuantValue},
tf32, ue8m0,
};
use derive_more::From;
use half::{bf16, f16};
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[derive(Debug, Clone, Copy, TypeHash, PartialEq, Eq, Hash, PartialOrd, Ord)]
#[allow(missing_docs)]
pub enum FloatKind {
E2M1,
E2M3,
E3M2,
E4M3,
E5M2,
UE8M0,
F16,
BF16,
Flex32,
F32,
TF32,
F64,
}
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[derive(Debug, Clone, Copy, TypeHash, PartialEq, Eq, Hash, PartialOrd, Ord)]
#[allow(missing_docs)]
pub enum IntKind {
I8,
I16,
I32,
I64,
}
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[derive(Debug, Clone, Copy, TypeHash, PartialEq, Eq, Hash, PartialOrd, Ord)]
#[allow(missing_docs)]
pub enum UIntKind {
U8,
U16,
U32,
U64,
}
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[derive(Debug, Clone, Copy, TypeHash, PartialEq, Eq, Hash, PartialOrd, Ord, From)]
#[allow(missing_docs)]
pub enum ElemType {
Float(FloatKind),
Int(IntKind),
UInt(UIntKind),
Bool,
}
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[derive(Debug, Clone, Copy, TypeHash, PartialEq, Eq, Hash, PartialOrd, Ord)]
pub enum OpaqueType {
Barrier(BarrierLevel),
}
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[derive(Debug, Clone, Copy, TypeHash, PartialEq, Eq, Hash, PartialOrd, Ord)]
pub enum SemanticType {
BarrierToken,
Pipeline,
TensorMap,
}
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[derive(Clone, Copy, TypeHash, PartialEq, Eq, Hash, PartialOrd, Ord)]
pub enum StorageType {
Scalar(ElemType),
Packed(ElemType, usize),
Atomic(ElemType),
Opaque(OpaqueType),
}
impl core::fmt::Debug for StorageType {
fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
struct Dummy<'a>(&'a StorageType);
impl<'a> core::fmt::Debug for Dummy<'a> {
fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
match self.0 {
StorageType::Scalar(f0) => f.debug_tuple("Scalar").field(&f0).finish(),
StorageType::Packed(f0, f1) => {
f.debug_tuple("Packed").field(&f0).field(&f1).finish()
}
StorageType::Atomic(f0) => f.debug_tuple("Atomic").field(&f0).finish(),
StorageType::Opaque(f0) => f.debug_tuple("Opaque").field(&f0).finish(),
}
}
}
write!(f, "{:?}", Dummy(self))
}
}
impl ElemType {
pub fn from_quant_param(quant_param: QuantParam) -> Self {
match quant_param {
QuantParam::F32 => Self::Float(FloatKind::F32),
QuantParam::F16 => Self::Float(FloatKind::F16),
QuantParam::BF16 => Self::Float(FloatKind::BF16),
QuantParam::UE8M0 => Self::Float(FloatKind::UE8M0),
QuantParam::UE4M3 => Self::Float(FloatKind::UE8M0),
}
}
pub fn from_quant_value(quant_value: QuantValue) -> Self {
match quant_value {
QuantValue::E5M2 => Self::Float(FloatKind::E5M2),
QuantValue::E4M3 => Self::Float(FloatKind::E4M3),
QuantValue::E2M1 => Self::Float(FloatKind::E2M1),
QuantValue::Q8F | QuantValue::Q8S => Self::Int(IntKind::I8),
other => panic!("Unsupported quant value {other:?}"),
}
}
pub fn constant(&self, val: ConstantValue) -> Variable {
Variable::constant(val, Type::scalar(*self))
}
pub const fn size(&self) -> usize {
match self {
ElemType::Float(kind) => match kind {
FloatKind::E2M1
| FloatKind::E2M3
| FloatKind::E3M2
| FloatKind::E4M3
| FloatKind::E5M2
| FloatKind::UE8M0 => core::mem::size_of::<u8>(),
FloatKind::F16 => core::mem::size_of::<half::f16>(),
FloatKind::BF16 => core::mem::size_of::<half::bf16>(),
FloatKind::F32 => core::mem::size_of::<f32>(),
FloatKind::F64 => core::mem::size_of::<f64>(),
FloatKind::Flex32 => core::mem::size_of::<f32>(),
FloatKind::TF32 => core::mem::size_of::<f32>(),
},
ElemType::Int(kind) => match kind {
IntKind::I8 => core::mem::size_of::<i8>(),
IntKind::I16 => core::mem::size_of::<i16>(),
IntKind::I32 => core::mem::size_of::<i32>(),
IntKind::I64 => core::mem::size_of::<i64>(),
},
ElemType::UInt(kind) => match kind {
UIntKind::U8 => core::mem::size_of::<u8>(),
UIntKind::U16 => core::mem::size_of::<u16>(),
UIntKind::U32 => core::mem::size_of::<u32>(),
UIntKind::U64 => core::mem::size_of::<u64>(),
},
ElemType::Bool => core::mem::size_of::<bool>(),
}
}
pub const fn size_bits(&self) -> usize {
match self {
ElemType::Float(kind) => match kind {
FloatKind::E2M3
| FloatKind::E3M2
| FloatKind::E4M3
| FloatKind::E5M2
| FloatKind::UE8M0
| FloatKind::F16
| FloatKind::BF16
| FloatKind::F32
| FloatKind::F64
| FloatKind::Flex32
| FloatKind::TF32 => self.size() * 8,
FloatKind::E2M1 => 4,
},
ElemType::Int(_) | ElemType::UInt(_) | ElemType::Bool => self.size() * 8,
}
}
pub const fn min_vector_size(&self) -> u8 {
match self {
ElemType::Float(FloatKind::E2M1) => 2,
_ => 1,
}
}
pub fn is_int(&self) -> bool {
matches!(self, ElemType::Int(_) | ElemType::UInt(_) | ElemType::Bool)
}
pub fn is_signed_int(&self) -> bool {
matches!(self, ElemType::Int(_))
}
pub fn is_unsigned_int(&self) -> bool {
matches!(self, ElemType::UInt(_) | ElemType::Bool)
}
pub fn is_float(&self) -> bool {
matches!(self, ElemType::Float(_))
}
pub fn is_bool(&self) -> bool {
matches!(self, ElemType::Bool)
}
pub fn as_float(&self) -> Option<FloatKind> {
match self {
ElemType::Float(kind) => Some(*kind),
_ => None,
}
}
pub fn max_variable(&self) -> Variable {
let value = match self {
ElemType::Float(kind) => match kind {
FloatKind::E2M1 => e2m1::MAX,
FloatKind::E2M3 => e2m3::MAX,
FloatKind::E3M2 => e3m2::MAX,
FloatKind::E4M3 => e4m3::MAX,
FloatKind::E5M2 => e5m2::MAX,
FloatKind::UE8M0 => ue8m0::MAX,
FloatKind::F16 => half::f16::MAX.to_f64(),
FloatKind::BF16 => half::bf16::MAX.to_f64(),
FloatKind::Flex32 | FloatKind::TF32 | FloatKind::F32 => f32::MAX as f64,
FloatKind::F64 => f64::MAX,
}
.into(),
ElemType::Int(kind) => match kind {
IntKind::I8 => i8::MAX as i64,
IntKind::I16 => i16::MAX as i64,
IntKind::I32 => i32::MAX as i64,
IntKind::I64 => i64::MAX,
}
.into(),
ElemType::UInt(kind) => match kind {
UIntKind::U8 => u8::MAX as u64,
UIntKind::U16 => u16::MAX as u64,
UIntKind::U32 => u32::MAX as u64,
UIntKind::U64 => u64::MAX,
}
.into(),
ElemType::Bool => true.into(),
};
Variable::new(VariableKind::Constant(value), Type::scalar(*self))
}
pub fn min_variable(&self) -> Variable {
let value = match self {
ElemType::Float(kind) => match kind {
FloatKind::E2M1 => e2m1::MIN,
FloatKind::E2M3 => e2m3::MIN,
FloatKind::E3M2 => e3m2::MIN,
FloatKind::E4M3 => e4m3::MIN,
FloatKind::E5M2 => e5m2::MIN,
FloatKind::UE8M0 => ue8m0::MIN,
FloatKind::F16 => half::f16::MIN.to_f64(),
FloatKind::BF16 => half::bf16::MIN.to_f64(),
FloatKind::Flex32 | FloatKind::TF32 | FloatKind::F32 => f32::MIN as f64,
FloatKind::F64 => f64::MIN,
}
.into(),
ElemType::Int(kind) => match kind {
IntKind::I8 => i8::MIN as i64,
IntKind::I16 => i16::MIN as i64,
IntKind::I32 => i32::MIN as i64,
IntKind::I64 => i64::MIN,
}
.into(),
ElemType::UInt(kind) => match kind {
UIntKind::U8 => u8::MIN as u64,
UIntKind::U16 => u16::MIN as u64,
UIntKind::U32 => u32::MIN as u64,
UIntKind::U64 => u64::MIN,
}
.into(),
ElemType::Bool => false.into(),
};
Variable::new(VariableKind::Constant(value), Type::scalar(*self))
}
pub fn epsilon(&self) -> f64 {
match self {
ElemType::Float(kind) => match kind {
FloatKind::E2M1 => 0.5 * (e2m1::MAX - e2m1::MIN),
FloatKind::E2M3 => 0.5 * (e2m3::MAX - e2m3::MIN),
FloatKind::E3M2 => 0.5 * (e3m2::MAX - e3m2::MIN),
FloatKind::E4M3 => 0.5 * (e4m3::MAX - e4m3::MIN),
FloatKind::E5M2 => 0.5 * (e5m2::MAX - e5m2::MIN),
FloatKind::UE8M0 => 0.5 * (ue8m0::MAX - ue8m0::MIN),
FloatKind::F16 => half::f16::EPSILON.to_f64(),
FloatKind::BF16 => 0.0078125, FloatKind::Flex32 | FloatKind::F32 | FloatKind::TF32 => f32::EPSILON.into(),
FloatKind::F64 => f64::EPSILON,
},
ElemType::Int(_) | ElemType::UInt(_) => 1.0, ElemType::Bool => 1.0,
}
}
}
impl OpaqueType {
pub const fn size(&self) -> usize {
match self {
OpaqueType::Barrier(_) => 8,
}
}
pub const fn size_bits(&self) -> usize {
match self {
OpaqueType::Barrier(_) => 64,
}
}
}
impl StorageType {
pub fn elem_type(&self) -> ElemType {
match self {
StorageType::Scalar(ty) | StorageType::Packed(ty, _) | StorageType::Atomic(ty) => *ty,
StorageType::Opaque(_) => unimplemented!("Can't get elem type for opaque type"),
}
}
pub fn packing_factor(&self) -> usize {
match self {
StorageType::Packed(_, factor) => *factor,
_ => 1,
}
}
pub fn is_atomic(&self) -> bool {
matches!(self, StorageType::Atomic(_))
}
pub fn size(&self) -> usize {
self.size_bits().div_ceil(8)
}
pub fn size_bits(&self) -> usize {
match self {
StorageType::Packed(ty, factor) => ty.size_bits() * *factor,
StorageType::Scalar(ty) | StorageType::Atomic(ty) => ty.size_bits(),
StorageType::Opaque(ty) => ty.size_bits(),
}
}
pub fn is_int(&self) -> bool {
self.elem_type().is_int()
}
pub fn is_signed_int(&self) -> bool {
self.elem_type().is_signed_int()
}
pub fn is_unsigned_int(&self) -> bool {
self.elem_type().is_unsigned_int()
}
pub fn is_float(&self) -> bool {
self.elem_type().is_float()
}
pub fn is_bool(&self) -> bool {
self.elem_type().is_bool()
}
pub fn epsilon(&self) -> f64 {
match self {
StorageType::Scalar(ty) | StorageType::Atomic(ty) => ty.epsilon(),
StorageType::Packed(ty, factor) => {
ty.epsilon() * (*factor as f64)
}
StorageType::Opaque(_) => panic!("Opaque type does not have an epsilon"),
}
}
pub fn constant(&self, value: ConstantValue) -> Variable {
Variable::constant(value, Type::new(*self))
}
}
macro_rules! storage_from_elem {
($($ty: ty),*) => {
$(impl From<$ty> for StorageType {
fn from(value: $ty) -> Self {
StorageType::Scalar(value.into())
}
})*
};
}
storage_from_elem!(FloatKind, IntKind, UIntKind, ElemType);
impl From<OpaqueType> for StorageType {
fn from(val: OpaqueType) -> Self {
StorageType::Opaque(val)
}
}
impl<T: Into<StorageType>> From<T> for Type {
fn from(val: T) -> Self {
Type::new(val.into())
}
}
impl From<SemanticType> for Type {
fn from(val: SemanticType) -> Self {
Type::semantic(val)
}
}
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[derive(Debug, Clone, Copy, TypeHash, PartialEq, Eq, Hash, PartialOrd, Ord)]
pub enum Type {
Scalar(StorageType),
Vector(StorageType, VectorSize),
Semantic(SemanticType),
}
pub type VectorSize = usize;
impl Type {
pub fn elem_type(&self) -> ElemType {
self.storage_type().elem_type()
}
pub fn new(storage: StorageType) -> Self {
Type::Scalar(storage)
}
pub fn scalar(elem: ElemType) -> Self {
Self::new(StorageType::Scalar(elem))
}
pub fn semantic(ty: SemanticType) -> Self {
Self::Semantic(ty)
}
pub fn with_vector_size(self, vector_size: VectorSize) -> Type {
match vector_size > 1 {
true => Type::Vector(self.storage_type(), vector_size),
false => Type::Scalar(self.storage_type()),
}
}
pub fn with_storage_type(self, storage: StorageType) -> Type {
let vector_size = self.vector_size();
Type::new(storage).with_vector_size(vector_size)
}
pub fn vector_size(&self) -> VectorSize {
match self {
Type::Scalar(_) => 1,
Type::Vector(_, vector_size) => *vector_size,
Type::Semantic(_) => 0,
}
}
pub fn size(&self) -> usize {
match self {
Type::Scalar(ty) => ty.size(),
Type::Vector(ty, vector_size) => ty.size() * *vector_size,
Type::Semantic(_) => 0,
}
}
pub fn size_bits(&self) -> usize {
match self {
Type::Scalar(ty) => ty.size_bits(),
Type::Vector(ty, vector_size) => ty.size_bits() * *vector_size,
Type::Semantic(_) => 0,
}
}
pub fn packing_factor(&self) -> usize {
match self {
Type::Scalar(ty) => ty.packing_factor(),
Type::Vector(ty, _) => ty.packing_factor(),
Type::Semantic(_) => 1,
}
}
pub fn is_atomic(&self) -> bool {
!self.is_semantic() && self.storage_type().is_atomic()
}
pub fn is_int(&self) -> bool {
!self.is_semantic() && self.storage_type().is_int()
}
pub fn is_signed_int(&self) -> bool {
!self.is_semantic() && self.storage_type().is_signed_int()
}
pub fn is_unsigned_int(&self) -> bool {
!self.is_semantic() && self.storage_type().is_unsigned_int()
}
pub fn is_float(&self) -> bool {
!self.is_semantic() && self.storage_type().is_float()
}
pub fn is_bool(&self) -> bool {
!self.is_semantic() && self.storage_type().is_bool()
}
pub fn storage_type(&self) -> StorageType {
match self {
Type::Scalar(ty) | Type::Vector(ty, _) => *ty,
Type::Semantic(_) => unimplemented!("Can't get storage for semantic type"),
}
}
pub fn is_semantic(&self) -> bool {
match self {
Type::Scalar(_) | Type::Vector(_, _) => false,
Type::Semantic(_) => true,
}
}
pub fn constant(&self, value: ConstantValue) -> Variable {
Variable::constant(value, *self)
}
}
impl Display for Type {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
match self {
Type::Scalar(ty) => write!(f, "{ty}"),
Type::Vector(ty, vector_size) => write!(f, "vector<{ty}, {vector_size}>"),
Type::Semantic(ty) => write!(f, "{ty}"),
}
}
}
impl Display for StorageType {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
match self {
StorageType::Scalar(ty) => write!(f, "{ty}"),
StorageType::Packed(ty, factor) => write!(f, "packed<{ty}, {factor}>"),
StorageType::Atomic(ty) => write!(f, "atomic<{ty}>"),
StorageType::Opaque(ty) => write!(f, "{ty}"),
}
}
}
impl Display for ElemType {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
match self {
Self::Float(kind) => match kind {
FloatKind::E2M1 => f.write_str("e2m1"),
FloatKind::E2M3 => f.write_str("e2m3"),
FloatKind::E3M2 => f.write_str("e3m2"),
FloatKind::E4M3 => f.write_str("e4m3"),
FloatKind::E5M2 => f.write_str("e5m2"),
FloatKind::UE8M0 => f.write_str("ue8m0"),
FloatKind::F16 => f.write_str("f16"),
FloatKind::BF16 => f.write_str("bf16"),
FloatKind::Flex32 => f.write_str("flex32"),
FloatKind::TF32 => f.write_str("tf32"),
FloatKind::F32 => f.write_str("f32"),
FloatKind::F64 => f.write_str("f64"),
},
Self::Int(kind) => match kind {
IntKind::I8 => f.write_str("i8"),
IntKind::I16 => f.write_str("i16"),
IntKind::I32 => f.write_str("i32"),
IntKind::I64 => f.write_str("i64"),
},
Self::UInt(kind) => match kind {
UIntKind::U8 => f.write_str("u8"),
UIntKind::U16 => f.write_str("u16"),
UIntKind::U32 => f.write_str("u32"),
UIntKind::U64 => f.write_str("u64"),
},
Self::Bool => f.write_str("bool"),
}
}
}
impl Display for SemanticType {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
match self {
SemanticType::BarrierToken => f.write_str("barrier_token"),
SemanticType::Pipeline => f.write_str("pipeline"),
SemanticType::TensorMap => f.write_str("tensor_map"),
}
}
}
impl Display for OpaqueType {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
match self {
OpaqueType::Barrier(level) => write!(f, "barrier<{level}>"),
}
}
}
impl From<e2m1x2> for Variable {
fn from(_value: e2m1x2) -> Self {
unimplemented!("Can't currently construct e2m1x2")
}
}
impl From<e2m3> for Variable {
fn from(_value: e2m3) -> Self {
unimplemented!("Can't currently construct fp6")
}
}
impl From<e3m2> for Variable {
fn from(_value: e3m2) -> Self {
unimplemented!("Can't currently construct fp6")
}
}
impl From<i8> for ConstantValue {
fn from(value: i8) -> Self {
ConstantValue::Int(value as i64)
}
}
impl From<i16> for ConstantValue {
fn from(value: i16) -> Self {
ConstantValue::Int(value as i64)
}
}
impl From<i32> for ConstantValue {
fn from(value: i32) -> Self {
ConstantValue::Int(value as i64)
}
}
impl From<isize> for ConstantValue {
fn from(value: isize) -> Self {
ConstantValue::Int(value as i64)
}
}
impl From<u8> for ConstantValue {
fn from(value: u8) -> Self {
ConstantValue::UInt(value as u64)
}
}
impl From<u16> for ConstantValue {
fn from(value: u16) -> Self {
ConstantValue::UInt(value as u64)
}
}
impl From<u32> for ConstantValue {
fn from(value: u32) -> Self {
ConstantValue::UInt(value as u64)
}
}
impl From<usize> for ConstantValue {
fn from(value: usize) -> Self {
ConstantValue::UInt(value as u64)
}
}
impl From<e2m1> for ConstantValue {
fn from(value: e2m1) -> Self {
ConstantValue::Float(value.to_f64())
}
}
impl From<e4m3> for ConstantValue {
fn from(value: e4m3) -> Self {
ConstantValue::Float(value.to_f64())
}
}
impl From<e5m2> for ConstantValue {
fn from(value: e5m2) -> Self {
ConstantValue::Float(value.to_f64())
}
}
impl From<ue8m0> for ConstantValue {
fn from(value: ue8m0) -> Self {
ConstantValue::Float(value.to_f64())
}
}
impl From<half::f16> for ConstantValue {
fn from(value: half::f16) -> Self {
ConstantValue::Float(value.to_f64())
}
}
impl From<half::bf16> for ConstantValue {
fn from(value: half::bf16) -> Self {
ConstantValue::Float(value.to_f64())
}
}
impl From<flex32> for ConstantValue {
fn from(value: flex32) -> Self {
ConstantValue::Float(value.to_f64())
}
}
impl From<tf32> for ConstantValue {
fn from(value: tf32) -> Self {
ConstantValue::Float(value.to_f64())
}
}
impl From<f32> for ConstantValue {
fn from(value: f32) -> Self {
ConstantValue::Float(value as f64)
}
}
macro_rules! impl_into_variable {
($($ty: ty => $kind: path,)*) => {
$(
impl From<$ty> for Variable {
fn from(value: $ty) -> Self {
Variable::new(VariableKind::Constant(value.into()), $kind.into())
}
}
)*
};
}
impl_into_variable!(
bool => ElemType::Bool,
i8 => IntKind::I8,
i16 => IntKind::I16,
i32 => IntKind::I32,
i64 => IntKind::I64,
u8 => UIntKind::U8,
u16 => UIntKind::U16,
u32 => UIntKind::U32,
u64 => UIntKind::U64,
e2m1 => FloatKind::E2M1,
e4m3 => FloatKind::E4M3,
e5m2 => FloatKind::E5M2,
ue8m0 => FloatKind::UE8M0,
f16 => FloatKind::F16,
bf16 => FloatKind::BF16,
f32 => FloatKind::F32,
flex32 => FloatKind::Flex32,
tf32 => FloatKind::TF32,
f64 => FloatKind::F64,
usize => UIntKind::U32,
isize => IntKind::I32,
);