cubecl_core/
pod.rs

1use cubecl_common::flex32;
2
3use crate::ir::{Elem, FloatKind, IntKind, UIntKind};
4
5/// The base element trait for the jit backend.
6pub trait CubeElement: core::fmt::Debug + Send + Sync + 'static + Clone + bytemuck::Pod {
7    /// Returns the name of the type.
8    fn type_name() -> &'static str;
9    /// Convert a slice of elements to a slice of bytes.
10    fn as_bytes(slice: &[Self]) -> &[u8];
11    /// Convert a slice of bytes to a slice of elements.
12    fn from_bytes(bytes: &[u8]) -> &[Self];
13    /// Element representation for `cubecl`.
14    fn cube_elem() -> Elem;
15    /// Highest possible value
16    fn maximum_value() -> Self;
17    /// Lowest possible value
18    fn minimum_value() -> Self;
19}
20
21impl CubeElement for u64 {
22    fn type_name() -> &'static str {
23        "u64"
24    }
25    fn as_bytes(slice: &[Self]) -> &[u8] {
26        bytemuck::cast_slice(slice)
27    }
28    fn from_bytes(bytes: &[u8]) -> &[Self] {
29        bytemuck::cast_slice(bytes)
30    }
31    fn cube_elem() -> Elem {
32        Elem::UInt(UIntKind::U64)
33    }
34    fn maximum_value() -> Self {
35        u64::MAX
36    }
37    fn minimum_value() -> Self {
38        u64::MIN
39    }
40}
41
42impl CubeElement for u32 {
43    fn type_name() -> &'static str {
44        "u32"
45    }
46    fn as_bytes(slice: &[Self]) -> &[u8] {
47        bytemuck::cast_slice(slice)
48    }
49    fn from_bytes(bytes: &[u8]) -> &[Self] {
50        bytemuck::cast_slice(bytes)
51    }
52    fn cube_elem() -> Elem {
53        Elem::UInt(UIntKind::U32)
54    }
55    fn maximum_value() -> Self {
56        u32::MAX
57    }
58    fn minimum_value() -> Self {
59        u32::MIN
60    }
61}
62
63impl CubeElement for u16 {
64    fn type_name() -> &'static str {
65        "u16"
66    }
67    fn as_bytes(slice: &[Self]) -> &[u8] {
68        bytemuck::cast_slice(slice)
69    }
70    fn from_bytes(bytes: &[u8]) -> &[Self] {
71        bytemuck::cast_slice(bytes)
72    }
73    fn cube_elem() -> Elem {
74        Elem::UInt(UIntKind::U16)
75    }
76    fn maximum_value() -> Self {
77        u16::MAX
78    }
79    fn minimum_value() -> Self {
80        u16::MIN
81    }
82}
83
84impl CubeElement for u8 {
85    fn type_name() -> &'static str {
86        "u8"
87    }
88    fn as_bytes(slice: &[Self]) -> &[u8] {
89        bytemuck::cast_slice(slice)
90    }
91    fn from_bytes(bytes: &[u8]) -> &[Self] {
92        bytemuck::cast_slice(bytes)
93    }
94    fn cube_elem() -> Elem {
95        Elem::UInt(UIntKind::U8)
96    }
97    fn maximum_value() -> Self {
98        u8::MAX
99    }
100    fn minimum_value() -> Self {
101        u8::MIN
102    }
103}
104
105impl CubeElement for i64 {
106    fn type_name() -> &'static str {
107        "i64"
108    }
109    fn as_bytes(slice: &[Self]) -> &[u8] {
110        bytemuck::cast_slice(slice)
111    }
112    fn from_bytes(bytes: &[u8]) -> &[Self] {
113        bytemuck::cast_slice(bytes)
114    }
115    fn cube_elem() -> Elem {
116        Elem::Int(IntKind::I64)
117    }
118    fn maximum_value() -> Self {
119        // Seems to cause problem for some GPU
120        i64::MAX - 1
121    }
122    fn minimum_value() -> Self {
123        // Seems to cause problem for some GPU
124        i64::MIN + 1
125    }
126}
127
128impl CubeElement for i32 {
129    fn type_name() -> &'static str {
130        "i32"
131    }
132    fn as_bytes(slice: &[Self]) -> &[u8] {
133        bytemuck::cast_slice(slice)
134    }
135    fn from_bytes(bytes: &[u8]) -> &[Self] {
136        bytemuck::cast_slice(bytes)
137    }
138    fn cube_elem() -> Elem {
139        Elem::Int(IntKind::I32)
140    }
141    fn maximum_value() -> Self {
142        // Seems to cause problem for some GPU
143        i32::MAX - 1
144    }
145    fn minimum_value() -> Self {
146        // Seems to cause problem for some GPU
147        i32::MIN + 1
148    }
149}
150
151impl CubeElement for i16 {
152    fn type_name() -> &'static str {
153        "i16"
154    }
155    fn as_bytes(slice: &[Self]) -> &[u8] {
156        bytemuck::cast_slice(slice)
157    }
158    fn from_bytes(bytes: &[u8]) -> &[Self] {
159        bytemuck::cast_slice(bytes)
160    }
161    fn cube_elem() -> Elem {
162        Elem::Int(IntKind::I16)
163    }
164    fn maximum_value() -> Self {
165        // Seems to cause problem for some GPU
166        i16::MAX - 1
167    }
168    fn minimum_value() -> Self {
169        // Seems to cause problem for some GPU
170        i16::MIN + 1
171    }
172}
173
174impl CubeElement for i8 {
175    fn type_name() -> &'static str {
176        "i8"
177    }
178    fn as_bytes(slice: &[Self]) -> &[u8] {
179        bytemuck::cast_slice(slice)
180    }
181    fn from_bytes(bytes: &[u8]) -> &[Self] {
182        bytemuck::cast_slice(bytes)
183    }
184    fn cube_elem() -> Elem {
185        Elem::Int(IntKind::I8)
186    }
187    fn maximum_value() -> Self {
188        // Seems to cause problem for some GPU
189        i8::MAX - 1
190    }
191    fn minimum_value() -> Self {
192        // Seems to cause problem for some GPU
193        i8::MIN + 1
194    }
195}
196
197impl CubeElement for f64 {
198    fn type_name() -> &'static str {
199        "f64"
200    }
201    fn as_bytes(slice: &[Self]) -> &[u8] {
202        bytemuck::cast_slice(slice)
203    }
204    fn from_bytes(bytes: &[u8]) -> &[Self] {
205        bytemuck::cast_slice(bytes)
206    }
207    fn cube_elem() -> Elem {
208        Elem::Float(FloatKind::F64)
209    }
210    fn maximum_value() -> Self {
211        f64::MAX
212    }
213    fn minimum_value() -> Self {
214        f64::MIN
215    }
216}
217
218impl CubeElement for f32 {
219    fn type_name() -> &'static str {
220        "f32"
221    }
222    fn as_bytes(slice: &[Self]) -> &[u8] {
223        bytemuck::cast_slice(slice)
224    }
225    fn from_bytes(bytes: &[u8]) -> &[Self] {
226        bytemuck::cast_slice(bytes)
227    }
228    fn cube_elem() -> Elem {
229        Elem::Float(FloatKind::F32)
230    }
231    fn maximum_value() -> Self {
232        f32::MAX
233    }
234    fn minimum_value() -> Self {
235        f32::MIN
236    }
237}
238
239impl CubeElement for half::f16 {
240    fn type_name() -> &'static str {
241        "f16"
242    }
243    fn as_bytes(slice: &[Self]) -> &[u8] {
244        bytemuck::cast_slice(slice)
245    }
246    fn from_bytes(bytes: &[u8]) -> &[Self] {
247        bytemuck::cast_slice(bytes)
248    }
249    fn cube_elem() -> Elem {
250        Elem::Float(FloatKind::F16)
251    }
252    fn maximum_value() -> Self {
253        half::f16::MAX
254    }
255    fn minimum_value() -> Self {
256        half::f16::MIN
257    }
258}
259
260impl CubeElement for half::bf16 {
261    fn type_name() -> &'static str {
262        "bf16"
263    }
264    fn as_bytes(slice: &[Self]) -> &[u8] {
265        bytemuck::cast_slice(slice)
266    }
267    fn from_bytes(bytes: &[u8]) -> &[Self] {
268        bytemuck::cast_slice(bytes)
269    }
270    fn cube_elem() -> Elem {
271        Elem::Float(FloatKind::BF16)
272    }
273    fn maximum_value() -> Self {
274        half::bf16::MAX
275    }
276    fn minimum_value() -> Self {
277        half::bf16::MIN
278    }
279}
280
281impl CubeElement for flex32 {
282    fn type_name() -> &'static str {
283        "flex32"
284    }
285    fn as_bytes(slice: &[Self]) -> &[u8] {
286        bytemuck::cast_slice(slice)
287    }
288    fn from_bytes(bytes: &[u8]) -> &[Self] {
289        bytemuck::cast_slice(bytes)
290    }
291    fn cube_elem() -> Elem {
292        Elem::Float(FloatKind::Flex32)
293    }
294    fn maximum_value() -> Self {
295        <flex32 as num_traits::Float>::max_value()
296    }
297    fn minimum_value() -> Self {
298        <flex32 as num_traits::Float>::min_value()
299    }
300}