1use cubecl_common::{e2m1, e2m1x2, e3m2, e5m2};
2use cubecl_core::{
3 ir::{BarrierLevel, ElemType, FloatKind, IntKind, UIntKind},
4 tf32,
5};
6use half::{bf16, f16};
7use std::fmt::Display;
8
9use super::Dialect;
10
11#[derive(Debug, Clone, PartialEq, Eq, Copy, Hash)]
12pub enum Elem<D: Dialect> {
13 TF32,
14 F32,
15 F64,
16 F16,
17 F16x2,
18 BF16,
19 BF16x2,
20 FP4(FP4Kind),
21 FP4x2(FP4Kind),
22 FP6(FP6Kind),
23 FP6x2(FP6Kind),
24 FP8(FP8Kind),
25 FP8x2(FP8Kind),
26 I8,
27 I16,
28 I32,
29 I64,
30 U8,
31 U16,
32 U32,
33 U64,
34 Bool,
35 Barrier(BarrierLevel),
36 Atomic(AtomicKind<D>),
37 _Dialect(std::marker::PhantomData<D>),
38}
39
40#[derive(Debug, Clone, PartialEq, Eq, Copy, Hash)]
41pub enum FP4Kind {
42 E2M1,
43}
44
45#[derive(Debug, Clone, PartialEq, Eq, Copy, Hash)]
46pub enum FP6Kind {
47 E2M3,
48 E3M2,
49}
50
51#[derive(Debug, Clone, PartialEq, Eq, Copy, Hash)]
52pub enum FP8Kind {
53 E4M3,
54 E5M2,
55 UE8M0,
56}
57
58impl Display for FP4Kind {
59 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
60 let name = match self {
61 FP4Kind::E2M1 => "e2m1",
62 };
63 f.write_str(name)
64 }
65}
66
67impl Display for FP6Kind {
68 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
69 let name = match self {
70 FP6Kind::E2M3 => "e2m3",
71 FP6Kind::E3M2 => "e3m2",
72 };
73 f.write_str(name)
74 }
75}
76
77impl Display for FP8Kind {
78 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
79 let name = match self {
80 FP8Kind::E4M3 => "e4m3",
81 FP8Kind::E5M2 => "e5m2",
82 FP8Kind::UE8M0 => "e8m0",
83 };
84 f.write_str(name)
85 }
86}
87
88impl<D: Dialect> Display for Elem<D> {
89 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
90 D::compile_elem(f, self, false)
91 }
92}
93
94impl<D: Dialect> Elem<D> {
95 pub const fn size(&self) -> usize {
96 match self {
97 Elem::FP4(_) => core::mem::size_of::<e2m1>(),
98 Elem::FP4x2(_) => core::mem::size_of::<e2m1x2>(),
99 Elem::FP6(_) => core::mem::size_of::<e3m2>(),
100 Elem::FP6x2(_) => 2 * core::mem::size_of::<e3m2>(),
101 Elem::FP8(_) => core::mem::size_of::<e5m2>(),
102 Elem::FP8x2(_) => 2 * core::mem::size_of::<e5m2>(),
103 Elem::F16 => core::mem::size_of::<f16>(),
104 Elem::F16x2 => 2 * core::mem::size_of::<f16>(),
105 Elem::BF16x2 => 2 * core::mem::size_of::<bf16>(),
106 Elem::BF16 => core::mem::size_of::<bf16>(),
107 Elem::TF32 => core::mem::size_of::<tf32>(),
108 Elem::F32 => core::mem::size_of::<f32>(),
109 Elem::F64 => core::mem::size_of::<f64>(),
110 Elem::I8 => core::mem::size_of::<i8>(),
111 Elem::I16 => core::mem::size_of::<i16>(),
112 Elem::I32 => core::mem::size_of::<i32>(),
113 Elem::I64 => core::mem::size_of::<i64>(),
114 Elem::U8 => core::mem::size_of::<u8>(),
115 Elem::U16 => core::mem::size_of::<u16>(),
116 Elem::U32 => core::mem::size_of::<u32>(),
117 Elem::U64 => core::mem::size_of::<u64>(),
118 Elem::Bool => core::mem::size_of::<bool>(),
119 Elem::Barrier(_) => core::mem::size_of::<u64>(),
120 Elem::Atomic(AtomicKind::I32) => core::mem::size_of::<i32>(),
121 Elem::Atomic(AtomicKind::I64) => core::mem::size_of::<i64>(),
122 Elem::Atomic(AtomicKind::U32) => core::mem::size_of::<u32>(),
123 Elem::Atomic(AtomicKind::U64) => core::mem::size_of::<u64>(),
124 Elem::Atomic(AtomicKind::F16) => core::mem::size_of::<f16>(),
125 Elem::Atomic(AtomicKind::BF16) => core::mem::size_of::<bf16>(),
126 Elem::Atomic(AtomicKind::F32) => core::mem::size_of::<f32>(),
127 Elem::Atomic(AtomicKind::F64) => core::mem::size_of::<f64>(),
128 Elem::Atomic(AtomicKind::_Dialect(_)) => 0,
129 Elem::_Dialect(_) => 0,
130 }
131 }
132
133 pub const fn size_bits(&self) -> usize {
134 match self {
135 Elem::FP4(_) => 4,
136 other => other.size() * 8,
137 }
138 }
139
140 pub const fn unpacked(&self) -> Self {
141 match self {
142 Elem::FP4x2(ty) => Elem::FP4(*ty),
143 Elem::FP6x2(ty) => Elem::FP6(*ty),
144 Elem::FP8x2(ty) => Elem::FP8(*ty),
145 Elem::F16x2 => Elem::F16,
146 Elem::BF16x2 => Elem::BF16,
147 elem => *elem,
148 }
149 }
150
151 pub const fn packing_factor(&self) -> usize {
153 match self {
154 Elem::FP4x2(_) | Elem::FP6x2(_) | Elem::FP8x2(_) | Elem::F16x2 | Elem::BF16x2 => 2,
155 _ => 1,
156 }
157 }
158
159 pub const fn ident(&self) -> &str {
160 match self {
161 Elem::FP4(_) => "fp4",
162 Elem::FP4x2(_) => "fp4x2",
163 Elem::FP6(_) => "fp6",
164 Elem::FP6x2(_) => "fp6x2",
165 Elem::FP8(_) => "fp8",
166 Elem::FP8x2(_) => "fp8x2",
167 Elem::F16 => "f16",
168 Elem::F16x2 => "f16x2",
169 Elem::BF16x2 => "bf16x2",
170 Elem::BF16 => "bf16",
171 Elem::TF32 => "tf32",
172 Elem::F32 => "f32",
173 Elem::F64 => "f64",
174 Elem::I8 => "i8",
175 Elem::I16 => "i16",
176 Elem::I32 => "i32",
177 Elem::I64 => "i64",
178 Elem::U8 => "u8",
179 Elem::U16 => "u16",
180 Elem::U32 => "u32",
181 Elem::U64 => "u64",
182 Elem::Bool => "bool",
183 Elem::Barrier(BarrierLevel::Cube) => "cuda::barrier<cuda::thread_scope_block>",
184 Elem::Barrier(BarrierLevel::Unit) => "cuda::barrier<cuda::thread_scope_thread>",
185 Elem::Atomic(_) => "atomic",
186 Elem::_Dialect(_) => "",
187 }
188 }
189}
190
191#[derive(Debug, Clone, PartialEq, Eq, Copy, Hash)]
192pub enum AtomicKind<D: Dialect> {
193 I32,
194 I64,
195 U32,
196 U64,
197 F16,
198 BF16,
199 F32,
200 F64,
201 _Dialect(std::marker::PhantomData<D>),
203}
204
205impl<D: Dialect> From<ElemType> for AtomicKind<D> {
206 fn from(value: ElemType) -> Self {
207 match value {
208 ElemType::Float(FloatKind::F16) => AtomicKind::F16,
209 ElemType::Float(FloatKind::BF16) => AtomicKind::BF16,
210 ElemType::Float(FloatKind::F32) => AtomicKind::F32,
211 ElemType::Float(FloatKind::F64) => AtomicKind::F64,
212 ElemType::Int(IntKind::I32) => AtomicKind::I32,
213 ElemType::Int(IntKind::I64) => AtomicKind::I64,
214 ElemType::UInt(UIntKind::U32) => AtomicKind::U32,
215 ElemType::UInt(UIntKind::U64) => AtomicKind::U64,
216 other => unimplemented!("Invalid atomic type: {other}"),
217 }
218 }
219}
220
221impl<D: Dialect> Display for AtomicKind<D> {
222 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
223 D::compile_atomic_kind(f, self)
224 }
225}