1use cubecl_common::flex32;
2
3use crate::ir::{Elem, FloatKind, IntKind, UIntKind};
4
5pub trait CubeElement: core::fmt::Debug + Send + Sync + 'static + Clone + bytemuck::Pod {
7 fn type_name() -> &'static str;
9 fn as_bytes(slice: &[Self]) -> &[u8];
11 fn from_bytes(bytes: &[u8]) -> &[Self];
13 fn cube_elem() -> Elem;
15 fn maximum_value() -> Self;
17 fn minimum_value() -> Self;
19}
20
21impl CubeElement for u64 {
22 fn type_name() -> &'static str {
23 "u64"
24 }
25 fn as_bytes(slice: &[Self]) -> &[u8] {
26 bytemuck::cast_slice(slice)
27 }
28 fn from_bytes(bytes: &[u8]) -> &[Self] {
29 bytemuck::cast_slice(bytes)
30 }
31 fn cube_elem() -> Elem {
32 Elem::UInt(UIntKind::U64)
33 }
34 fn maximum_value() -> Self {
35 u64::MAX
36 }
37 fn minimum_value() -> Self {
38 u64::MIN
39 }
40}
41
42impl CubeElement for u32 {
43 fn type_name() -> &'static str {
44 "u32"
45 }
46 fn as_bytes(slice: &[Self]) -> &[u8] {
47 bytemuck::cast_slice(slice)
48 }
49 fn from_bytes(bytes: &[u8]) -> &[Self] {
50 bytemuck::cast_slice(bytes)
51 }
52 fn cube_elem() -> Elem {
53 Elem::UInt(UIntKind::U32)
54 }
55 fn maximum_value() -> Self {
56 u32::MAX
57 }
58 fn minimum_value() -> Self {
59 u32::MIN
60 }
61}
62
63impl CubeElement for u16 {
64 fn type_name() -> &'static str {
65 "u16"
66 }
67 fn as_bytes(slice: &[Self]) -> &[u8] {
68 bytemuck::cast_slice(slice)
69 }
70 fn from_bytes(bytes: &[u8]) -> &[Self] {
71 bytemuck::cast_slice(bytes)
72 }
73 fn cube_elem() -> Elem {
74 Elem::UInt(UIntKind::U16)
75 }
76 fn maximum_value() -> Self {
77 u16::MAX
78 }
79 fn minimum_value() -> Self {
80 u16::MIN
81 }
82}
83
84impl CubeElement for u8 {
85 fn type_name() -> &'static str {
86 "u8"
87 }
88 fn as_bytes(slice: &[Self]) -> &[u8] {
89 bytemuck::cast_slice(slice)
90 }
91 fn from_bytes(bytes: &[u8]) -> &[Self] {
92 bytemuck::cast_slice(bytes)
93 }
94 fn cube_elem() -> Elem {
95 Elem::UInt(UIntKind::U8)
96 }
97 fn maximum_value() -> Self {
98 u8::MAX
99 }
100 fn minimum_value() -> Self {
101 u8::MIN
102 }
103}
104
105impl CubeElement for i64 {
106 fn type_name() -> &'static str {
107 "i64"
108 }
109 fn as_bytes(slice: &[Self]) -> &[u8] {
110 bytemuck::cast_slice(slice)
111 }
112 fn from_bytes(bytes: &[u8]) -> &[Self] {
113 bytemuck::cast_slice(bytes)
114 }
115 fn cube_elem() -> Elem {
116 Elem::Int(IntKind::I64)
117 }
118 fn maximum_value() -> Self {
119 i64::MAX - 1
121 }
122 fn minimum_value() -> Self {
123 i64::MIN + 1
125 }
126}
127
128impl CubeElement for i32 {
129 fn type_name() -> &'static str {
130 "i32"
131 }
132 fn as_bytes(slice: &[Self]) -> &[u8] {
133 bytemuck::cast_slice(slice)
134 }
135 fn from_bytes(bytes: &[u8]) -> &[Self] {
136 bytemuck::cast_slice(bytes)
137 }
138 fn cube_elem() -> Elem {
139 Elem::Int(IntKind::I32)
140 }
141 fn maximum_value() -> Self {
142 i32::MAX - 1
144 }
145 fn minimum_value() -> Self {
146 i32::MIN + 1
148 }
149}
150
151impl CubeElement for i16 {
152 fn type_name() -> &'static str {
153 "i16"
154 }
155 fn as_bytes(slice: &[Self]) -> &[u8] {
156 bytemuck::cast_slice(slice)
157 }
158 fn from_bytes(bytes: &[u8]) -> &[Self] {
159 bytemuck::cast_slice(bytes)
160 }
161 fn cube_elem() -> Elem {
162 Elem::Int(IntKind::I16)
163 }
164 fn maximum_value() -> Self {
165 i16::MAX - 1
167 }
168 fn minimum_value() -> Self {
169 i16::MIN + 1
171 }
172}
173
174impl CubeElement for i8 {
175 fn type_name() -> &'static str {
176 "i8"
177 }
178 fn as_bytes(slice: &[Self]) -> &[u8] {
179 bytemuck::cast_slice(slice)
180 }
181 fn from_bytes(bytes: &[u8]) -> &[Self] {
182 bytemuck::cast_slice(bytes)
183 }
184 fn cube_elem() -> Elem {
185 Elem::Int(IntKind::I8)
186 }
187 fn maximum_value() -> Self {
188 i8::MAX - 1
190 }
191 fn minimum_value() -> Self {
192 i8::MIN + 1
194 }
195}
196
197impl CubeElement for f64 {
198 fn type_name() -> &'static str {
199 "f64"
200 }
201 fn as_bytes(slice: &[Self]) -> &[u8] {
202 bytemuck::cast_slice(slice)
203 }
204 fn from_bytes(bytes: &[u8]) -> &[Self] {
205 bytemuck::cast_slice(bytes)
206 }
207 fn cube_elem() -> Elem {
208 Elem::Float(FloatKind::F64)
209 }
210 fn maximum_value() -> Self {
211 f64::MAX
212 }
213 fn minimum_value() -> Self {
214 f64::MIN
215 }
216}
217
218impl CubeElement for f32 {
219 fn type_name() -> &'static str {
220 "f32"
221 }
222 fn as_bytes(slice: &[Self]) -> &[u8] {
223 bytemuck::cast_slice(slice)
224 }
225 fn from_bytes(bytes: &[u8]) -> &[Self] {
226 bytemuck::cast_slice(bytes)
227 }
228 fn cube_elem() -> Elem {
229 Elem::Float(FloatKind::F32)
230 }
231 fn maximum_value() -> Self {
232 f32::MAX
233 }
234 fn minimum_value() -> Self {
235 f32::MIN
236 }
237}
238
239impl CubeElement for half::f16 {
240 fn type_name() -> &'static str {
241 "f16"
242 }
243 fn as_bytes(slice: &[Self]) -> &[u8] {
244 bytemuck::cast_slice(slice)
245 }
246 fn from_bytes(bytes: &[u8]) -> &[Self] {
247 bytemuck::cast_slice(bytes)
248 }
249 fn cube_elem() -> Elem {
250 Elem::Float(FloatKind::F16)
251 }
252 fn maximum_value() -> Self {
253 half::f16::MAX
254 }
255 fn minimum_value() -> Self {
256 half::f16::MIN
257 }
258}
259
260impl CubeElement for half::bf16 {
261 fn type_name() -> &'static str {
262 "bf16"
263 }
264 fn as_bytes(slice: &[Self]) -> &[u8] {
265 bytemuck::cast_slice(slice)
266 }
267 fn from_bytes(bytes: &[u8]) -> &[Self] {
268 bytemuck::cast_slice(bytes)
269 }
270 fn cube_elem() -> Elem {
271 Elem::Float(FloatKind::BF16)
272 }
273 fn maximum_value() -> Self {
274 half::bf16::MAX
275 }
276 fn minimum_value() -> Self {
277 half::bf16::MIN
278 }
279}
280
281impl CubeElement for flex32 {
282 fn type_name() -> &'static str {
283 "flex32"
284 }
285 fn as_bytes(slice: &[Self]) -> &[u8] {
286 bytemuck::cast_slice(slice)
287 }
288 fn from_bytes(bytes: &[u8]) -> &[Self] {
289 bytemuck::cast_slice(bytes)
290 }
291 fn cube_elem() -> Elem {
292 Elem::Float(FloatKind::Flex32)
293 }
294 fn maximum_value() -> Self {
295 <flex32 as num_traits::Float>::max_value()
296 }
297 fn minimum_value() -> Self {
298 <flex32 as num_traits::Float>::min_value()
299 }
300}