1use cubecl_common::{e2m1, e2m1x2, e3m2, e5m2};
2use cubecl_core::{
3 ir::{BarrierLevel, ElemType, FloatKind, IntKind, UIntKind},
4 tf32,
5};
6use half::{bf16, f16};
7use std::fmt::Display;
8
9use super::Dialect;
10
11#[derive(Debug, Clone, PartialEq, Eq, Copy, Hash)]
12pub enum Elem<D: Dialect> {
13 TF32,
14 F32,
15 F64,
16 F16,
17 F16x2,
18 BF16,
19 BF16x2,
20 FP4(FP4Kind),
21 FP4x2(FP4Kind),
22 FP6(FP6Kind),
23 FP6x2(FP6Kind),
24 FP8(FP8Kind),
25 FP8x2(FP8Kind),
26 I8,
27 I16,
28 I32,
29 I64,
30 U8,
31 U16,
32 U32,
33 U64,
34 Bool,
35 Barrier(BarrierLevel),
36 Atomic(AtomicKind<D>),
37 _Dialect(std::marker::PhantomData<D>),
38}
39
40#[derive(Debug, Clone, PartialEq, Eq, Copy, Hash)]
41pub enum FP4Kind {
42 E2M1,
43}
44
45#[derive(Debug, Clone, PartialEq, Eq, Copy, Hash)]
46pub enum FP6Kind {
47 E2M3,
48 E3M2,
49}
50
51#[derive(Debug, Clone, PartialEq, Eq, Copy, Hash)]
52pub enum FP8Kind {
53 E4M3,
54 E5M2,
55 UE8M0,
56}
57
58impl Display for FP4Kind {
59 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
60 let name = match self {
61 FP4Kind::E2M1 => "e2m1",
62 };
63 f.write_str(name)
64 }
65}
66
67impl Display for FP6Kind {
68 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
69 let name = match self {
70 FP6Kind::E2M3 => "e2m3",
71 FP6Kind::E3M2 => "e3m2",
72 };
73 f.write_str(name)
74 }
75}
76
77impl Display for FP8Kind {
78 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
79 let name = match self {
80 FP8Kind::E4M3 => "e4m3",
81 FP8Kind::E5M2 => "e5m2",
82 FP8Kind::UE8M0 => "e8m0",
83 };
84 f.write_str(name)
85 }
86}
87
88impl<D: Dialect> Display for Elem<D> {
89 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
90 D::compile_elem(f, self, false)
91 }
92}
93
94impl<D: Dialect> Elem<D> {
95 pub const fn size(&self) -> usize {
96 match self {
97 Elem::FP4(_) => core::mem::size_of::<e2m1>(),
98 Elem::FP4x2(_) => core::mem::size_of::<e2m1x2>(),
99 Elem::FP6(_) => core::mem::size_of::<e3m2>(),
100 Elem::FP6x2(_) => 2 * core::mem::size_of::<e3m2>(),
101 Elem::FP8(_) => core::mem::size_of::<e5m2>(),
102 Elem::FP8x2(_) => 2 * core::mem::size_of::<e5m2>(),
103 Elem::F16 => core::mem::size_of::<f16>(),
104 Elem::F16x2 => 2 * core::mem::size_of::<f16>(),
105 Elem::BF16x2 => 2 * core::mem::size_of::<bf16>(),
106 Elem::BF16 => core::mem::size_of::<bf16>(),
107 Elem::TF32 => core::mem::size_of::<tf32>(),
108 Elem::F32 => core::mem::size_of::<f32>(),
109 Elem::F64 => core::mem::size_of::<f64>(),
110 Elem::I8 => core::mem::size_of::<i8>(),
111 Elem::I16 => core::mem::size_of::<i16>(),
112 Elem::I32 => core::mem::size_of::<i32>(),
113 Elem::I64 => core::mem::size_of::<i64>(),
114 Elem::U8 => core::mem::size_of::<u8>(),
115 Elem::U16 => core::mem::size_of::<u16>(),
116 Elem::U32 => core::mem::size_of::<u32>(),
117 Elem::U64 => core::mem::size_of::<u64>(),
118 Elem::Bool => core::mem::size_of::<bool>(),
119 Elem::Barrier(_) => core::mem::size_of::<u64>(),
120 Elem::Atomic(AtomicKind::I32) => core::mem::size_of::<i32>(),
121 Elem::Atomic(AtomicKind::I64) => core::mem::size_of::<i64>(),
122 Elem::Atomic(AtomicKind::U32) => core::mem::size_of::<u32>(),
123 Elem::Atomic(AtomicKind::U64) => core::mem::size_of::<u64>(),
124 Elem::Atomic(AtomicKind::F16) => core::mem::size_of::<f16>(),
125 Elem::Atomic(AtomicKind::F16x2) => core::mem::size_of::<f32>(),
126 Elem::Atomic(AtomicKind::BF16) => core::mem::size_of::<bf16>(),
127 Elem::Atomic(AtomicKind::BF16x2) => core::mem::size_of::<f32>(),
128 Elem::Atomic(AtomicKind::F32) => core::mem::size_of::<f32>(),
129 Elem::Atomic(AtomicKind::F64) => core::mem::size_of::<f64>(),
130 Elem::Atomic(AtomicKind::_Dialect(_)) => 0,
131 Elem::_Dialect(_) => 0,
132 }
133 }
134
135 pub const fn size_bits(&self) -> usize {
136 match self {
137 Elem::FP4(_) => 4,
138 other => other.size() * 8,
139 }
140 }
141
142 pub const fn unpacked(&self) -> Self {
143 match self {
144 Elem::FP4x2(ty) => Elem::FP4(*ty),
145 Elem::FP6x2(ty) => Elem::FP6(*ty),
146 Elem::FP8x2(ty) => Elem::FP8(*ty),
147 Elem::F16x2 => Elem::F16,
148 Elem::BF16x2 => Elem::BF16,
149 elem => *elem,
150 }
151 }
152
153 pub const fn packing_factor(&self) -> usize {
155 match self {
156 Elem::FP4x2(_) | Elem::FP6x2(_) | Elem::FP8x2(_) | Elem::F16x2 | Elem::BF16x2 => 2,
157 _ => 1,
158 }
159 }
160
161 pub const fn ident(&self) -> &str {
162 match self {
163 Elem::FP4(_) => "fp4",
164 Elem::FP4x2(_) => "fp4x2",
165 Elem::FP6(_) => "fp6",
166 Elem::FP6x2(_) => "fp6x2",
167 Elem::FP8(_) => "fp8",
168 Elem::FP8x2(_) => "fp8x2",
169 Elem::F16 => "f16",
170 Elem::F16x2 => "f16x2",
171 Elem::BF16x2 => "bf16x2",
172 Elem::BF16 => "bf16",
173 Elem::TF32 => "tf32",
174 Elem::F32 => "f32",
175 Elem::F64 => "f64",
176 Elem::I8 => "i8",
177 Elem::I16 => "i16",
178 Elem::I32 => "i32",
179 Elem::I64 => "i64",
180 Elem::U8 => "u8",
181 Elem::U16 => "u16",
182 Elem::U32 => "u32",
183 Elem::U64 => "u64",
184 Elem::Bool => "bool",
185 Elem::Barrier(BarrierLevel::Cube) => "cuda::barrier<cuda::thread_scope_block>",
186 Elem::Barrier(BarrierLevel::Unit) => "cuda::barrier<cuda::thread_scope_thread>",
187 Elem::Atomic(_) => "atomic",
188 Elem::_Dialect(_) => "",
189 }
190 }
191}
192
193#[derive(Debug, Clone, PartialEq, Eq, Copy, Hash)]
194pub enum AtomicKind<D: Dialect> {
195 I32,
196 I64,
197 U32,
198 U64,
199 F16,
200 F16x2,
201 BF16,
202 BF16x2,
203 F32,
204 F64,
205 _Dialect(std::marker::PhantomData<D>),
207}
208
209impl<D: Dialect> From<ElemType> for AtomicKind<D> {
210 fn from(value: ElemType) -> Self {
211 match value {
212 ElemType::Float(FloatKind::F16) => AtomicKind::F16,
213 ElemType::Float(FloatKind::BF16) => AtomicKind::BF16,
214 ElemType::Float(FloatKind::F32) => AtomicKind::F32,
215 ElemType::Float(FloatKind::F64) => AtomicKind::F64,
216 ElemType::Int(IntKind::I32) => AtomicKind::I32,
217 ElemType::Int(IntKind::I64) => AtomicKind::I64,
218 ElemType::UInt(UIntKind::U32) => AtomicKind::U32,
219 ElemType::UInt(UIntKind::U64) => AtomicKind::U64,
220 other => unimplemented!("Invalid atomic type: {other}"),
221 }
222 }
223}
224
225impl<D: Dialect> AtomicKind<D> {
226 pub fn as_elem(self) -> Elem<D> {
227 match self {
228 AtomicKind::I32 => Elem::I32,
229 AtomicKind::I64 => Elem::I64,
230 AtomicKind::U32 => Elem::U32,
231 AtomicKind::U64 => Elem::U64,
232 AtomicKind::F16 => Elem::F16,
233 AtomicKind::F16x2 => Elem::F16x2,
234 AtomicKind::BF16 => Elem::BF16,
235 AtomicKind::BF16x2 => Elem::BF16x2,
236 AtomicKind::F32 => Elem::F32,
237 AtomicKind::F64 => Elem::F64,
238 AtomicKind::_Dialect(_) => unreachable!(),
239 }
240 }
241}
242
243impl<D: Dialect> Display for AtomicKind<D> {
244 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
245 D::compile_atomic_kind(f, self)
246 }
247}