1use core::cmp::Ordering;
2
3use crate::{Distribution, cast::ToElement, quantization::QuantScheme};
4#[cfg(feature = "cubecl")]
5use cubecl::flex32;
6
7use cubecl_quant::scheme::{QuantStore, QuantValue};
8use half::{bf16, f16};
9use rand::RngCore;
10use serde::{Deserialize, Serialize};
11
12pub trait Element:
14 ToElement
15 + ElementRandom
16 + ElementConversion
17 + ElementComparison
18 + ElementLimits
19 + bytemuck::CheckedBitPattern
20 + bytemuck::NoUninit
21 + bytemuck::Zeroable
22 + core::fmt::Debug
23 + core::fmt::Display
24 + Default
25 + Send
26 + Sync
27 + Copy
28 + 'static
29{
30 fn dtype() -> DType;
32}
33
34pub trait ElementConversion {
36 fn from_elem<E: ToElement>(elem: E) -> Self;
46
47 fn elem<E: Element>(self) -> E;
49}
50
51pub trait ElementRandom {
53 fn random<R: RngCore>(distribution: Distribution, rng: &mut R) -> Self;
64}
65
66pub trait ElementComparison {
68 fn cmp(&self, other: &Self) -> Ordering;
70}
71
72pub trait ElementLimits {
74 const MIN: Self;
76 const MAX: Self;
78}
79
80#[macro_export]
82macro_rules! make_element {
83 (
84 ty $type:ident,
85 convert $convert:expr,
86 random $random:expr,
87 cmp $cmp:expr,
88 dtype $dtype:expr
89 ) => {
90 make_element!(ty $type, convert $convert, random $random, cmp $cmp, dtype $dtype, min $type::MIN, max $type::MAX);
91 };
92 (
93 ty $type:ident,
94 convert $convert:expr,
95 random $random:expr,
96 cmp $cmp:expr,
97 dtype $dtype:expr,
98 min $min:expr,
99 max $max:expr
100 ) => {
101 impl Element for $type {
102 #[inline(always)]
103 fn dtype() -> $crate::DType {
104 $dtype
105 }
106 }
107
108 impl ElementConversion for $type {
109 #[inline(always)]
110 fn from_elem<E: ToElement>(elem: E) -> Self {
111 #[allow(clippy::redundant_closure_call)]
112 $convert(&elem)
113 }
114 #[inline(always)]
115 fn elem<E: Element>(self) -> E {
116 E::from_elem(self)
117 }
118 }
119
120 impl ElementRandom for $type {
121 fn random<R: RngCore>(distribution: Distribution, rng: &mut R) -> Self {
122 #[allow(clippy::redundant_closure_call)]
123 $random(distribution, rng)
124 }
125 }
126
127 impl ElementComparison for $type {
128 fn cmp(&self, other: &Self) -> Ordering {
129 let a = self.elem::<$type>();
130 let b = other.elem::<$type>();
131 #[allow(clippy::redundant_closure_call)]
132 $cmp(&a, &b)
133 }
134 }
135
136 impl ElementLimits for $type {
137 const MIN: Self = $min;
138 const MAX: Self = $max;
139 }
140 };
141}
142
143make_element!(
144 ty f64,
145 convert ToElement::to_f64,
146 random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(),
147 cmp |a: &f64, b: &f64| a.total_cmp(b),
148 dtype DType::F64
149);
150
151make_element!(
152 ty f32,
153 convert ToElement::to_f32,
154 random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(),
155 cmp |a: &f32, b: &f32| a.total_cmp(b),
156 dtype DType::F32
157);
158
159make_element!(
160 ty i64,
161 convert ToElement::to_i64,
162 random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(),
163 cmp |a: &i64, b: &i64| Ord::cmp(a, b),
164 dtype DType::I64
165);
166
167make_element!(
168 ty u64,
169 convert ToElement::to_u64,
170 random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(),
171 cmp |a: &u64, b: &u64| Ord::cmp(a, b),
172 dtype DType::U64
173);
174
175make_element!(
176 ty i32,
177 convert ToElement::to_i32,
178 random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(),
179 cmp |a: &i32, b: &i32| Ord::cmp(a, b),
180 dtype DType::I32
181);
182
183make_element!(
184 ty u32,
185 convert ToElement::to_u32,
186 random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(),
187 cmp |a: &u32, b: &u32| Ord::cmp(a, b),
188 dtype DType::U32
189);
190
191make_element!(
192 ty i16,
193 convert ToElement::to_i16,
194 random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(),
195 cmp |a: &i16, b: &i16| Ord::cmp(a, b),
196 dtype DType::I16
197);
198
199make_element!(
200 ty u16,
201 convert ToElement::to_u16,
202 random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(),
203 cmp |a: &u16, b: &u16| Ord::cmp(a, b),
204 dtype DType::U16
205);
206
207make_element!(
208 ty i8,
209 convert ToElement::to_i8,
210 random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(),
211 cmp |a: &i8, b: &i8| Ord::cmp(a, b),
212 dtype DType::I8
213);
214
215make_element!(
216 ty u8,
217 convert ToElement::to_u8,
218 random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(),
219 cmp |a: &u8, b: &u8| Ord::cmp(a, b),
220 dtype DType::U8
221);
222
223make_element!(
224 ty f16,
225 convert ToElement::to_f16,
226 random |distribution: Distribution, rng: &mut R| {
227 let sample: f32 = distribution.sampler(rng).sample();
228 f16::from_elem(sample)
229 },
230 cmp |a: &f16, b: &f16| a.total_cmp(b),
231 dtype DType::F16
232);
233make_element!(
234 ty bf16,
235 convert ToElement::to_bf16,
236 random |distribution: Distribution, rng: &mut R| {
237 let sample: f32 = distribution.sampler(rng).sample();
238 bf16::from_elem(sample)
239 },
240 cmp |a: &bf16, b: &bf16| a.total_cmp(b),
241 dtype DType::BF16
242);
243
244#[cfg(feature = "cubecl")]
245make_element!(
246 ty flex32,
247 convert |elem: &dyn ToElement| flex32::from_f32(elem.to_f32()),
248 random |distribution: Distribution, rng: &mut R| {
249 let sample: f32 = distribution.sampler(rng).sample();
250 flex32::from_elem(sample)
251 },
252 cmp |a: &flex32, b: &flex32| a.total_cmp(b),
253 dtype DType::Flex32,
254 min flex32::from_f32(half::f16::MIN.to_f32_const()),
255 max flex32::from_f32(half::f16::MAX.to_f32_const())
256);
257
258make_element!(
259 ty bool,
260 convert ToElement::to_bool,
261 random |distribution: Distribution, rng: &mut R| {
262 let sample: u8 = distribution.sampler(rng).sample();
263 bool::from_elem(sample)
264 },
265 cmp |a: &bool, b: &bool| Ord::cmp(a, b),
266 dtype DType::Bool,
267 min false,
268 max true
269);
270
271#[allow(missing_docs)]
272#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq, Serialize, Deserialize)]
273pub enum DType {
274 F64,
275 F32,
276 Flex32,
277 F16,
278 BF16,
279 I64,
280 I32,
281 I16,
282 I8,
283 U64,
284 U32,
285 U16,
286 U8,
287 Bool,
288 QFloat(QuantScheme),
289}
290
291#[cfg(feature = "cubecl")]
292impl From<cubecl::ir::ElemType> for DType {
293 fn from(value: cubecl::ir::ElemType) -> Self {
294 match value {
295 cubecl::ir::ElemType::Float(float_kind) => match float_kind {
296 cubecl::ir::FloatKind::F16 => DType::F16,
297 cubecl::ir::FloatKind::BF16 => DType::BF16,
298 cubecl::ir::FloatKind::Flex32 => DType::Flex32,
299 cubecl::ir::FloatKind::F32 => DType::F32,
300 cubecl::ir::FloatKind::F64 => DType::F64,
301 cubecl::ir::FloatKind::TF32 => panic!("Not a valid DType for tensors."),
302 cubecl::ir::FloatKind::E2M1
303 | cubecl::ir::FloatKind::E2M3
304 | cubecl::ir::FloatKind::E3M2
305 | cubecl::ir::FloatKind::E4M3
306 | cubecl::ir::FloatKind::E5M2
307 | cubecl::ir::FloatKind::UE8M0 => {
308 unimplemented!("Not yet supported, will be used for quantization")
309 }
310 },
311 cubecl::ir::ElemType::Int(int_kind) => match int_kind {
312 cubecl::ir::IntKind::I8 => DType::I8,
313 cubecl::ir::IntKind::I16 => DType::I16,
314 cubecl::ir::IntKind::I32 => DType::I32,
315 cubecl::ir::IntKind::I64 => DType::I64,
316 },
317 cubecl::ir::ElemType::UInt(uint_kind) => match uint_kind {
318 cubecl::ir::UIntKind::U8 => DType::U8,
319 cubecl::ir::UIntKind::U16 => DType::U16,
320 cubecl::ir::UIntKind::U32 => DType::U32,
321 cubecl::ir::UIntKind::U64 => DType::U64,
322 },
323 _ => panic!("Not a valid DType for tensors."),
324 }
325 }
326}
327
328impl DType {
329 pub const fn size(&self) -> usize {
331 match self {
332 DType::F64 => core::mem::size_of::<f64>(),
333 DType::F32 => core::mem::size_of::<f32>(),
334 DType::Flex32 => core::mem::size_of::<f32>(),
335 DType::F16 => core::mem::size_of::<f16>(),
336 DType::BF16 => core::mem::size_of::<bf16>(),
337 DType::I64 => core::mem::size_of::<i64>(),
338 DType::I32 => core::mem::size_of::<i32>(),
339 DType::I16 => core::mem::size_of::<i16>(),
340 DType::I8 => core::mem::size_of::<i8>(),
341 DType::U64 => core::mem::size_of::<u64>(),
342 DType::U32 => core::mem::size_of::<u32>(),
343 DType::U16 => core::mem::size_of::<u16>(),
344 DType::U8 => core::mem::size_of::<u8>(),
345 DType::Bool => core::mem::size_of::<bool>(),
346 DType::QFloat(scheme) => match scheme.store {
347 QuantStore::Native => match scheme.value {
348 QuantValue::Q8F | QuantValue::Q8S => core::mem::size_of::<i8>(),
349 QuantValue::E4M3 | QuantValue::E5M2 | QuantValue::E2M1 => {
352 core::mem::size_of::<u8>()
353 }
354 QuantValue::Q4F | QuantValue::Q4S | QuantValue::Q2F | QuantValue::Q2S => {
355 0
357 }
358 },
359 QuantStore::U32 => core::mem::size_of::<u32>(),
360 },
361 }
362 }
363 pub fn is_float(&self) -> bool {
365 matches!(
366 self,
367 DType::F64 | DType::F32 | DType::Flex32 | DType::F16 | DType::BF16
368 )
369 }
370 pub fn is_int(&self) -> bool {
372 matches!(self, DType::I64 | DType::I32 | DType::I16 | DType::I8)
373 }
374 pub fn is_uint(&self) -> bool {
376 matches!(self, DType::U64 | DType::U32 | DType::U16 | DType::U8)
377 }
378
379 pub fn is_bool(&self) -> bool {
381 matches!(self, DType::Bool)
382 }
383
384 pub fn name(&self) -> &'static str {
386 match self {
387 DType::F64 => "f64",
388 DType::F32 => "f32",
389 DType::Flex32 => "flex32",
390 DType::F16 => "f16",
391 DType::BF16 => "bf16",
392 DType::I64 => "i64",
393 DType::I32 => "i32",
394 DType::I16 => "i16",
395 DType::I8 => "i8",
396 DType::U64 => "u64",
397 DType::U32 => "u32",
398 DType::U16 => "u16",
399 DType::U8 => "u8",
400 DType::Bool => "bool",
401 DType::QFloat(_) => "qfloat",
402 }
403 }
404}
405
406#[allow(missing_docs)]
407#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
408pub enum FloatDType {
409 F64,
410 F32,
411 Flex32,
412 F16,
413 BF16,
414}
415
416impl From<DType> for FloatDType {
417 fn from(value: DType) -> Self {
418 match value {
419 DType::F64 => FloatDType::F64,
420 DType::F32 => FloatDType::F32,
421 DType::Flex32 => FloatDType::Flex32,
422 DType::F16 => FloatDType::F16,
423 DType::BF16 => FloatDType::BF16,
424 _ => panic!("Expected float data type, got {value:?}"),
425 }
426 }
427}
428
429impl From<FloatDType> for DType {
430 fn from(value: FloatDType) -> Self {
431 match value {
432 FloatDType::F64 => DType::F64,
433 FloatDType::F32 => DType::F32,
434 FloatDType::Flex32 => DType::Flex32,
435 FloatDType::F16 => DType::F16,
436 FloatDType::BF16 => DType::BF16,
437 }
438 }
439}
440
441#[allow(missing_docs)]
442#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
443pub enum IntDType {
444 I64,
445 I32,
446 I16,
447 I8,
448 U64,
449 U32,
450 U16,
451 U8,
452}
453
454impl From<DType> for IntDType {
455 fn from(value: DType) -> Self {
456 match value {
457 DType::I64 => IntDType::I64,
458 DType::I32 => IntDType::I32,
459 DType::I16 => IntDType::I16,
460 DType::I8 => IntDType::I8,
461 DType::U64 => IntDType::U64,
462 DType::U32 => IntDType::U32,
463 DType::U16 => IntDType::U16,
464 DType::U8 => IntDType::U8,
465 _ => panic!("Expected int data type, got {value:?}"),
466 }
467 }
468}
469
470impl From<IntDType> for DType {
471 fn from(value: IntDType) -> Self {
472 match value {
473 IntDType::I64 => DType::I64,
474 IntDType::I32 => DType::I32,
475 IntDType::I16 => DType::I16,
476 IntDType::I8 => DType::I8,
477 IntDType::U64 => DType::U64,
478 IntDType::U32 => DType::U32,
479 IntDType::U16 => DType::U16,
480 IntDType::U8 => DType::U8,
481 }
482 }
483}