1use cubecl_common::{e2m1, e2m1x2, e4m3, e5m2, flex32, tf32, ue8m0};
2use cubecl_ir::StorageType;
3
4use crate::{
5    ir::{ElemType, FloatKind, IntKind, UIntKind},
6    prelude::Numeric,
7};
8
9pub trait CubeElement: core::fmt::Debug + Send + Sync + 'static + Clone + bytemuck::Pod {
11    fn type_name() -> &'static str;
13    fn as_bytes(slice: &[Self]) -> &[u8];
15    fn from_bytes(bytes: &[u8]) -> &[Self];
17    fn cube_type() -> StorageType;
19    fn maximum_value() -> Self;
21    fn minimum_value() -> Self;
23}
24
25impl CubeElement for u64 {
26    fn type_name() -> &'static str {
27        "u64"
28    }
29    fn as_bytes(slice: &[Self]) -> &[u8] {
30        bytemuck::cast_slice(slice)
31    }
32    fn from_bytes(bytes: &[u8]) -> &[Self] {
33        bytemuck::cast_slice(bytes)
34    }
35    fn cube_type() -> StorageType {
36        ElemType::UInt(UIntKind::U64).into()
37    }
38    fn maximum_value() -> Self {
39        u64::MAX
40    }
41    fn minimum_value() -> Self {
42        u64::MIN
43    }
44}
45
46impl CubeElement for u32 {
47    fn type_name() -> &'static str {
48        "u32"
49    }
50    fn as_bytes(slice: &[Self]) -> &[u8] {
51        bytemuck::cast_slice(slice)
52    }
53    fn from_bytes(bytes: &[u8]) -> &[Self] {
54        bytemuck::cast_slice(bytes)
55    }
56    fn cube_type() -> StorageType {
57        ElemType::UInt(UIntKind::U32).into()
58    }
59    fn maximum_value() -> Self {
60        u32::MAX
61    }
62    fn minimum_value() -> Self {
63        u32::MIN
64    }
65}
66
67impl CubeElement for u16 {
68    fn type_name() -> &'static str {
69        "u16"
70    }
71    fn as_bytes(slice: &[Self]) -> &[u8] {
72        bytemuck::cast_slice(slice)
73    }
74    fn from_bytes(bytes: &[u8]) -> &[Self] {
75        bytemuck::cast_slice(bytes)
76    }
77    fn cube_type() -> StorageType {
78        ElemType::UInt(UIntKind::U16).into()
79    }
80    fn maximum_value() -> Self {
81        u16::MAX
82    }
83    fn minimum_value() -> Self {
84        u16::MIN
85    }
86}
87
88impl CubeElement for u8 {
89    fn type_name() -> &'static str {
90        "u8"
91    }
92    fn as_bytes(slice: &[Self]) -> &[u8] {
93        bytemuck::cast_slice(slice)
94    }
95    fn from_bytes(bytes: &[u8]) -> &[Self] {
96        bytemuck::cast_slice(bytes)
97    }
98    fn cube_type() -> StorageType {
99        ElemType::UInt(UIntKind::U8).into()
100    }
101    fn maximum_value() -> Self {
102        u8::MAX
103    }
104    fn minimum_value() -> Self {
105        u8::MIN
106    }
107}
108
109impl CubeElement for i64 {
110    fn type_name() -> &'static str {
111        "i64"
112    }
113    fn as_bytes(slice: &[Self]) -> &[u8] {
114        bytemuck::cast_slice(slice)
115    }
116    fn from_bytes(bytes: &[u8]) -> &[Self] {
117        bytemuck::cast_slice(bytes)
118    }
119    fn cube_type() -> StorageType {
120        ElemType::Int(IntKind::I64).into()
121    }
122    fn maximum_value() -> Self {
123        i64::MAX - 1
125    }
126    fn minimum_value() -> Self {
127        i64::MIN + 1
129    }
130}
131
132impl CubeElement for i32 {
133    fn type_name() -> &'static str {
134        "i32"
135    }
136    fn as_bytes(slice: &[Self]) -> &[u8] {
137        bytemuck::cast_slice(slice)
138    }
139    fn from_bytes(bytes: &[u8]) -> &[Self] {
140        bytemuck::cast_slice(bytes)
141    }
142    fn cube_type() -> StorageType {
143        ElemType::Int(IntKind::I32).into()
144    }
145    fn maximum_value() -> Self {
146        i32::MAX - 1
148    }
149    fn minimum_value() -> Self {
150        i32::MIN + 1
152    }
153}
154
155impl CubeElement for i16 {
156    fn type_name() -> &'static str {
157        "i16"
158    }
159    fn as_bytes(slice: &[Self]) -> &[u8] {
160        bytemuck::cast_slice(slice)
161    }
162    fn from_bytes(bytes: &[u8]) -> &[Self] {
163        bytemuck::cast_slice(bytes)
164    }
165    fn cube_type() -> StorageType {
166        ElemType::Int(IntKind::I16).into()
167    }
168    fn maximum_value() -> Self {
169        i16::MAX - 1
171    }
172    fn minimum_value() -> Self {
173        i16::MIN + 1
175    }
176}
177
178impl CubeElement for i8 {
179    fn type_name() -> &'static str {
180        "i8"
181    }
182    fn as_bytes(slice: &[Self]) -> &[u8] {
183        bytemuck::cast_slice(slice)
184    }
185    fn from_bytes(bytes: &[u8]) -> &[Self] {
186        bytemuck::cast_slice(bytes)
187    }
188    fn cube_type() -> StorageType {
189        ElemType::Int(IntKind::I8).into()
190    }
191    fn maximum_value() -> Self {
192        i8::MAX - 1
194    }
195    fn minimum_value() -> Self {
196        i8::MIN + 1
198    }
199}
200
201impl CubeElement for f64 {
202    fn type_name() -> &'static str {
203        "f64"
204    }
205    fn as_bytes(slice: &[Self]) -> &[u8] {
206        bytemuck::cast_slice(slice)
207    }
208    fn from_bytes(bytes: &[u8]) -> &[Self] {
209        bytemuck::cast_slice(bytes)
210    }
211    fn cube_type() -> StorageType {
212        ElemType::Float(FloatKind::F64).into()
213    }
214    fn maximum_value() -> Self {
215        f64::MAX
216    }
217    fn minimum_value() -> Self {
218        f64::MIN
219    }
220}
221
222impl CubeElement for f32 {
223    fn type_name() -> &'static str {
224        "f32"
225    }
226    fn as_bytes(slice: &[Self]) -> &[u8] {
227        bytemuck::cast_slice(slice)
228    }
229    fn from_bytes(bytes: &[u8]) -> &[Self] {
230        bytemuck::cast_slice(bytes)
231    }
232    fn cube_type() -> StorageType {
233        ElemType::Float(FloatKind::F32).into()
234    }
235    fn maximum_value() -> Self {
236        f32::MAX
237    }
238    fn minimum_value() -> Self {
239        f32::MIN
240    }
241}
242
243impl CubeElement for half::f16 {
244    fn type_name() -> &'static str {
245        "f16"
246    }
247    fn as_bytes(slice: &[Self]) -> &[u8] {
248        bytemuck::cast_slice(slice)
249    }
250    fn from_bytes(bytes: &[u8]) -> &[Self] {
251        bytemuck::cast_slice(bytes)
252    }
253    fn cube_type() -> StorageType {
254        ElemType::Float(FloatKind::F16).into()
255    }
256    fn maximum_value() -> Self {
257        half::f16::MAX
258    }
259    fn minimum_value() -> Self {
260        half::f16::MIN
261    }
262}
263
264impl CubeElement for half::bf16 {
265    fn type_name() -> &'static str {
266        "bf16"
267    }
268    fn as_bytes(slice: &[Self]) -> &[u8] {
269        bytemuck::cast_slice(slice)
270    }
271    fn from_bytes(bytes: &[u8]) -> &[Self] {
272        bytemuck::cast_slice(bytes)
273    }
274    fn cube_type() -> StorageType {
275        ElemType::Float(FloatKind::BF16).into()
276    }
277    fn maximum_value() -> Self {
278        half::bf16::MAX
279    }
280    fn minimum_value() -> Self {
281        half::bf16::MIN
282    }
283}
284
285impl CubeElement for flex32 {
286    fn type_name() -> &'static str {
287        "flex32"
288    }
289    fn as_bytes(slice: &[Self]) -> &[u8] {
290        bytemuck::cast_slice(slice)
291    }
292    fn from_bytes(bytes: &[u8]) -> &[Self] {
293        bytemuck::cast_slice(bytes)
294    }
295    fn cube_type() -> StorageType {
296        ElemType::Float(FloatKind::Flex32).into()
297    }
298    fn maximum_value() -> Self {
299        <flex32 as num_traits::Float>::max_value()
300    }
301    fn minimum_value() -> Self {
302        <flex32 as num_traits::Float>::min_value()
303    }
304}
305
306impl CubeElement for tf32 {
307    fn type_name() -> &'static str {
308        "tf32"
309    }
310
311    fn as_bytes(slice: &[Self]) -> &[u8] {
312        bytemuck::cast_slice(slice)
313    }
314
315    fn from_bytes(bytes: &[u8]) -> &[Self] {
316        bytemuck::cast_slice(bytes)
317    }
318
319    fn cube_type() -> StorageType {
320        ElemType::Float(FloatKind::TF32).into()
321    }
322
323    fn maximum_value() -> Self {
324        tf32::max_value()
325    }
326
327    fn minimum_value() -> Self {
328        tf32::min_value()
329    }
330}
331
332impl CubeElement for e4m3 {
333    fn type_name() -> &'static str {
334        "e4m3"
335    }
336
337    fn as_bytes(slice: &[Self]) -> &[u8] {
338        bytemuck::cast_slice(slice)
339    }
340
341    fn from_bytes(bytes: &[u8]) -> &[Self] {
342        bytemuck::cast_slice(bytes)
343    }
344
345    fn cube_type() -> StorageType {
346        ElemType::Float(FloatKind::E4M3).into()
347    }
348
349    fn maximum_value() -> Self {
350        e4m3::max_value()
351    }
352
353    fn minimum_value() -> Self {
354        e4m3::min_value()
355    }
356}
357
358impl CubeElement for e5m2 {
359    fn type_name() -> &'static str {
360        "e5m2"
361    }
362
363    fn as_bytes(slice: &[Self]) -> &[u8] {
364        bytemuck::cast_slice(slice)
365    }
366
367    fn from_bytes(bytes: &[u8]) -> &[Self] {
368        bytemuck::cast_slice(bytes)
369    }
370
371    fn cube_type() -> StorageType {
372        ElemType::Float(FloatKind::E5M2).into()
373    }
374
375    fn maximum_value() -> Self {
376        e5m2::max_value()
377    }
378
379    fn minimum_value() -> Self {
380        e5m2::min_value()
381    }
382}
383
384impl CubeElement for ue8m0 {
385    fn type_name() -> &'static str {
386        "ue8m0"
387    }
388
389    fn as_bytes(slice: &[Self]) -> &[u8] {
390        bytemuck::cast_slice(slice)
391    }
392
393    fn from_bytes(bytes: &[u8]) -> &[Self] {
394        bytemuck::cast_slice(bytes)
395    }
396
397    fn cube_type() -> StorageType {
398        ElemType::Float(FloatKind::UE8M0).into()
399    }
400
401    fn maximum_value() -> Self {
402        ue8m0::max_value()
403    }
404
405    fn minimum_value() -> Self {
406        ue8m0::min_value()
407    }
408}
409
410impl CubeElement for e2m1x2 {
411    fn type_name() -> &'static str {
412        "e2m1x2"
413    }
414
415    fn as_bytes(slice: &[Self]) -> &[u8] {
416        bytemuck::cast_slice(slice)
417    }
418
419    fn from_bytes(bytes: &[u8]) -> &[Self] {
420        bytemuck::cast_slice(bytes)
421    }
422
423    fn cube_type() -> StorageType {
424        StorageType::Packed(ElemType::Float(FloatKind::E2M1), 2)
425    }
426
427    fn maximum_value() -> Self {
428        let max = e2m1::MAX.to_bits() as u8;
429        e2m1x2::from_bits(max << 4 | max)
430    }
431
432    fn minimum_value() -> Self {
433        let min = e2m1::MIN.to_bits() as u8;
434        e2m1x2::from_bits(min << 4 | min)
435    }
436}