1use cubecl_common::{e2m1, e2m1x2, e4m3, e5m2, flex32, tf32, ue8m0};
2use cubecl_ir::StorageType;
3
4use crate::{
5 ir::{ElemType, FloatKind, IntKind, UIntKind},
6 prelude::{CubePrimitive, Numeric},
7};
8
9pub trait CubeElement: core::fmt::Debug + Send + Sync + 'static + Clone + bytemuck::Pod {
11 fn type_name() -> &'static str;
13 fn as_bytes(slice: &[Self]) -> &[u8];
15 fn from_bytes(bytes: &[u8]) -> &[Self];
17 fn cube_type() -> StorageType;
19 fn maximum_value() -> Self;
21 fn minimum_value() -> Self;
23}
24
25pub trait CubeScalar: CubeElement + CubePrimitive + num_traits::NumCast {}
26
27impl<E: CubeElement + CubePrimitive + num_traits::NumCast> CubeScalar for E {}
28
29impl CubeElement for u64 {
30 fn type_name() -> &'static str {
31 "u64"
32 }
33 fn as_bytes(slice: &[Self]) -> &[u8] {
34 bytemuck::cast_slice(slice)
35 }
36 fn from_bytes(bytes: &[u8]) -> &[Self] {
37 bytemuck::cast_slice(bytes)
38 }
39 fn cube_type() -> StorageType {
40 ElemType::UInt(UIntKind::U64).into()
41 }
42 fn maximum_value() -> Self {
43 u64::MAX
44 }
45 fn minimum_value() -> Self {
46 u64::MIN
47 }
48}
49
50impl CubeElement for u32 {
51 fn type_name() -> &'static str {
52 "u32"
53 }
54 fn as_bytes(slice: &[Self]) -> &[u8] {
55 bytemuck::cast_slice(slice)
56 }
57 fn from_bytes(bytes: &[u8]) -> &[Self] {
58 bytemuck::cast_slice(bytes)
59 }
60 fn cube_type() -> StorageType {
61 ElemType::UInt(UIntKind::U32).into()
62 }
63 fn maximum_value() -> Self {
64 u32::MAX
65 }
66 fn minimum_value() -> Self {
67 u32::MIN
68 }
69}
70
71impl CubeElement for u16 {
72 fn type_name() -> &'static str {
73 "u16"
74 }
75 fn as_bytes(slice: &[Self]) -> &[u8] {
76 bytemuck::cast_slice(slice)
77 }
78 fn from_bytes(bytes: &[u8]) -> &[Self] {
79 bytemuck::cast_slice(bytes)
80 }
81 fn cube_type() -> StorageType {
82 ElemType::UInt(UIntKind::U16).into()
83 }
84 fn maximum_value() -> Self {
85 u16::MAX
86 }
87 fn minimum_value() -> Self {
88 u16::MIN
89 }
90}
91
92impl CubeElement for u8 {
93 fn type_name() -> &'static str {
94 "u8"
95 }
96 fn as_bytes(slice: &[Self]) -> &[u8] {
97 bytemuck::cast_slice(slice)
98 }
99 fn from_bytes(bytes: &[u8]) -> &[Self] {
100 bytemuck::cast_slice(bytes)
101 }
102 fn cube_type() -> StorageType {
103 ElemType::UInt(UIntKind::U8).into()
104 }
105 fn maximum_value() -> Self {
106 u8::MAX
107 }
108 fn minimum_value() -> Self {
109 u8::MIN
110 }
111}
112
113impl CubeElement for i64 {
114 fn type_name() -> &'static str {
115 "i64"
116 }
117 fn as_bytes(slice: &[Self]) -> &[u8] {
118 bytemuck::cast_slice(slice)
119 }
120 fn from_bytes(bytes: &[u8]) -> &[Self] {
121 bytemuck::cast_slice(bytes)
122 }
123 fn cube_type() -> StorageType {
124 ElemType::Int(IntKind::I64).into()
125 }
126 fn maximum_value() -> Self {
127 i64::MAX - 1
129 }
130 fn minimum_value() -> Self {
131 i64::MIN + 1
133 }
134}
135
136impl CubeElement for i32 {
137 fn type_name() -> &'static str {
138 "i32"
139 }
140 fn as_bytes(slice: &[Self]) -> &[u8] {
141 bytemuck::cast_slice(slice)
142 }
143 fn from_bytes(bytes: &[u8]) -> &[Self] {
144 bytemuck::cast_slice(bytes)
145 }
146 fn cube_type() -> StorageType {
147 ElemType::Int(IntKind::I32).into()
148 }
149 fn maximum_value() -> Self {
150 i32::MAX - 1
152 }
153 fn minimum_value() -> Self {
154 i32::MIN + 1
156 }
157}
158
159impl CubeElement for i16 {
160 fn type_name() -> &'static str {
161 "i16"
162 }
163 fn as_bytes(slice: &[Self]) -> &[u8] {
164 bytemuck::cast_slice(slice)
165 }
166 fn from_bytes(bytes: &[u8]) -> &[Self] {
167 bytemuck::cast_slice(bytes)
168 }
169 fn cube_type() -> StorageType {
170 ElemType::Int(IntKind::I16).into()
171 }
172 fn maximum_value() -> Self {
173 i16::MAX - 1
175 }
176 fn minimum_value() -> Self {
177 i16::MIN + 1
179 }
180}
181
182impl CubeElement for i8 {
183 fn type_name() -> &'static str {
184 "i8"
185 }
186 fn as_bytes(slice: &[Self]) -> &[u8] {
187 bytemuck::cast_slice(slice)
188 }
189 fn from_bytes(bytes: &[u8]) -> &[Self] {
190 bytemuck::cast_slice(bytes)
191 }
192 fn cube_type() -> StorageType {
193 ElemType::Int(IntKind::I8).into()
194 }
195 fn maximum_value() -> Self {
196 i8::MAX - 1
198 }
199 fn minimum_value() -> Self {
200 i8::MIN + 1
202 }
203}
204
205impl CubeElement for f64 {
206 fn type_name() -> &'static str {
207 "f64"
208 }
209 fn as_bytes(slice: &[Self]) -> &[u8] {
210 bytemuck::cast_slice(slice)
211 }
212 fn from_bytes(bytes: &[u8]) -> &[Self] {
213 bytemuck::cast_slice(bytes)
214 }
215 fn cube_type() -> StorageType {
216 ElemType::Float(FloatKind::F64).into()
217 }
218 fn maximum_value() -> Self {
219 f64::MAX
220 }
221 fn minimum_value() -> Self {
222 f64::MIN
223 }
224}
225
226impl CubeElement for f32 {
227 fn type_name() -> &'static str {
228 "f32"
229 }
230 fn as_bytes(slice: &[Self]) -> &[u8] {
231 bytemuck::cast_slice(slice)
232 }
233 fn from_bytes(bytes: &[u8]) -> &[Self] {
234 bytemuck::cast_slice(bytes)
235 }
236 fn cube_type() -> StorageType {
237 ElemType::Float(FloatKind::F32).into()
238 }
239 fn maximum_value() -> Self {
240 f32::MAX
241 }
242 fn minimum_value() -> Self {
243 f32::MIN
244 }
245}
246
247impl CubeElement for half::f16 {
248 fn type_name() -> &'static str {
249 "f16"
250 }
251 fn as_bytes(slice: &[Self]) -> &[u8] {
252 bytemuck::cast_slice(slice)
253 }
254 fn from_bytes(bytes: &[u8]) -> &[Self] {
255 bytemuck::cast_slice(bytes)
256 }
257 fn cube_type() -> StorageType {
258 ElemType::Float(FloatKind::F16).into()
259 }
260 fn maximum_value() -> Self {
261 half::f16::MAX
262 }
263 fn minimum_value() -> Self {
264 half::f16::MIN
265 }
266}
267
268impl CubeElement for half::bf16 {
269 fn type_name() -> &'static str {
270 "bf16"
271 }
272 fn as_bytes(slice: &[Self]) -> &[u8] {
273 bytemuck::cast_slice(slice)
274 }
275 fn from_bytes(bytes: &[u8]) -> &[Self] {
276 bytemuck::cast_slice(bytes)
277 }
278 fn cube_type() -> StorageType {
279 ElemType::Float(FloatKind::BF16).into()
280 }
281 fn maximum_value() -> Self {
282 half::bf16::MAX
283 }
284 fn minimum_value() -> Self {
285 half::bf16::MIN
286 }
287}
288
289impl CubeElement for flex32 {
290 fn type_name() -> &'static str {
291 "flex32"
292 }
293 fn as_bytes(slice: &[Self]) -> &[u8] {
294 bytemuck::cast_slice(slice)
295 }
296 fn from_bytes(bytes: &[u8]) -> &[Self] {
297 bytemuck::cast_slice(bytes)
298 }
299 fn cube_type() -> StorageType {
300 ElemType::Float(FloatKind::Flex32).into()
301 }
302 fn maximum_value() -> Self {
303 <flex32 as num_traits::Float>::max_value()
304 }
305 fn minimum_value() -> Self {
306 <flex32 as num_traits::Float>::min_value()
307 }
308}
309
310impl CubeElement for tf32 {
311 fn type_name() -> &'static str {
312 "tf32"
313 }
314
315 fn as_bytes(slice: &[Self]) -> &[u8] {
316 bytemuck::cast_slice(slice)
317 }
318
319 fn from_bytes(bytes: &[u8]) -> &[Self] {
320 bytemuck::cast_slice(bytes)
321 }
322
323 fn cube_type() -> StorageType {
324 ElemType::Float(FloatKind::TF32).into()
325 }
326
327 fn maximum_value() -> Self {
328 tf32::max_value()
329 }
330
331 fn minimum_value() -> Self {
332 tf32::min_value()
333 }
334}
335
336impl CubeElement for e4m3 {
337 fn type_name() -> &'static str {
338 "e4m3"
339 }
340
341 fn as_bytes(slice: &[Self]) -> &[u8] {
342 bytemuck::cast_slice(slice)
343 }
344
345 fn from_bytes(bytes: &[u8]) -> &[Self] {
346 bytemuck::cast_slice(bytes)
347 }
348
349 fn cube_type() -> StorageType {
350 ElemType::Float(FloatKind::E4M3).into()
351 }
352
353 fn maximum_value() -> Self {
354 e4m3::max_value()
355 }
356
357 fn minimum_value() -> Self {
358 e4m3::min_value()
359 }
360}
361
362impl CubeElement for e5m2 {
363 fn type_name() -> &'static str {
364 "e5m2"
365 }
366
367 fn as_bytes(slice: &[Self]) -> &[u8] {
368 bytemuck::cast_slice(slice)
369 }
370
371 fn from_bytes(bytes: &[u8]) -> &[Self] {
372 bytemuck::cast_slice(bytes)
373 }
374
375 fn cube_type() -> StorageType {
376 ElemType::Float(FloatKind::E5M2).into()
377 }
378
379 fn maximum_value() -> Self {
380 e5m2::max_value()
381 }
382
383 fn minimum_value() -> Self {
384 e5m2::min_value()
385 }
386}
387
388impl CubeElement for ue8m0 {
389 fn type_name() -> &'static str {
390 "ue8m0"
391 }
392
393 fn as_bytes(slice: &[Self]) -> &[u8] {
394 bytemuck::cast_slice(slice)
395 }
396
397 fn from_bytes(bytes: &[u8]) -> &[Self] {
398 bytemuck::cast_slice(bytes)
399 }
400
401 fn cube_type() -> StorageType {
402 ElemType::Float(FloatKind::UE8M0).into()
403 }
404
405 fn maximum_value() -> Self {
406 ue8m0::max_value()
407 }
408
409 fn minimum_value() -> Self {
410 ue8m0::min_value()
411 }
412}
413
414impl CubeElement for e2m1x2 {
415 fn type_name() -> &'static str {
416 "e2m1x2"
417 }
418
419 fn as_bytes(slice: &[Self]) -> &[u8] {
420 bytemuck::cast_slice(slice)
421 }
422
423 fn from_bytes(bytes: &[u8]) -> &[Self] {
424 bytemuck::cast_slice(bytes)
425 }
426
427 fn cube_type() -> StorageType {
428 StorageType::Packed(ElemType::Float(FloatKind::E2M1), 2)
429 }
430
431 fn maximum_value() -> Self {
432 let max = e2m1::MAX.to_bits() as u8;
433 e2m1x2::from_bits(max << 4 | max)
434 }
435
436 fn minimum_value() -> Self {
437 let min = e2m1::MIN.to_bits() as u8;
438 e2m1x2::from_bits(min << 4 | min)
439 }
440}