1use core::cmp::Ordering;
2
3use crate::{
4 Distribution,
5 cast::ToElement,
6 quantization::{QuantizationScheme, QuantizationType},
7};
8#[cfg(feature = "cubecl")]
9use cubecl::flex32;
10use half::{bf16, f16};
11use rand::RngCore;
12use serde::{Deserialize, Serialize};
13
14pub trait Element:
16 ToElement
17 + ElementRandom
18 + ElementConversion
19 + ElementPrecision
20 + ElementComparison
21 + ElementLimits
22 + bytemuck::CheckedBitPattern
23 + bytemuck::NoUninit
24 + bytemuck::Zeroable
25 + core::fmt::Debug
26 + core::fmt::Display
27 + Default
28 + Send
29 + Sync
30 + Copy
31 + 'static
32{
33 fn dtype() -> DType;
35}
36
37pub trait ElementConversion {
39 fn from_elem<E: ToElement>(elem: E) -> Self;
49
50 fn elem<E: Element>(self) -> E;
52}
53
54pub trait ElementRandom {
56 fn random<R: RngCore>(distribution: Distribution, rng: &mut R) -> Self;
67}
68
69pub trait ElementComparison {
71 fn cmp(&self, other: &Self) -> Ordering;
73}
74
75pub trait ElementLimits {
77 const MIN: Self;
79 const MAX: Self;
81}
82
83#[derive(Clone, PartialEq, Eq, Copy, Debug)]
85pub enum Precision {
86 Double,
88
89 Full,
91
92 Half,
94
95 Other,
97}
98
99pub trait ElementPrecision {
101 fn precision() -> Precision;
103}
104
105#[macro_export]
107macro_rules! make_element {
108 (
109 ty $type:ident $precision:expr,
110 convert $convert:expr,
111 random $random:expr,
112 cmp $cmp:expr,
113 dtype $dtype:expr
114 ) => {
115 make_element!(ty $type $precision, convert $convert, random $random, cmp $cmp, dtype $dtype, min $type::MIN, max $type::MAX);
116 };
117 (
118 ty $type:ident $precision:expr,
119 convert $convert:expr,
120 random $random:expr,
121 cmp $cmp:expr,
122 dtype $dtype:expr,
123 min $min:expr,
124 max $max:expr
125 ) => {
126 impl Element for $type {
127 #[inline(always)]
128 fn dtype() -> $crate::DType {
129 $dtype
130 }
131 }
132
133 impl ElementConversion for $type {
134 #[inline(always)]
135 fn from_elem<E: ToElement>(elem: E) -> Self {
136 #[allow(clippy::redundant_closure_call)]
137 $convert(&elem)
138 }
139 #[inline(always)]
140 fn elem<E: Element>(self) -> E {
141 E::from_elem(self)
142 }
143 }
144
145 impl ElementPrecision for $type {
146 fn precision() -> Precision {
147 $precision
148 }
149 }
150
151 impl ElementRandom for $type {
152 fn random<R: RngCore>(distribution: Distribution, rng: &mut R) -> Self {
153 #[allow(clippy::redundant_closure_call)]
154 $random(distribution, rng)
155 }
156 }
157
158 impl ElementComparison for $type {
159 fn cmp(&self, other: &Self) -> Ordering {
160 let a = self.elem::<$type>();
161 let b = other.elem::<$type>();
162 #[allow(clippy::redundant_closure_call)]
163 $cmp(&a, &b)
164 }
165 }
166
167 impl ElementLimits for $type {
168 const MIN: Self = $min;
169 const MAX: Self = $max;
170 }
171 };
172}
173
174make_element!(
175 ty f64 Precision::Double,
176 convert ToElement::to_f64,
177 random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(),
178 cmp |a: &f64, b: &f64| a.total_cmp(b),
179 dtype DType::F64
180);
181
182make_element!(
183 ty f32 Precision::Full,
184 convert ToElement::to_f32,
185 random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(),
186 cmp |a: &f32, b: &f32| a.total_cmp(b),
187 dtype DType::F32
188);
189
190make_element!(
191 ty i64 Precision::Double,
192 convert ToElement::to_i64,
193 random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(),
194 cmp |a: &i64, b: &i64| Ord::cmp(a, b),
195 dtype DType::I64
196);
197
198make_element!(
199 ty u64 Precision::Double,
200 convert ToElement::to_u64,
201 random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(),
202 cmp |a: &u64, b: &u64| Ord::cmp(a, b),
203 dtype DType::U64
204);
205
206make_element!(
207 ty i32 Precision::Full,
208 convert ToElement::to_i32,
209 random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(),
210 cmp |a: &i32, b: &i32| Ord::cmp(a, b),
211 dtype DType::I32
212);
213
214make_element!(
215 ty u32 Precision::Full,
216 convert ToElement::to_u32,
217 random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(),
218 cmp |a: &u32, b: &u32| Ord::cmp(a, b),
219 dtype DType::U32
220);
221
222make_element!(
223 ty i16 Precision::Half,
224 convert ToElement::to_i16,
225 random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(),
226 cmp |a: &i16, b: &i16| Ord::cmp(a, b),
227 dtype DType::I16
228);
229
230make_element!(
231 ty u16 Precision::Half,
232 convert ToElement::to_u16,
233 random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(),
234 cmp |a: &u16, b: &u16| Ord::cmp(a, b),
235 dtype DType::U16
236);
237
238make_element!(
239 ty i8 Precision::Other,
240 convert ToElement::to_i8,
241 random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(),
242 cmp |a: &i8, b: &i8| Ord::cmp(a, b),
243 dtype DType::I8
244);
245
246make_element!(
247 ty u8 Precision::Other,
248 convert ToElement::to_u8,
249 random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(),
250 cmp |a: &u8, b: &u8| Ord::cmp(a, b),
251 dtype DType::U8
252);
253
254make_element!(
255 ty f16 Precision::Half,
256 convert ToElement::to_f16,
257 random |distribution: Distribution, rng: &mut R| {
258 let sample: f32 = distribution.sampler(rng).sample();
259 f16::from_elem(sample)
260 },
261 cmp |a: &f16, b: &f16| a.total_cmp(b),
262 dtype DType::F16
263);
264make_element!(
265 ty bf16 Precision::Half,
266 convert ToElement::to_bf16,
267 random |distribution: Distribution, rng: &mut R| {
268 let sample: f32 = distribution.sampler(rng).sample();
269 bf16::from_elem(sample)
270 },
271 cmp |a: &bf16, b: &bf16| a.total_cmp(b),
272 dtype DType::BF16
273);
274
275#[cfg(feature = "cubecl")]
276make_element!(
277 ty flex32 Precision::Half,
278 convert |elem: &dyn ToElement| flex32::from_f32(elem.to_f32()),
279 random |distribution: Distribution, rng: &mut R| {
280 let sample: f32 = distribution.sampler(rng).sample();
281 flex32::from_elem(sample)
282 },
283 cmp |a: &flex32, b: &flex32| a.total_cmp(b),
284 dtype DType::Flex32,
285 min flex32::from_f32(half::f16::MIN.to_f32_const()),
286 max flex32::from_f32(half::f16::MAX.to_f32_const())
287);
288
289make_element!(
290 ty bool Precision::Other,
291 convert ToElement::to_bool,
292 random |distribution: Distribution, rng: &mut R| {
293 let sample: u8 = distribution.sampler(rng).sample();
294 bool::from_elem(sample)
295 },
296 cmp |a: &bool, b: &bool| Ord::cmp(a, b),
297 dtype DType::Bool,
298 min false,
299 max true
300);
301
302#[allow(missing_docs)]
303#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq, Serialize, Deserialize)]
304pub enum DType {
305 F64,
306 F32,
307 Flex32,
308 F16,
309 BF16,
310 I64,
311 I32,
312 I16,
313 I8,
314 U64,
315 U32,
316 U16,
317 U8,
318 Bool,
319 QFloat(QuantizationScheme),
320}
321
322impl DType {
323 pub const fn size(&self) -> usize {
325 match self {
326 DType::F64 => core::mem::size_of::<f64>(),
327 DType::F32 => core::mem::size_of::<f32>(),
328 DType::Flex32 => core::mem::size_of::<f32>(),
329 DType::F16 => core::mem::size_of::<f16>(),
330 DType::BF16 => core::mem::size_of::<bf16>(),
331 DType::I64 => core::mem::size_of::<i64>(),
332 DType::I32 => core::mem::size_of::<i32>(),
333 DType::I16 => core::mem::size_of::<i16>(),
334 DType::I8 => core::mem::size_of::<i8>(),
335 DType::U64 => core::mem::size_of::<u64>(),
336 DType::U32 => core::mem::size_of::<u32>(),
337 DType::U16 => core::mem::size_of::<u16>(),
338 DType::U8 => core::mem::size_of::<u8>(),
339 DType::Bool => core::mem::size_of::<bool>(),
340 DType::QFloat(scheme) => match scheme {
341 QuantizationScheme::PerTensor(_mode, QuantizationType::QInt8) => {
342 core::mem::size_of::<i8>()
343 }
344 },
345 }
346 }
347 pub fn is_float(&self) -> bool {
349 matches!(self, DType::F64 | DType::F32 | DType::F16 | DType::BF16)
350 }
351 pub fn is_int(&self) -> bool {
353 matches!(self, DType::I64 | DType::I32 | DType::I16 | DType::I8)
354 }
355
356 pub fn is_bool(&self) -> bool {
358 matches!(self, DType::Bool)
359 }
360
361 pub fn name(&self) -> &'static str {
363 match self {
364 DType::F64 => "f64",
365 DType::F32 => "f32",
366 DType::Flex32 => "flex32",
367 DType::F16 => "f16",
368 DType::BF16 => "bf16",
369 DType::I64 => "i64",
370 DType::I32 => "i32",
371 DType::I16 => "i16",
372 DType::I8 => "i8",
373 DType::U64 => "u64",
374 DType::U32 => "u32",
375 DType::U16 => "u16",
376 DType::U8 => "u8",
377 DType::Bool => "bool",
378 DType::QFloat(_) => "qfloat",
379 }
380 }
381}
382
383#[allow(missing_docs)]
384#[derive(Debug, Clone)]
385pub enum FloatDType {
386 F64,
387 F32,
388 F16,
389 BF16,
390}
391
392impl From<DType> for FloatDType {
393 fn from(value: DType) -> Self {
394 match value {
395 DType::F64 => FloatDType::F64,
396 DType::F32 => FloatDType::F32,
397 DType::F16 => FloatDType::F16,
398 DType::BF16 => FloatDType::BF16,
399 _ => panic!("Expected float data type, got {value:?}"),
400 }
401 }
402}
403
404impl From<FloatDType> for DType {
405 fn from(value: FloatDType) -> Self {
406 match value {
407 FloatDType::F64 => DType::F64,
408 FloatDType::F32 => DType::F32,
409 FloatDType::F16 => DType::F16,
410 FloatDType::BF16 => DType::BF16,
411 }
412 }
413}