1use core::cmp::Ordering;
2
3use crate::{
4 cast::ToElement,
5 quantization::{QuantizationScheme, QuantizationType},
6 Distribution,
7};
8#[cfg(feature = "cubecl")]
9use cubecl::flex32;
10use half::{bf16, f16};
11use rand::RngCore;
12use serde::{Deserialize, Serialize};
13
14pub trait Element:
16 ToElement
17 + ElementRandom
18 + ElementConversion
19 + ElementPrecision
20 + ElementComparison
21 + bytemuck::CheckedBitPattern
22 + bytemuck::NoUninit
23 + core::fmt::Debug
24 + core::fmt::Display
25 + Default
26 + Send
27 + Sync
28 + Copy
29 + 'static
30{
31 fn dtype() -> DType;
33}
34
35pub trait ElementConversion {
37 fn from_elem<E: ToElement>(elem: E) -> Self;
47
48 fn elem<E: Element>(self) -> E;
50}
51
52pub trait ElementRandom {
54 fn random<R: RngCore>(distribution: Distribution, rng: &mut R) -> Self;
65}
66
67pub trait ElementComparison {
69 fn cmp(&self, other: &Self) -> Ordering;
71}
72
73#[derive(Clone, PartialEq, Eq, Copy, Debug)]
75pub enum Precision {
76 Double,
78
79 Full,
81
82 Half,
84
85 Other,
87}
88
89pub trait ElementPrecision {
91 fn precision() -> Precision;
93}
94
95#[macro_export]
97macro_rules! make_element {
98 (
99 ty $type:ident $precision:expr,
100 convert $convert:expr,
101 random $random:expr,
102 cmp $cmp:expr,
103 dtype $dtype:expr
104
105 ) => {
106 impl Element for $type {
107 fn dtype() -> $crate::DType {
108 $dtype
109 }
110 }
111
112 impl ElementConversion for $type {
113 fn from_elem<E: ToElement>(elem: E) -> Self {
114 #[allow(clippy::redundant_closure_call)]
115 $convert(&elem)
116 }
117 fn elem<E: Element>(self) -> E {
118 E::from_elem(self)
119 }
120 }
121
122 impl ElementPrecision for $type {
123 fn precision() -> Precision {
124 $precision
125 }
126 }
127
128 impl ElementRandom for $type {
129 fn random<R: RngCore>(distribution: Distribution, rng: &mut R) -> Self {
130 #[allow(clippy::redundant_closure_call)]
131 $random(distribution, rng)
132 }
133 }
134
135 impl ElementComparison for $type {
136 fn cmp(&self, other: &Self) -> Ordering {
137 let a = self.elem::<$type>();
138 let b = other.elem::<$type>();
139 #[allow(clippy::redundant_closure_call)]
140 $cmp(&a, &b)
141 }
142 }
143 };
144}
145
146make_element!(
147 ty f64 Precision::Double,
148 convert |elem: &dyn ToElement| elem.to_f64(),
149 random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(),
150 cmp |a: &f64, b: &f64| a.total_cmp(b),
151 dtype DType::F64
152);
153
154make_element!(
155 ty f32 Precision::Full,
156 convert |elem: &dyn ToElement| elem.to_f32(),
157 random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(),
158 cmp |a: &f32, b: &f32| a.total_cmp(b),
159 dtype DType::F32
160);
161
162make_element!(
163 ty i64 Precision::Double,
164 convert |elem: &dyn ToElement| elem.to_i64(),
165 random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(),
166 cmp |a: &i64, b: &i64| Ord::cmp(a, b),
167 dtype DType::I64
168);
169
170make_element!(
171 ty u64 Precision::Double,
172 convert |elem: &dyn ToElement| elem.to_u64(),
173 random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(),
174 cmp |a: &u64, b: &u64| Ord::cmp(a, b),
175 dtype DType::U64
176);
177
178make_element!(
179 ty i32 Precision::Full,
180 convert |elem: &dyn ToElement| elem.to_i32(),
181 random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(),
182 cmp |a: &i32, b: &i32| Ord::cmp(a, b),
183 dtype DType::I32
184);
185
186make_element!(
187 ty u32 Precision::Full,
188 convert |elem: &dyn ToElement| elem.to_u32(),
189 random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(),
190 cmp |a: &u32, b: &u32| Ord::cmp(a, b),
191 dtype DType::U32
192);
193
194make_element!(
195 ty i16 Precision::Half,
196 convert |elem: &dyn ToElement| elem.to_i16(),
197 random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(),
198 cmp |a: &i16, b: &i16| Ord::cmp(a, b),
199 dtype DType::I16
200);
201
202make_element!(
203 ty u16 Precision::Half,
204 convert |elem: &dyn ToElement| elem.to_u16(),
205 random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(),
206 cmp |a: &u16, b: &u16| Ord::cmp(a, b),
207 dtype DType::U16
208);
209
210make_element!(
211 ty i8 Precision::Other,
212 convert |elem: &dyn ToElement| elem.to_i8(),
213 random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(),
214 cmp |a: &i8, b: &i8| Ord::cmp(a, b),
215 dtype DType::I8
216);
217
218make_element!(
219 ty u8 Precision::Other,
220 convert |elem: &dyn ToElement| elem.to_u8(),
221 random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(),
222 cmp |a: &u8, b: &u8| Ord::cmp(a, b),
223 dtype DType::U8
224);
225
226make_element!(
227 ty f16 Precision::Half,
228 convert |elem: &dyn ToElement| f16::from_f32(elem.to_f32()),
229 random |distribution: Distribution, rng: &mut R| {
230 let sample: f32 = distribution.sampler(rng).sample();
231 f16::from_elem(sample)
232 },
233 cmp |a: &f16, b: &f16| a.total_cmp(b),
234 dtype DType::F16
235);
236make_element!(
237 ty bf16 Precision::Half,
238 convert |elem: &dyn ToElement| bf16::from_f32(elem.to_f32()),
239 random |distribution: Distribution, rng: &mut R| {
240 let sample: f32 = distribution.sampler(rng).sample();
241 bf16::from_elem(sample)
242 },
243 cmp |a: &bf16, b: &bf16| a.total_cmp(b),
244 dtype DType::BF16
245);
246
247#[cfg(feature = "cubecl")]
248make_element!(
249 ty flex32 Precision::Half,
250 convert |elem: &dyn ToElement| flex32::from_f32(elem.to_f32()),
251 random |distribution: Distribution, rng: &mut R| {
252 let sample: f32 = distribution.sampler(rng).sample();
253 flex32::from_elem(sample)
254 },
255 cmp |a: &flex32, b: &flex32| a.total_cmp(b),
256 dtype DType::F32
257);
258
259make_element!(
260 ty bool Precision::Other,
261 convert |elem: &dyn ToElement| elem.to_u8() != 0,
262 random |distribution: Distribution, rng: &mut R| {
263 let sample: u8 = distribution.sampler(rng).sample();
264 bool::from_elem(sample)
265 },
266 cmp |a: &bool, b: &bool| Ord::cmp(a, b),
267 dtype DType::Bool
268);
269
270#[allow(missing_docs)]
271#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq, Serialize, Deserialize)]
272pub enum DType {
273 F64,
274 F32,
275 F16,
276 BF16,
277 I64,
278 I32,
279 I16,
280 I8,
281 U64,
282 U32,
283 U16,
284 U8,
285 Bool,
286 QFloat(QuantizationScheme),
287}
288
289impl DType {
290 pub const fn size(&self) -> usize {
292 match self {
293 DType::F64 => core::mem::size_of::<f64>(),
294 DType::F32 => core::mem::size_of::<f32>(),
295 DType::F16 => core::mem::size_of::<f16>(),
296 DType::BF16 => core::mem::size_of::<bf16>(),
297 DType::I64 => core::mem::size_of::<i64>(),
298 DType::I32 => core::mem::size_of::<i32>(),
299 DType::I16 => core::mem::size_of::<i16>(),
300 DType::I8 => core::mem::size_of::<i8>(),
301 DType::U64 => core::mem::size_of::<u64>(),
302 DType::U32 => core::mem::size_of::<u32>(),
303 DType::U16 => core::mem::size_of::<u16>(),
304 DType::U8 => core::mem::size_of::<u8>(),
305 DType::Bool => core::mem::size_of::<bool>(),
306 DType::QFloat(scheme) => match scheme {
307 QuantizationScheme::PerTensorAffine(qtype)
308 | QuantizationScheme::PerTensorSymmetric(qtype) => match qtype {
309 QuantizationType::QInt8 => core::mem::size_of::<i8>(),
310 },
311 },
312 }
313 }
314 pub fn is_float(&self) -> bool {
316 matches!(self, DType::F64 | DType::F32 | DType::F16 | DType::BF16)
317 }
318 pub fn is_int(&self) -> bool {
320 matches!(self, DType::I64 | DType::I32 | DType::I16 | DType::I8)
321 }
322
323 pub fn is_bool(&self) -> bool {
325 matches!(self, DType::Bool)
326 }
327
328 pub fn name(&self) -> &'static str {
330 match self {
331 DType::F64 => "f64",
332 DType::F32 => "f32",
333 DType::F16 => "f16",
334 DType::BF16 => "bf16",
335 DType::I64 => "i64",
336 DType::I32 => "i32",
337 DType::I16 => "i16",
338 DType::I8 => "i8",
339 DType::U64 => "u64",
340 DType::U32 => "u32",
341 DType::U16 => "u16",
342 DType::U8 => "u8",
343 DType::Bool => "bool",
344 DType::QFloat(_) => "qfloat",
345 }
346 }
347}
348
349#[allow(missing_docs)]
350#[derive(Debug, Clone)]
351pub enum FloatDType {
352 F64,
353 F32,
354 F16,
355 BF16,
356}
357
358impl From<DType> for FloatDType {
359 fn from(value: DType) -> Self {
360 match value {
361 DType::F64 => FloatDType::F64,
362 DType::F32 => FloatDType::F32,
363 DType::F16 => FloatDType::F16,
364 DType::BF16 => FloatDType::BF16,
365 _ => panic!("Expected float data type, got {value:?}"),
366 }
367 }
368}
369
370impl From<FloatDType> for DType {
371 fn from(value: FloatDType) -> Self {
372 match value {
373 FloatDType::F64 => DType::F64,
374 FloatDType::F32 => DType::F32,
375 FloatDType::F16 => DType::F16,
376 FloatDType::BF16 => DType::BF16,
377 }
378 }
379}