cubecl_core/
pod.rs

1use crate::{
2    flex32,
3    ir::{Elem, FloatKind, IntKind, UIntKind},
4};
5
6/// The base element trait for the jit backend.
7pub trait CubeElement: core::fmt::Debug + Send + Sync + 'static + Clone + bytemuck::Pod {
8    /// Returns the name of the type.
9    fn type_name() -> &'static str;
10    /// Convert a slice of elements to a slice of bytes.
11    fn as_bytes(slice: &[Self]) -> &[u8];
12    /// Convert a slice of bytes to a slice of elements.
13    fn from_bytes(bytes: &[u8]) -> &[Self];
14    /// Element representation for `cubecl`.
15    fn cube_elem() -> Elem;
16    /// Highest possible value
17    fn maximum_value() -> Self;
18    /// Lowest possible value
19    fn minimum_value() -> Self;
20}
21
22impl CubeElement for u64 {
23    fn type_name() -> &'static str {
24        "u64"
25    }
26    fn as_bytes(slice: &[Self]) -> &[u8] {
27        bytemuck::cast_slice(slice)
28    }
29    fn from_bytes(bytes: &[u8]) -> &[Self] {
30        bytemuck::cast_slice(bytes)
31    }
32    fn cube_elem() -> Elem {
33        Elem::UInt(UIntKind::U64)
34    }
35    fn maximum_value() -> Self {
36        u64::MAX
37    }
38    fn minimum_value() -> Self {
39        u64::MIN
40    }
41}
42
43impl CubeElement for u32 {
44    fn type_name() -> &'static str {
45        "u32"
46    }
47    fn as_bytes(slice: &[Self]) -> &[u8] {
48        bytemuck::cast_slice(slice)
49    }
50    fn from_bytes(bytes: &[u8]) -> &[Self] {
51        bytemuck::cast_slice(bytes)
52    }
53    fn cube_elem() -> Elem {
54        Elem::UInt(UIntKind::U32)
55    }
56    fn maximum_value() -> Self {
57        u32::MAX
58    }
59    fn minimum_value() -> Self {
60        u32::MIN
61    }
62}
63
64impl CubeElement for u16 {
65    fn type_name() -> &'static str {
66        "u16"
67    }
68    fn as_bytes(slice: &[Self]) -> &[u8] {
69        bytemuck::cast_slice(slice)
70    }
71    fn from_bytes(bytes: &[u8]) -> &[Self] {
72        bytemuck::cast_slice(bytes)
73    }
74    fn cube_elem() -> Elem {
75        Elem::UInt(UIntKind::U16)
76    }
77    fn maximum_value() -> Self {
78        u16::MAX
79    }
80    fn minimum_value() -> Self {
81        u16::MIN
82    }
83}
84
85impl CubeElement for u8 {
86    fn type_name() -> &'static str {
87        "u8"
88    }
89    fn as_bytes(slice: &[Self]) -> &[u8] {
90        bytemuck::cast_slice(slice)
91    }
92    fn from_bytes(bytes: &[u8]) -> &[Self] {
93        bytemuck::cast_slice(bytes)
94    }
95    fn cube_elem() -> Elem {
96        Elem::UInt(UIntKind::U8)
97    }
98    fn maximum_value() -> Self {
99        u8::MAX
100    }
101    fn minimum_value() -> Self {
102        u8::MIN
103    }
104}
105
106impl CubeElement for i64 {
107    fn type_name() -> &'static str {
108        "i64"
109    }
110    fn as_bytes(slice: &[Self]) -> &[u8] {
111        bytemuck::cast_slice(slice)
112    }
113    fn from_bytes(bytes: &[u8]) -> &[Self] {
114        bytemuck::cast_slice(bytes)
115    }
116    fn cube_elem() -> Elem {
117        Elem::Int(IntKind::I64)
118    }
119    fn maximum_value() -> Self {
120        // Seems to cause problem for some GPU
121        i64::MAX - 1
122    }
123    fn minimum_value() -> Self {
124        // Seems to cause problem for some GPU
125        i64::MIN + 1
126    }
127}
128
129impl CubeElement for i32 {
130    fn type_name() -> &'static str {
131        "i32"
132    }
133    fn as_bytes(slice: &[Self]) -> &[u8] {
134        bytemuck::cast_slice(slice)
135    }
136    fn from_bytes(bytes: &[u8]) -> &[Self] {
137        bytemuck::cast_slice(bytes)
138    }
139    fn cube_elem() -> Elem {
140        Elem::Int(IntKind::I32)
141    }
142    fn maximum_value() -> Self {
143        // Seems to cause problem for some GPU
144        i32::MAX - 1
145    }
146    fn minimum_value() -> Self {
147        // Seems to cause problem for some GPU
148        i32::MIN + 1
149    }
150}
151
152impl CubeElement for i16 {
153    fn type_name() -> &'static str {
154        "i16"
155    }
156    fn as_bytes(slice: &[Self]) -> &[u8] {
157        bytemuck::cast_slice(slice)
158    }
159    fn from_bytes(bytes: &[u8]) -> &[Self] {
160        bytemuck::cast_slice(bytes)
161    }
162    fn cube_elem() -> Elem {
163        Elem::Int(IntKind::I16)
164    }
165    fn maximum_value() -> Self {
166        // Seems to cause problem for some GPU
167        i16::MAX - 1
168    }
169    fn minimum_value() -> Self {
170        // Seems to cause problem for some GPU
171        i16::MIN + 1
172    }
173}
174
175impl CubeElement for i8 {
176    fn type_name() -> &'static str {
177        "i8"
178    }
179    fn as_bytes(slice: &[Self]) -> &[u8] {
180        bytemuck::cast_slice(slice)
181    }
182    fn from_bytes(bytes: &[u8]) -> &[Self] {
183        bytemuck::cast_slice(bytes)
184    }
185    fn cube_elem() -> Elem {
186        Elem::Int(IntKind::I8)
187    }
188    fn maximum_value() -> Self {
189        // Seems to cause problem for some GPU
190        i8::MAX - 1
191    }
192    fn minimum_value() -> Self {
193        // Seems to cause problem for some GPU
194        i8::MIN + 1
195    }
196}
197
198impl CubeElement for f64 {
199    fn type_name() -> &'static str {
200        "f64"
201    }
202    fn as_bytes(slice: &[Self]) -> &[u8] {
203        bytemuck::cast_slice(slice)
204    }
205    fn from_bytes(bytes: &[u8]) -> &[Self] {
206        bytemuck::cast_slice(bytes)
207    }
208    fn cube_elem() -> Elem {
209        Elem::Float(FloatKind::F64)
210    }
211    fn maximum_value() -> Self {
212        f64::MAX
213    }
214    fn minimum_value() -> Self {
215        f64::MIN
216    }
217}
218
219impl CubeElement for f32 {
220    fn type_name() -> &'static str {
221        "f32"
222    }
223    fn as_bytes(slice: &[Self]) -> &[u8] {
224        bytemuck::cast_slice(slice)
225    }
226    fn from_bytes(bytes: &[u8]) -> &[Self] {
227        bytemuck::cast_slice(bytes)
228    }
229    fn cube_elem() -> Elem {
230        Elem::Float(FloatKind::F32)
231    }
232    fn maximum_value() -> Self {
233        f32::MAX
234    }
235    fn minimum_value() -> Self {
236        f32::MIN
237    }
238}
239
240impl CubeElement for half::f16 {
241    fn type_name() -> &'static str {
242        "f16"
243    }
244    fn as_bytes(slice: &[Self]) -> &[u8] {
245        bytemuck::cast_slice(slice)
246    }
247    fn from_bytes(bytes: &[u8]) -> &[Self] {
248        bytemuck::cast_slice(bytes)
249    }
250    fn cube_elem() -> Elem {
251        Elem::Float(FloatKind::F16)
252    }
253    fn maximum_value() -> Self {
254        half::f16::MAX
255    }
256    fn minimum_value() -> Self {
257        half::f16::MIN
258    }
259}
260
261impl CubeElement for half::bf16 {
262    fn type_name() -> &'static str {
263        "bf16"
264    }
265    fn as_bytes(slice: &[Self]) -> &[u8] {
266        bytemuck::cast_slice(slice)
267    }
268    fn from_bytes(bytes: &[u8]) -> &[Self] {
269        bytemuck::cast_slice(bytes)
270    }
271    fn cube_elem() -> Elem {
272        Elem::Float(FloatKind::BF16)
273    }
274    fn maximum_value() -> Self {
275        half::bf16::MAX
276    }
277    fn minimum_value() -> Self {
278        half::bf16::MIN
279    }
280}
281
282impl CubeElement for flex32 {
283    fn type_name() -> &'static str {
284        "flex32"
285    }
286    fn as_bytes(slice: &[Self]) -> &[u8] {
287        bytemuck::cast_slice(slice)
288    }
289    fn from_bytes(bytes: &[u8]) -> &[Self] {
290        bytemuck::cast_slice(bytes)
291    }
292    fn cube_elem() -> Elem {
293        Elem::Float(FloatKind::Flex32)
294    }
295    fn maximum_value() -> Self {
296        <flex32 as num_traits::Float>::max_value()
297    }
298    fn minimum_value() -> Self {
299        <flex32 as num_traits::Float>::min_value()
300    }
301}