use core::fmt;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum CutlassDtype {
F32,
F64,
F16,
Bf16,
F8E4m3,
F8E5m2,
F4E2m1,
I8,
I32,
U8,
}
impl CutlassDtype {
pub fn as_cutlass_type(self) -> &'static str {
match self {
CutlassDtype::F32 => "float",
CutlassDtype::F64 => "double",
CutlassDtype::F16 => "cutlass::half_t",
CutlassDtype::Bf16 => "cutlass::bfloat16_t",
CutlassDtype::F8E4m3 => "cutlass::float_e4m3_t",
CutlassDtype::F8E5m2 => "cutlass::float_e5m2_t",
CutlassDtype::F4E2m1 => "cutlass::float_e2m1_t",
CutlassDtype::I8 => "int8_t",
CutlassDtype::I32 => "int32_t",
CutlassDtype::U8 => "uint8_t",
}
}
pub fn short_name(self) -> &'static str {
match self {
CutlassDtype::F32 => "f32",
CutlassDtype::F64 => "f64",
CutlassDtype::F16 => "f16",
CutlassDtype::Bf16 => "bf16",
CutlassDtype::F8E4m3 => "f8e4m3",
CutlassDtype::F8E5m2 => "f8e5m2",
CutlassDtype::F4E2m1 => "f4e2m1",
CutlassDtype::I8 => "i8",
CutlassDtype::I32 => "i32",
CutlassDtype::U8 => "u8",
}
}
pub fn size_bits(self) -> u32 {
match self {
CutlassDtype::F64 => 64,
CutlassDtype::F32 | CutlassDtype::I32 => 32,
CutlassDtype::F16 | CutlassDtype::Bf16 => 16,
CutlassDtype::F8E4m3 | CutlassDtype::F8E5m2 | CutlassDtype::I8 | CutlassDtype::U8 => 8,
CutlassDtype::F4E2m1 => 4,
}
}
}
impl fmt::Display for CutlassDtype {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(self.short_name())
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum SmArch {
Sm80,
Sm86,
Sm89,
Sm90,
Sm90a,
Sm100,
Sm120,
}
impl SmArch {
pub fn nvrtc_flag(self) -> &'static str {
match self {
SmArch::Sm80 => "--gpu-architecture=compute_80",
SmArch::Sm86 => "--gpu-architecture=compute_86",
SmArch::Sm89 => "--gpu-architecture=compute_89",
SmArch::Sm90 => "--gpu-architecture=compute_90",
SmArch::Sm90a => "--gpu-architecture=compute_90a",
SmArch::Sm100 => "--gpu-architecture=compute_100",
SmArch::Sm120 => "--gpu-architecture=compute_120",
}
}
pub fn short_name(self) -> &'static str {
match self {
SmArch::Sm80 => "sm_80",
SmArch::Sm86 => "sm_86",
SmArch::Sm89 => "sm_89",
SmArch::Sm90 => "sm_90",
SmArch::Sm90a => "sm_90a",
SmArch::Sm100 => "sm_100",
SmArch::Sm120 => "sm_120",
}
}
pub fn supports_fp8(self) -> bool {
matches!(
self,
SmArch::Sm89 | SmArch::Sm90 | SmArch::Sm90a | SmArch::Sm100 | SmArch::Sm120
)
}
pub fn supports_fp4(self) -> bool {
matches!(self, SmArch::Sm100 | SmArch::Sm120)
}
pub fn supports_persistent_kernels(self) -> bool {
matches!(
self,
SmArch::Sm90 | SmArch::Sm90a | SmArch::Sm100 | SmArch::Sm120
)
}
}
impl fmt::Display for SmArch {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(self.short_name())
}
}
pub trait GemmSupported: Copy + Send + Sync + 'static {
const DTYPE: CutlassDtype;
}
impl GemmSupported for f32 {
const DTYPE: CutlassDtype = CutlassDtype::F32;
}
impl GemmSupported for f64 {
const DTYPE: CutlassDtype = CutlassDtype::F64;
}
impl GemmSupported for i8 {
const DTYPE: CutlassDtype = CutlassDtype::I8;
}
impl GemmSupported for i32 {
const DTYPE: CutlassDtype = CutlassDtype::I32;
}
impl GemmSupported for u8 {
const DTYPE: CutlassDtype = CutlassDtype::U8;
}
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Hash)]
#[repr(transparent)]
pub struct F16(pub u16);
impl GemmSupported for F16 {
const DTYPE: CutlassDtype = CutlassDtype::F16;
}
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Hash)]
#[repr(transparent)]
pub struct Bf16(pub u16);
impl GemmSupported for Bf16 {
const DTYPE: CutlassDtype = CutlassDtype::Bf16;
}
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Hash)]
#[repr(transparent)]
pub struct F8E4m3(pub u8);
impl GemmSupported for F8E4m3 {
const DTYPE: CutlassDtype = CutlassDtype::F8E4m3;
}
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Hash)]
#[repr(transparent)]
pub struct F8E5m2(pub u8);
impl GemmSupported for F8E5m2 {
const DTYPE: CutlassDtype = CutlassDtype::F8E5m2;
}
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Hash)]
#[repr(transparent)]
pub struct F4E2m1(pub u8);
impl GemmSupported for F4E2m1 {
const DTYPE: CutlassDtype = CutlassDtype::F4E2m1;
}
pub fn is_supported_for(dtype: CutlassDtype, arch: SmArch) -> bool {
match dtype {
CutlassDtype::F8E4m3 | CutlassDtype::F8E5m2 => arch.supports_fp8(),
CutlassDtype::F4E2m1 => arch.supports_fp4(),
_ => true,
}
}
pub fn is_fp8_supported(arch: SmArch) -> bool {
arch.supports_fp8()
}
pub fn is_fp4_supported(arch: SmArch) -> bool {
arch.supports_fp4()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn arch_capability_predicates() {
assert!(!SmArch::Sm80.supports_fp8());
assert!(SmArch::Sm89.supports_fp8());
assert!(SmArch::Sm90a.supports_fp8());
assert!(SmArch::Sm100.supports_fp4());
assert!(!SmArch::Sm89.supports_fp4());
assert!(SmArch::Sm90a.supports_persistent_kernels());
assert!(!SmArch::Sm80.supports_persistent_kernels());
}
#[test]
fn dtype_short_names_unique() {
let all = [
CutlassDtype::F32,
CutlassDtype::F64,
CutlassDtype::F16,
CutlassDtype::Bf16,
CutlassDtype::F8E4m3,
CutlassDtype::F8E5m2,
CutlassDtype::F4E2m1,
CutlassDtype::I8,
CutlassDtype::I32,
CutlassDtype::U8,
];
let mut seen: Vec<&'static str> = Vec::new();
for dt in all {
assert!(!seen.contains(&dt.short_name()));
seen.push(dt.short_name());
}
}
#[test]
fn is_supported_for_matrix() {
assert!(is_supported_for(CutlassDtype::F32, SmArch::Sm80));
assert!(!is_supported_for(CutlassDtype::F8E4m3, SmArch::Sm80));
assert!(is_supported_for(CutlassDtype::F8E4m3, SmArch::Sm90a));
assert!(!is_supported_for(CutlassDtype::F4E2m1, SmArch::Sm89));
assert!(is_supported_for(CutlassDtype::F4E2m1, SmArch::Sm100));
}
}