1use cubecl_common::{e2m1x2, e3m2, e5m2};
2use cubecl_core::tf32;
3use half::{bf16, f16};
4use std::fmt::Display;
5
6use super::Dialect;
7
8#[derive(Debug, Clone, PartialEq, Eq, Copy, Hash)]
9pub enum Elem<D: Dialect> {
10 TF32,
11 F32,
12 F64,
13 F16,
14 F16x2,
15 BF16,
16 BF16x2,
17 FP4(FP4Kind),
18 FP4x2(FP4Kind),
19 FP6(FP6Kind),
20 FP6x2(FP6Kind),
21 FP8(FP8Kind),
22 FP8x2(FP8Kind),
23 I8,
24 I16,
25 I32,
26 I64,
27 U8,
28 U16,
29 U32,
30 U64,
31 Bool,
32 Atomic(AtomicKind<D>),
33 _Dialect(std::marker::PhantomData<D>),
34}
35
36#[derive(Debug, Clone, PartialEq, Eq, Copy, Hash)]
37pub enum FP4Kind {
38 E2M1,
39}
40
41#[derive(Debug, Clone, PartialEq, Eq, Copy, Hash)]
42pub enum FP6Kind {
43 E2M3,
44 E3M2,
45}
46
47#[derive(Debug, Clone, PartialEq, Eq, Copy, Hash)]
48pub enum FP8Kind {
49 E4M3,
50 E5M2,
51 UE8M0,
52}
53
54impl Display for FP4Kind {
55 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
56 let name = match self {
57 FP4Kind::E2M1 => "e2m1",
58 };
59 f.write_str(name)
60 }
61}
62
63impl Display for FP6Kind {
64 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
65 let name = match self {
66 FP6Kind::E2M3 => "e2m3",
67 FP6Kind::E3M2 => "e3m2",
68 };
69 f.write_str(name)
70 }
71}
72
73impl Display for FP8Kind {
74 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
75 let name = match self {
76 FP8Kind::E4M3 => "e4m3",
77 FP8Kind::E5M2 => "e5m2",
78 FP8Kind::UE8M0 => "e8m0",
79 };
80 f.write_str(name)
81 }
82}
83
84impl<D: Dialect> Display for Elem<D> {
85 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
86 D::compile_elem(f, self, false)
87 }
88}
89
90impl<D: Dialect> Elem<D> {
91 pub const fn size(&self) -> usize {
92 match self {
93 Elem::FP4(_) => panic!("Can't get byte size of sub-byte type"),
94 Elem::FP4x2(_) => core::mem::size_of::<e2m1x2>(),
95 Elem::FP6(_) => core::mem::size_of::<e3m2>(),
96 Elem::FP6x2(_) => 2 * core::mem::size_of::<e3m2>(),
97 Elem::FP8(_) => core::mem::size_of::<e5m2>(),
98 Elem::FP8x2(_) => 2 * core::mem::size_of::<e5m2>(),
99 Elem::F16 => core::mem::size_of::<f16>(),
100 Elem::F16x2 => 2 * core::mem::size_of::<f16>(),
101 Elem::BF16x2 => 2 * core::mem::size_of::<bf16>(),
102 Elem::BF16 => core::mem::size_of::<bf16>(),
103 Elem::TF32 => core::mem::size_of::<tf32>(),
104 Elem::F32 => core::mem::size_of::<f32>(),
105 Elem::F64 => core::mem::size_of::<f64>(),
106 Elem::I8 => core::mem::size_of::<i8>(),
107 Elem::I16 => core::mem::size_of::<i16>(),
108 Elem::I32 => core::mem::size_of::<i32>(),
109 Elem::I64 => core::mem::size_of::<i64>(),
110 Elem::U8 => core::mem::size_of::<u8>(),
111 Elem::U16 => core::mem::size_of::<u16>(),
112 Elem::U32 => core::mem::size_of::<u32>(),
113 Elem::U64 => core::mem::size_of::<u64>(),
114 Elem::Bool => core::mem::size_of::<bool>(),
115 Elem::Atomic(AtomicKind::I32) => core::mem::size_of::<i32>(),
116 Elem::Atomic(AtomicKind::I64) => core::mem::size_of::<i64>(),
117 Elem::Atomic(AtomicKind::U32) => core::mem::size_of::<u32>(),
118 Elem::Atomic(AtomicKind::U64) => core::mem::size_of::<u64>(),
119 Elem::Atomic(AtomicKind::F16) => core::mem::size_of::<f16>(),
120 Elem::Atomic(AtomicKind::BF16) => core::mem::size_of::<bf16>(),
121 Elem::Atomic(AtomicKind::F32) => core::mem::size_of::<f32>(),
122 Elem::Atomic(AtomicKind::F64) => core::mem::size_of::<f64>(),
123 Elem::Atomic(AtomicKind::_Dialect(_)) => 0,
124 Elem::_Dialect(_) => 0,
125 }
126 }
127
128 pub const fn size_bits(&self) -> usize {
129 match self {
130 Elem::FP4(_) => 4,
131 other => other.size() * 8,
132 }
133 }
134}
135
136#[derive(Debug, Clone, PartialEq, Eq, Copy, Hash)]
137pub enum AtomicKind<D: Dialect> {
138 I32,
139 I64,
140 U32,
141 U64,
142 F16,
143 BF16,
144 F32,
145 F64,
146 _Dialect(std::marker::PhantomData<D>),
148}
149
150impl<D: Dialect> Display for AtomicKind<D> {
151 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
152 D::compile_atomic_kind(f, self)
153 }
154}