1use cubecl_common::{e2m1, e2m1x2, e3m2, e5m2};
2use cubecl_core::{
3 ir::{ElemType, FloatKind, IntKind, UIntKind},
4 tf32,
5};
6use half::{bf16, f16};
7use std::fmt::Display;
8
9use super::Dialect;
10
11#[derive(Debug, Clone, PartialEq, Eq, Copy, Hash)]
12pub enum Elem<D: Dialect> {
13 TF32,
14 F32,
15 F64,
16 F16,
17 F16x2,
18 BF16,
19 BF16x2,
20 FP4(FP4Kind),
21 FP4x2(FP4Kind),
22 FP6(FP6Kind),
23 FP6x2(FP6Kind),
24 FP8(FP8Kind),
25 FP8x2(FP8Kind),
26 I8,
27 I16,
28 I32,
29 I64,
30 U8,
31 U16,
32 U32,
33 U64,
34 Bool,
35 Atomic(AtomicKind<D>),
36 _Dialect(std::marker::PhantomData<D>),
37}
38
39#[derive(Debug, Clone, PartialEq, Eq, Copy, Hash)]
40pub enum FP4Kind {
41 E2M1,
42}
43
44#[derive(Debug, Clone, PartialEq, Eq, Copy, Hash)]
45pub enum FP6Kind {
46 E2M3,
47 E3M2,
48}
49
50#[derive(Debug, Clone, PartialEq, Eq, Copy, Hash)]
51pub enum FP8Kind {
52 E4M3,
53 E5M2,
54 UE8M0,
55}
56
57impl Display for FP4Kind {
58 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
59 let name = match self {
60 FP4Kind::E2M1 => "e2m1",
61 };
62 f.write_str(name)
63 }
64}
65
66impl Display for FP6Kind {
67 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
68 let name = match self {
69 FP6Kind::E2M3 => "e2m3",
70 FP6Kind::E3M2 => "e3m2",
71 };
72 f.write_str(name)
73 }
74}
75
76impl Display for FP8Kind {
77 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
78 let name = match self {
79 FP8Kind::E4M3 => "e4m3",
80 FP8Kind::E5M2 => "e5m2",
81 FP8Kind::UE8M0 => "e8m0",
82 };
83 f.write_str(name)
84 }
85}
86
87impl<D: Dialect> Display for Elem<D> {
88 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
89 D::compile_elem(f, self, false)
90 }
91}
92
93impl<D: Dialect> Elem<D> {
94 pub const fn size(&self) -> usize {
95 match self {
96 Elem::FP4(_) => core::mem::size_of::<e2m1>(),
97 Elem::FP4x2(_) => core::mem::size_of::<e2m1x2>(),
98 Elem::FP6(_) => core::mem::size_of::<e3m2>(),
99 Elem::FP6x2(_) => 2 * core::mem::size_of::<e3m2>(),
100 Elem::FP8(_) => core::mem::size_of::<e5m2>(),
101 Elem::FP8x2(_) => 2 * core::mem::size_of::<e5m2>(),
102 Elem::F16 => core::mem::size_of::<f16>(),
103 Elem::F16x2 => 2 * core::mem::size_of::<f16>(),
104 Elem::BF16x2 => 2 * core::mem::size_of::<bf16>(),
105 Elem::BF16 => core::mem::size_of::<bf16>(),
106 Elem::TF32 => core::mem::size_of::<tf32>(),
107 Elem::F32 => core::mem::size_of::<f32>(),
108 Elem::F64 => core::mem::size_of::<f64>(),
109 Elem::I8 => core::mem::size_of::<i8>(),
110 Elem::I16 => core::mem::size_of::<i16>(),
111 Elem::I32 => core::mem::size_of::<i32>(),
112 Elem::I64 => core::mem::size_of::<i64>(),
113 Elem::U8 => core::mem::size_of::<u8>(),
114 Elem::U16 => core::mem::size_of::<u16>(),
115 Elem::U32 => core::mem::size_of::<u32>(),
116 Elem::U64 => core::mem::size_of::<u64>(),
117 Elem::Bool => core::mem::size_of::<bool>(),
118 Elem::Atomic(AtomicKind::I32) => core::mem::size_of::<i32>(),
119 Elem::Atomic(AtomicKind::I64) => core::mem::size_of::<i64>(),
120 Elem::Atomic(AtomicKind::U32) => core::mem::size_of::<u32>(),
121 Elem::Atomic(AtomicKind::U64) => core::mem::size_of::<u64>(),
122 Elem::Atomic(AtomicKind::F16) => core::mem::size_of::<f16>(),
123 Elem::Atomic(AtomicKind::BF16) => core::mem::size_of::<bf16>(),
124 Elem::Atomic(AtomicKind::F32) => core::mem::size_of::<f32>(),
125 Elem::Atomic(AtomicKind::F64) => core::mem::size_of::<f64>(),
126 Elem::Atomic(AtomicKind::_Dialect(_)) => 0,
127 Elem::_Dialect(_) => 0,
128 }
129 }
130
131 pub const fn size_bits(&self) -> usize {
132 match self {
133 Elem::FP4(_) => 4,
134 other => other.size() * 8,
135 }
136 }
137
138 pub const fn unpacked(&self) -> Self {
139 match self {
140 Elem::FP4x2(ty) => Elem::FP4(*ty),
141 Elem::FP6x2(ty) => Elem::FP6(*ty),
142 Elem::FP8x2(ty) => Elem::FP8(*ty),
143 Elem::F16x2 => Elem::F16,
144 Elem::BF16x2 => Elem::BF16,
145 elem => *elem,
146 }
147 }
148
149 pub const fn packing_factor(&self) -> usize {
151 match self {
152 Elem::FP4x2(_) | Elem::FP6x2(_) | Elem::FP8x2(_) | Elem::F16x2 | Elem::BF16x2 => 2,
153 _ => 1,
154 }
155 }
156
157 pub const fn ident(&self) -> &str {
158 match self {
159 Elem::FP4(_) => "fp4",
160 Elem::FP4x2(_) => "fp4x2",
161 Elem::FP6(_) => "fp6",
162 Elem::FP6x2(_) => "fp6x2",
163 Elem::FP8(_) => "fp8",
164 Elem::FP8x2(_) => "fp8x2",
165 Elem::F16 => "f16",
166 Elem::F16x2 => "f16x2",
167 Elem::BF16x2 => "bf16x2",
168 Elem::BF16 => "bf16",
169 Elem::TF32 => "tf32",
170 Elem::F32 => "f32",
171 Elem::F64 => "f64",
172 Elem::I8 => "i8",
173 Elem::I16 => "i16",
174 Elem::I32 => "i32",
175 Elem::I64 => "i64",
176 Elem::U8 => "u8",
177 Elem::U16 => "u16",
178 Elem::U32 => "u32",
179 Elem::U64 => "u64",
180 Elem::Bool => "bool",
181 Elem::Atomic(_) => "atomic",
182 Elem::_Dialect(_) => "",
183 }
184 }
185}
186
187#[derive(Debug, Clone, PartialEq, Eq, Copy, Hash)]
188pub enum AtomicKind<D: Dialect> {
189 I32,
190 I64,
191 U32,
192 U64,
193 F16,
194 BF16,
195 F32,
196 F64,
197 _Dialect(std::marker::PhantomData<D>),
199}
200
201impl<D: Dialect> From<ElemType> for AtomicKind<D> {
202 fn from(value: ElemType) -> Self {
203 match value {
204 ElemType::Float(FloatKind::F16) => AtomicKind::F16,
205 ElemType::Float(FloatKind::BF16) => AtomicKind::BF16,
206 ElemType::Float(FloatKind::F32) => AtomicKind::F32,
207 ElemType::Float(FloatKind::F64) => AtomicKind::F64,
208 ElemType::Int(IntKind::I32) => AtomicKind::I32,
209 ElemType::Int(IntKind::I64) => AtomicKind::I64,
210 ElemType::UInt(UIntKind::U32) => AtomicKind::U32,
211 ElemType::UInt(UIntKind::U64) => AtomicKind::U64,
212 other => unimplemented!("Invalid atomic type: {other}"),
213 }
214 }
215}
216
217impl<D: Dialect> Display for AtomicKind<D> {
218 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
219 D::compile_atomic_kind(f, self)
220 }
221}