1use cubecl_common::{e2m1, e2m1x2, e4m3, e5m2, flex32, tf32, ue8m0};
2use cubecl_ir::StorageType;
3
4use crate::{
5 ir::{ElemType, FloatKind, IntKind, UIntKind},
6 prelude::Numeric,
7};
8
9pub trait CubeElement: core::fmt::Debug + Send + Sync + 'static + Clone + bytemuck::Pod {
11 fn type_name() -> &'static str;
13 fn as_bytes(slice: &[Self]) -> &[u8];
15 fn from_bytes(bytes: &[u8]) -> &[Self];
17 fn cube_type() -> StorageType;
19 fn maximum_value() -> Self;
21 fn minimum_value() -> Self;
23}
24
25impl CubeElement for u64 {
26 fn type_name() -> &'static str {
27 "u64"
28 }
29 fn as_bytes(slice: &[Self]) -> &[u8] {
30 bytemuck::cast_slice(slice)
31 }
32 fn from_bytes(bytes: &[u8]) -> &[Self] {
33 bytemuck::cast_slice(bytes)
34 }
35 fn cube_type() -> StorageType {
36 ElemType::UInt(UIntKind::U64).into()
37 }
38 fn maximum_value() -> Self {
39 u64::MAX
40 }
41 fn minimum_value() -> Self {
42 u64::MIN
43 }
44}
45
46impl CubeElement for u32 {
47 fn type_name() -> &'static str {
48 "u32"
49 }
50 fn as_bytes(slice: &[Self]) -> &[u8] {
51 bytemuck::cast_slice(slice)
52 }
53 fn from_bytes(bytes: &[u8]) -> &[Self] {
54 bytemuck::cast_slice(bytes)
55 }
56 fn cube_type() -> StorageType {
57 ElemType::UInt(UIntKind::U32).into()
58 }
59 fn maximum_value() -> Self {
60 u32::MAX
61 }
62 fn minimum_value() -> Self {
63 u32::MIN
64 }
65}
66
67impl CubeElement for u16 {
68 fn type_name() -> &'static str {
69 "u16"
70 }
71 fn as_bytes(slice: &[Self]) -> &[u8] {
72 bytemuck::cast_slice(slice)
73 }
74 fn from_bytes(bytes: &[u8]) -> &[Self] {
75 bytemuck::cast_slice(bytes)
76 }
77 fn cube_type() -> StorageType {
78 ElemType::UInt(UIntKind::U16).into()
79 }
80 fn maximum_value() -> Self {
81 u16::MAX
82 }
83 fn minimum_value() -> Self {
84 u16::MIN
85 }
86}
87
88impl CubeElement for u8 {
89 fn type_name() -> &'static str {
90 "u8"
91 }
92 fn as_bytes(slice: &[Self]) -> &[u8] {
93 bytemuck::cast_slice(slice)
94 }
95 fn from_bytes(bytes: &[u8]) -> &[Self] {
96 bytemuck::cast_slice(bytes)
97 }
98 fn cube_type() -> StorageType {
99 ElemType::UInt(UIntKind::U8).into()
100 }
101 fn maximum_value() -> Self {
102 u8::MAX
103 }
104 fn minimum_value() -> Self {
105 u8::MIN
106 }
107}
108
109impl CubeElement for i64 {
110 fn type_name() -> &'static str {
111 "i64"
112 }
113 fn as_bytes(slice: &[Self]) -> &[u8] {
114 bytemuck::cast_slice(slice)
115 }
116 fn from_bytes(bytes: &[u8]) -> &[Self] {
117 bytemuck::cast_slice(bytes)
118 }
119 fn cube_type() -> StorageType {
120 ElemType::Int(IntKind::I64).into()
121 }
122 fn maximum_value() -> Self {
123 i64::MAX - 1
125 }
126 fn minimum_value() -> Self {
127 i64::MIN + 1
129 }
130}
131
132impl CubeElement for i32 {
133 fn type_name() -> &'static str {
134 "i32"
135 }
136 fn as_bytes(slice: &[Self]) -> &[u8] {
137 bytemuck::cast_slice(slice)
138 }
139 fn from_bytes(bytes: &[u8]) -> &[Self] {
140 bytemuck::cast_slice(bytes)
141 }
142 fn cube_type() -> StorageType {
143 ElemType::Int(IntKind::I32).into()
144 }
145 fn maximum_value() -> Self {
146 i32::MAX - 1
148 }
149 fn minimum_value() -> Self {
150 i32::MIN + 1
152 }
153}
154
155impl CubeElement for i16 {
156 fn type_name() -> &'static str {
157 "i16"
158 }
159 fn as_bytes(slice: &[Self]) -> &[u8] {
160 bytemuck::cast_slice(slice)
161 }
162 fn from_bytes(bytes: &[u8]) -> &[Self] {
163 bytemuck::cast_slice(bytes)
164 }
165 fn cube_type() -> StorageType {
166 ElemType::Int(IntKind::I16).into()
167 }
168 fn maximum_value() -> Self {
169 i16::MAX - 1
171 }
172 fn minimum_value() -> Self {
173 i16::MIN + 1
175 }
176}
177
178impl CubeElement for i8 {
179 fn type_name() -> &'static str {
180 "i8"
181 }
182 fn as_bytes(slice: &[Self]) -> &[u8] {
183 bytemuck::cast_slice(slice)
184 }
185 fn from_bytes(bytes: &[u8]) -> &[Self] {
186 bytemuck::cast_slice(bytes)
187 }
188 fn cube_type() -> StorageType {
189 ElemType::Int(IntKind::I8).into()
190 }
191 fn maximum_value() -> Self {
192 i8::MAX - 1
194 }
195 fn minimum_value() -> Self {
196 i8::MIN + 1
198 }
199}
200
201impl CubeElement for f64 {
202 fn type_name() -> &'static str {
203 "f64"
204 }
205 fn as_bytes(slice: &[Self]) -> &[u8] {
206 bytemuck::cast_slice(slice)
207 }
208 fn from_bytes(bytes: &[u8]) -> &[Self] {
209 bytemuck::cast_slice(bytes)
210 }
211 fn cube_type() -> StorageType {
212 ElemType::Float(FloatKind::F64).into()
213 }
214 fn maximum_value() -> Self {
215 f64::MAX
216 }
217 fn minimum_value() -> Self {
218 f64::MIN
219 }
220}
221
222impl CubeElement for f32 {
223 fn type_name() -> &'static str {
224 "f32"
225 }
226 fn as_bytes(slice: &[Self]) -> &[u8] {
227 bytemuck::cast_slice(slice)
228 }
229 fn from_bytes(bytes: &[u8]) -> &[Self] {
230 bytemuck::cast_slice(bytes)
231 }
232 fn cube_type() -> StorageType {
233 ElemType::Float(FloatKind::F32).into()
234 }
235 fn maximum_value() -> Self {
236 f32::MAX
237 }
238 fn minimum_value() -> Self {
239 f32::MIN
240 }
241}
242
243impl CubeElement for half::f16 {
244 fn type_name() -> &'static str {
245 "f16"
246 }
247 fn as_bytes(slice: &[Self]) -> &[u8] {
248 bytemuck::cast_slice(slice)
249 }
250 fn from_bytes(bytes: &[u8]) -> &[Self] {
251 bytemuck::cast_slice(bytes)
252 }
253 fn cube_type() -> StorageType {
254 ElemType::Float(FloatKind::F16).into()
255 }
256 fn maximum_value() -> Self {
257 half::f16::MAX
258 }
259 fn minimum_value() -> Self {
260 half::f16::MIN
261 }
262}
263
264impl CubeElement for half::bf16 {
265 fn type_name() -> &'static str {
266 "bf16"
267 }
268 fn as_bytes(slice: &[Self]) -> &[u8] {
269 bytemuck::cast_slice(slice)
270 }
271 fn from_bytes(bytes: &[u8]) -> &[Self] {
272 bytemuck::cast_slice(bytes)
273 }
274 fn cube_type() -> StorageType {
275 ElemType::Float(FloatKind::BF16).into()
276 }
277 fn maximum_value() -> Self {
278 half::bf16::MAX
279 }
280 fn minimum_value() -> Self {
281 half::bf16::MIN
282 }
283}
284
285impl CubeElement for flex32 {
286 fn type_name() -> &'static str {
287 "flex32"
288 }
289 fn as_bytes(slice: &[Self]) -> &[u8] {
290 bytemuck::cast_slice(slice)
291 }
292 fn from_bytes(bytes: &[u8]) -> &[Self] {
293 bytemuck::cast_slice(bytes)
294 }
295 fn cube_type() -> StorageType {
296 ElemType::Float(FloatKind::Flex32).into()
297 }
298 fn maximum_value() -> Self {
299 <flex32 as num_traits::Float>::max_value()
300 }
301 fn minimum_value() -> Self {
302 <flex32 as num_traits::Float>::min_value()
303 }
304}
305
306impl CubeElement for tf32 {
307 fn type_name() -> &'static str {
308 "tf32"
309 }
310
311 fn as_bytes(slice: &[Self]) -> &[u8] {
312 bytemuck::cast_slice(slice)
313 }
314
315 fn from_bytes(bytes: &[u8]) -> &[Self] {
316 bytemuck::cast_slice(bytes)
317 }
318
319 fn cube_type() -> StorageType {
320 ElemType::Float(FloatKind::TF32).into()
321 }
322
323 fn maximum_value() -> Self {
324 tf32::max_value()
325 }
326
327 fn minimum_value() -> Self {
328 tf32::min_value()
329 }
330}
331
332impl CubeElement for e4m3 {
333 fn type_name() -> &'static str {
334 "e4m3"
335 }
336
337 fn as_bytes(slice: &[Self]) -> &[u8] {
338 bytemuck::cast_slice(slice)
339 }
340
341 fn from_bytes(bytes: &[u8]) -> &[Self] {
342 bytemuck::cast_slice(bytes)
343 }
344
345 fn cube_type() -> StorageType {
346 ElemType::Float(FloatKind::E4M3).into()
347 }
348
349 fn maximum_value() -> Self {
350 e4m3::max_value()
351 }
352
353 fn minimum_value() -> Self {
354 e4m3::min_value()
355 }
356}
357
358impl CubeElement for e5m2 {
359 fn type_name() -> &'static str {
360 "e5m2"
361 }
362
363 fn as_bytes(slice: &[Self]) -> &[u8] {
364 bytemuck::cast_slice(slice)
365 }
366
367 fn from_bytes(bytes: &[u8]) -> &[Self] {
368 bytemuck::cast_slice(bytes)
369 }
370
371 fn cube_type() -> StorageType {
372 ElemType::Float(FloatKind::E5M2).into()
373 }
374
375 fn maximum_value() -> Self {
376 e5m2::max_value()
377 }
378
379 fn minimum_value() -> Self {
380 e5m2::min_value()
381 }
382}
383
384impl CubeElement for ue8m0 {
385 fn type_name() -> &'static str {
386 "ue8m0"
387 }
388
389 fn as_bytes(slice: &[Self]) -> &[u8] {
390 bytemuck::cast_slice(slice)
391 }
392
393 fn from_bytes(bytes: &[u8]) -> &[Self] {
394 bytemuck::cast_slice(bytes)
395 }
396
397 fn cube_type() -> StorageType {
398 ElemType::Float(FloatKind::UE8M0).into()
399 }
400
401 fn maximum_value() -> Self {
402 ue8m0::max_value()
403 }
404
405 fn minimum_value() -> Self {
406 ue8m0::min_value()
407 }
408}
409
410impl CubeElement for e2m1x2 {
411 fn type_name() -> &'static str {
412 "e2m1x2"
413 }
414
415 fn as_bytes(slice: &[Self]) -> &[u8] {
416 bytemuck::cast_slice(slice)
417 }
418
419 fn from_bytes(bytes: &[u8]) -> &[Self] {
420 bytemuck::cast_slice(bytes)
421 }
422
423 fn cube_type() -> StorageType {
424 StorageType::Packed(ElemType::Float(FloatKind::E2M1), 2)
425 }
426
427 fn maximum_value() -> Self {
428 let max = e2m1::MAX.to_bits() as u8;
429 e2m1x2::from_bits(max << 4 | max)
430 }
431
432 fn minimum_value() -> Self {
433 let min = e2m1::MIN.to_bits() as u8;
434 e2m1x2::from_bits(min << 4 | min)
435 }
436}