cubecl_core/
pod.rs

1use cubecl_common::{e2m1, e2m1x2, e4m3, e5m2, flex32, tf32, ue8m0};
2use cubecl_ir::StorageType;
3
4use crate::{
5    ir::{ElemType, FloatKind, IntKind, UIntKind},
6    prelude::{CubePrimitive, Numeric},
7};
8
9/// The base element trait for the jit backend.
10pub trait CubeElement: core::fmt::Debug + Send + Sync + 'static + Clone + bytemuck::Pod {
11    /// Returns the name of the type.
12    fn type_name() -> &'static str;
13    /// Convert a slice of elements to a slice of bytes.
14    fn as_bytes(slice: &[Self]) -> &[u8];
15    /// Convert a slice of bytes to a slice of elements.
16    fn from_bytes(bytes: &[u8]) -> &[Self];
17    /// Element representation for `cubecl`.
18    fn cube_type() -> StorageType;
19    /// Highest possible value
20    fn maximum_value() -> Self;
21    /// Lowest possible value
22    fn minimum_value() -> Self;
23}
24
25pub trait CubeScalar: CubeElement + CubePrimitive + num_traits::NumCast {}
26
27impl<E: CubeElement + CubePrimitive + num_traits::NumCast> CubeScalar for E {}
28
29impl CubeElement for u64 {
30    fn type_name() -> &'static str {
31        "u64"
32    }
33    fn as_bytes(slice: &[Self]) -> &[u8] {
34        bytemuck::cast_slice(slice)
35    }
36    fn from_bytes(bytes: &[u8]) -> &[Self] {
37        bytemuck::cast_slice(bytes)
38    }
39    fn cube_type() -> StorageType {
40        ElemType::UInt(UIntKind::U64).into()
41    }
42    fn maximum_value() -> Self {
43        u64::MAX
44    }
45    fn minimum_value() -> Self {
46        u64::MIN
47    }
48}
49
50impl CubeElement for u32 {
51    fn type_name() -> &'static str {
52        "u32"
53    }
54    fn as_bytes(slice: &[Self]) -> &[u8] {
55        bytemuck::cast_slice(slice)
56    }
57    fn from_bytes(bytes: &[u8]) -> &[Self] {
58        bytemuck::cast_slice(bytes)
59    }
60    fn cube_type() -> StorageType {
61        ElemType::UInt(UIntKind::U32).into()
62    }
63    fn maximum_value() -> Self {
64        u32::MAX
65    }
66    fn minimum_value() -> Self {
67        u32::MIN
68    }
69}
70
71impl CubeElement for u16 {
72    fn type_name() -> &'static str {
73        "u16"
74    }
75    fn as_bytes(slice: &[Self]) -> &[u8] {
76        bytemuck::cast_slice(slice)
77    }
78    fn from_bytes(bytes: &[u8]) -> &[Self] {
79        bytemuck::cast_slice(bytes)
80    }
81    fn cube_type() -> StorageType {
82        ElemType::UInt(UIntKind::U16).into()
83    }
84    fn maximum_value() -> Self {
85        u16::MAX
86    }
87    fn minimum_value() -> Self {
88        u16::MIN
89    }
90}
91
92impl CubeElement for u8 {
93    fn type_name() -> &'static str {
94        "u8"
95    }
96    fn as_bytes(slice: &[Self]) -> &[u8] {
97        bytemuck::cast_slice(slice)
98    }
99    fn from_bytes(bytes: &[u8]) -> &[Self] {
100        bytemuck::cast_slice(bytes)
101    }
102    fn cube_type() -> StorageType {
103        ElemType::UInt(UIntKind::U8).into()
104    }
105    fn maximum_value() -> Self {
106        u8::MAX
107    }
108    fn minimum_value() -> Self {
109        u8::MIN
110    }
111}
112
113impl CubeElement for i64 {
114    fn type_name() -> &'static str {
115        "i64"
116    }
117    fn as_bytes(slice: &[Self]) -> &[u8] {
118        bytemuck::cast_slice(slice)
119    }
120    fn from_bytes(bytes: &[u8]) -> &[Self] {
121        bytemuck::cast_slice(bytes)
122    }
123    fn cube_type() -> StorageType {
124        ElemType::Int(IntKind::I64).into()
125    }
126    fn maximum_value() -> Self {
127        // Seems to cause problem for some GPU
128        i64::MAX - 1
129    }
130    fn minimum_value() -> Self {
131        // Seems to cause problem for some GPU
132        i64::MIN + 1
133    }
134}
135
136impl CubeElement for i32 {
137    fn type_name() -> &'static str {
138        "i32"
139    }
140    fn as_bytes(slice: &[Self]) -> &[u8] {
141        bytemuck::cast_slice(slice)
142    }
143    fn from_bytes(bytes: &[u8]) -> &[Self] {
144        bytemuck::cast_slice(bytes)
145    }
146    fn cube_type() -> StorageType {
147        ElemType::Int(IntKind::I32).into()
148    }
149    fn maximum_value() -> Self {
150        // Seems to cause problem for some GPU
151        i32::MAX - 1
152    }
153    fn minimum_value() -> Self {
154        // Seems to cause problem for some GPU
155        i32::MIN + 1
156    }
157}
158
159impl CubeElement for i16 {
160    fn type_name() -> &'static str {
161        "i16"
162    }
163    fn as_bytes(slice: &[Self]) -> &[u8] {
164        bytemuck::cast_slice(slice)
165    }
166    fn from_bytes(bytes: &[u8]) -> &[Self] {
167        bytemuck::cast_slice(bytes)
168    }
169    fn cube_type() -> StorageType {
170        ElemType::Int(IntKind::I16).into()
171    }
172    fn maximum_value() -> Self {
173        // Seems to cause problem for some GPU
174        i16::MAX - 1
175    }
176    fn minimum_value() -> Self {
177        // Seems to cause problem for some GPU
178        i16::MIN + 1
179    }
180}
181
182impl CubeElement for i8 {
183    fn type_name() -> &'static str {
184        "i8"
185    }
186    fn as_bytes(slice: &[Self]) -> &[u8] {
187        bytemuck::cast_slice(slice)
188    }
189    fn from_bytes(bytes: &[u8]) -> &[Self] {
190        bytemuck::cast_slice(bytes)
191    }
192    fn cube_type() -> StorageType {
193        ElemType::Int(IntKind::I8).into()
194    }
195    fn maximum_value() -> Self {
196        // Seems to cause problem for some GPU
197        i8::MAX - 1
198    }
199    fn minimum_value() -> Self {
200        // Seems to cause problem for some GPU
201        i8::MIN + 1
202    }
203}
204
205impl CubeElement for f64 {
206    fn type_name() -> &'static str {
207        "f64"
208    }
209    fn as_bytes(slice: &[Self]) -> &[u8] {
210        bytemuck::cast_slice(slice)
211    }
212    fn from_bytes(bytes: &[u8]) -> &[Self] {
213        bytemuck::cast_slice(bytes)
214    }
215    fn cube_type() -> StorageType {
216        ElemType::Float(FloatKind::F64).into()
217    }
218    fn maximum_value() -> Self {
219        f64::MAX
220    }
221    fn minimum_value() -> Self {
222        f64::MIN
223    }
224}
225
226impl CubeElement for f32 {
227    fn type_name() -> &'static str {
228        "f32"
229    }
230    fn as_bytes(slice: &[Self]) -> &[u8] {
231        bytemuck::cast_slice(slice)
232    }
233    fn from_bytes(bytes: &[u8]) -> &[Self] {
234        bytemuck::cast_slice(bytes)
235    }
236    fn cube_type() -> StorageType {
237        ElemType::Float(FloatKind::F32).into()
238    }
239    fn maximum_value() -> Self {
240        f32::MAX
241    }
242    fn minimum_value() -> Self {
243        f32::MIN
244    }
245}
246
247impl CubeElement for half::f16 {
248    fn type_name() -> &'static str {
249        "f16"
250    }
251    fn as_bytes(slice: &[Self]) -> &[u8] {
252        bytemuck::cast_slice(slice)
253    }
254    fn from_bytes(bytes: &[u8]) -> &[Self] {
255        bytemuck::cast_slice(bytes)
256    }
257    fn cube_type() -> StorageType {
258        ElemType::Float(FloatKind::F16).into()
259    }
260    fn maximum_value() -> Self {
261        half::f16::MAX
262    }
263    fn minimum_value() -> Self {
264        half::f16::MIN
265    }
266}
267
268impl CubeElement for half::bf16 {
269    fn type_name() -> &'static str {
270        "bf16"
271    }
272    fn as_bytes(slice: &[Self]) -> &[u8] {
273        bytemuck::cast_slice(slice)
274    }
275    fn from_bytes(bytes: &[u8]) -> &[Self] {
276        bytemuck::cast_slice(bytes)
277    }
278    fn cube_type() -> StorageType {
279        ElemType::Float(FloatKind::BF16).into()
280    }
281    fn maximum_value() -> Self {
282        half::bf16::MAX
283    }
284    fn minimum_value() -> Self {
285        half::bf16::MIN
286    }
287}
288
289impl CubeElement for flex32 {
290    fn type_name() -> &'static str {
291        "flex32"
292    }
293    fn as_bytes(slice: &[Self]) -> &[u8] {
294        bytemuck::cast_slice(slice)
295    }
296    fn from_bytes(bytes: &[u8]) -> &[Self] {
297        bytemuck::cast_slice(bytes)
298    }
299    fn cube_type() -> StorageType {
300        ElemType::Float(FloatKind::Flex32).into()
301    }
302    fn maximum_value() -> Self {
303        <flex32 as num_traits::Float>::max_value()
304    }
305    fn minimum_value() -> Self {
306        <flex32 as num_traits::Float>::min_value()
307    }
308}
309
310impl CubeElement for tf32 {
311    fn type_name() -> &'static str {
312        "tf32"
313    }
314
315    fn as_bytes(slice: &[Self]) -> &[u8] {
316        bytemuck::cast_slice(slice)
317    }
318
319    fn from_bytes(bytes: &[u8]) -> &[Self] {
320        bytemuck::cast_slice(bytes)
321    }
322
323    fn cube_type() -> StorageType {
324        ElemType::Float(FloatKind::TF32).into()
325    }
326
327    fn maximum_value() -> Self {
328        tf32::max_value()
329    }
330
331    fn minimum_value() -> Self {
332        tf32::min_value()
333    }
334}
335
336impl CubeElement for e4m3 {
337    fn type_name() -> &'static str {
338        "e4m3"
339    }
340
341    fn as_bytes(slice: &[Self]) -> &[u8] {
342        bytemuck::cast_slice(slice)
343    }
344
345    fn from_bytes(bytes: &[u8]) -> &[Self] {
346        bytemuck::cast_slice(bytes)
347    }
348
349    fn cube_type() -> StorageType {
350        ElemType::Float(FloatKind::E4M3).into()
351    }
352
353    fn maximum_value() -> Self {
354        e4m3::max_value()
355    }
356
357    fn minimum_value() -> Self {
358        e4m3::min_value()
359    }
360}
361
362impl CubeElement for e5m2 {
363    fn type_name() -> &'static str {
364        "e5m2"
365    }
366
367    fn as_bytes(slice: &[Self]) -> &[u8] {
368        bytemuck::cast_slice(slice)
369    }
370
371    fn from_bytes(bytes: &[u8]) -> &[Self] {
372        bytemuck::cast_slice(bytes)
373    }
374
375    fn cube_type() -> StorageType {
376        ElemType::Float(FloatKind::E5M2).into()
377    }
378
379    fn maximum_value() -> Self {
380        e5m2::max_value()
381    }
382
383    fn minimum_value() -> Self {
384        e5m2::min_value()
385    }
386}
387
388impl CubeElement for ue8m0 {
389    fn type_name() -> &'static str {
390        "ue8m0"
391    }
392
393    fn as_bytes(slice: &[Self]) -> &[u8] {
394        bytemuck::cast_slice(slice)
395    }
396
397    fn from_bytes(bytes: &[u8]) -> &[Self] {
398        bytemuck::cast_slice(bytes)
399    }
400
401    fn cube_type() -> StorageType {
402        ElemType::Float(FloatKind::UE8M0).into()
403    }
404
405    fn maximum_value() -> Self {
406        ue8m0::max_value()
407    }
408
409    fn minimum_value() -> Self {
410        ue8m0::min_value()
411    }
412}
413
414impl CubeElement for e2m1x2 {
415    fn type_name() -> &'static str {
416        "e2m1x2"
417    }
418
419    fn as_bytes(slice: &[Self]) -> &[u8] {
420        bytemuck::cast_slice(slice)
421    }
422
423    fn from_bytes(bytes: &[u8]) -> &[Self] {
424        bytemuck::cast_slice(bytes)
425    }
426
427    fn cube_type() -> StorageType {
428        StorageType::Packed(ElemType::Float(FloatKind::E2M1), 2)
429    }
430
431    fn maximum_value() -> Self {
432        let max = e2m1::MAX.to_bits() as u8;
433        e2m1x2::from_bits(max << 4 | max)
434    }
435
436    fn minimum_value() -> Self {
437        let min = e2m1::MIN.to_bits() as u8;
438        e2m1x2::from_bits(min << 4 | min)
439    }
440}