burn_tensor/tensor/element/
base.rs

1use core::cmp::Ordering;
2
3use crate::{
4    Distribution,
5    cast::ToElement,
6    quantization::{QuantizationScheme, QuantizationType},
7};
8#[cfg(feature = "cubecl")]
9use cubecl::flex32;
10use half::{bf16, f16};
11use rand::RngCore;
12use serde::{Deserialize, Serialize};
13
14/// Element trait for tensor.
15pub trait Element:
16    ToElement
17    + ElementRandom
18    + ElementConversion
19    + ElementPrecision
20    + ElementComparison
21    + ElementLimits
22    + bytemuck::CheckedBitPattern
23    + bytemuck::NoUninit
24    + bytemuck::Zeroable
25    + core::fmt::Debug
26    + core::fmt::Display
27    + Default
28    + Send
29    + Sync
30    + Copy
31    + 'static
32{
33    /// The dtype of the element.
34    fn dtype() -> DType;
35}
36
37/// Element conversion trait for tensor.
38pub trait ElementConversion {
39    /// Converts an element to another element.
40    ///
41    /// # Arguments
42    ///
43    /// * `elem` - The element to convert.
44    ///
45    /// # Returns
46    ///
47    /// The converted element.
48    fn from_elem<E: ToElement>(elem: E) -> Self;
49
50    /// Converts and returns the converted element.
51    fn elem<E: Element>(self) -> E;
52}
53
54/// Element trait for random value of a tensor.
55pub trait ElementRandom {
56    /// Returns a random value for the given distribution.
57    ///
58    /// # Arguments
59    ///
60    /// * `distribution` - The distribution to sample from.
61    /// * `rng` - The random number generator.
62    ///
63    /// # Returns
64    ///
65    /// The random value.
66    fn random<R: RngCore>(distribution: Distribution, rng: &mut R) -> Self;
67}
68
69/// Element ordering trait.
70pub trait ElementComparison {
71    /// Returns and [Ordering] between `self` and `other`.
72    fn cmp(&self, other: &Self) -> Ordering;
73}
74
75/// Element ordering trait.
76pub trait ElementLimits {
77    /// The minimum representable value
78    const MIN: Self;
79    /// The maximum representable value
80    const MAX: Self;
81}
82
83/// Element precision trait for tensor.
84#[derive(Clone, PartialEq, Eq, Copy, Debug)]
85pub enum Precision {
86    /// Double precision, e.g. f64.
87    Double,
88
89    /// Full precision, e.g. f32.
90    Full,
91
92    /// Half precision, e.g. f16.
93    Half,
94
95    /// Other precision.
96    Other,
97}
98
99/// Element precision trait for tensor.
100pub trait ElementPrecision {
101    /// Returns the precision of the element.
102    fn precision() -> Precision;
103}
104
105/// Macro to implement the element trait for a type.
106#[macro_export]
107macro_rules! make_element {
108    (
109        ty $type:ident $precision:expr,
110        convert $convert:expr,
111        random $random:expr,
112        cmp $cmp:expr,
113        dtype $dtype:expr
114    ) => {
115        make_element!(ty $type $precision, convert $convert, random $random, cmp $cmp, dtype $dtype, min $type::MIN, max $type::MAX);
116    };
117    (
118        ty $type:ident $precision:expr,
119        convert $convert:expr,
120        random $random:expr,
121        cmp $cmp:expr,
122        dtype $dtype:expr,
123        min $min:expr,
124        max $max:expr
125    ) => {
126        impl Element for $type {
127            #[inline(always)]
128            fn dtype() -> $crate::DType {
129                $dtype
130            }
131        }
132
133        impl ElementConversion for $type {
134            #[inline(always)]
135            fn from_elem<E: ToElement>(elem: E) -> Self {
136                #[allow(clippy::redundant_closure_call)]
137                $convert(&elem)
138            }
139            #[inline(always)]
140            fn elem<E: Element>(self) -> E {
141                E::from_elem(self)
142            }
143        }
144
145        impl ElementPrecision for $type {
146            fn precision() -> Precision {
147                $precision
148            }
149        }
150
151        impl ElementRandom for $type {
152            fn random<R: RngCore>(distribution: Distribution, rng: &mut R) -> Self {
153                #[allow(clippy::redundant_closure_call)]
154                $random(distribution, rng)
155            }
156        }
157
158        impl ElementComparison for $type {
159            fn cmp(&self, other: &Self) -> Ordering {
160                let a = self.elem::<$type>();
161                let b = other.elem::<$type>();
162                #[allow(clippy::redundant_closure_call)]
163                $cmp(&a, &b)
164            }
165        }
166
167        impl ElementLimits for $type {
168            const MIN: Self = $min;
169            const MAX: Self = $max;
170        }
171    };
172}
173
174make_element!(
175    ty f64 Precision::Double,
176    convert ToElement::to_f64,
177    random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(),
178    cmp |a: &f64, b: &f64| a.total_cmp(b),
179    dtype DType::F64
180);
181
182make_element!(
183    ty f32 Precision::Full,
184    convert ToElement::to_f32,
185    random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(),
186    cmp |a: &f32, b: &f32| a.total_cmp(b),
187    dtype DType::F32
188);
189
190make_element!(
191    ty i64 Precision::Double,
192    convert ToElement::to_i64,
193    random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(),
194    cmp |a: &i64, b: &i64| Ord::cmp(a, b),
195    dtype DType::I64
196);
197
198make_element!(
199    ty u64 Precision::Double,
200    convert ToElement::to_u64,
201    random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(),
202    cmp |a: &u64, b: &u64| Ord::cmp(a, b),
203    dtype DType::U64
204);
205
206make_element!(
207    ty i32 Precision::Full,
208    convert ToElement::to_i32,
209    random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(),
210    cmp |a: &i32, b: &i32| Ord::cmp(a, b),
211    dtype DType::I32
212);
213
214make_element!(
215    ty u32 Precision::Full,
216    convert ToElement::to_u32,
217    random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(),
218    cmp |a: &u32, b: &u32| Ord::cmp(a, b),
219    dtype DType::U32
220);
221
222make_element!(
223    ty i16 Precision::Half,
224    convert ToElement::to_i16,
225    random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(),
226    cmp |a: &i16, b: &i16| Ord::cmp(a, b),
227    dtype DType::I16
228);
229
230make_element!(
231    ty u16 Precision::Half,
232    convert ToElement::to_u16,
233    random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(),
234    cmp |a: &u16, b: &u16| Ord::cmp(a, b),
235    dtype DType::U16
236);
237
238make_element!(
239    ty i8 Precision::Other,
240    convert ToElement::to_i8,
241    random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(),
242    cmp |a: &i8, b: &i8| Ord::cmp(a, b),
243    dtype DType::I8
244);
245
246make_element!(
247    ty u8 Precision::Other,
248    convert ToElement::to_u8,
249    random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(),
250    cmp |a: &u8, b: &u8| Ord::cmp(a, b),
251    dtype DType::U8
252);
253
254make_element!(
255    ty f16 Precision::Half,
256    convert ToElement::to_f16,
257    random |distribution: Distribution, rng: &mut R| {
258        let sample: f32 = distribution.sampler(rng).sample();
259        f16::from_elem(sample)
260    },
261    cmp |a: &f16, b: &f16| a.total_cmp(b),
262    dtype DType::F16
263);
264make_element!(
265    ty bf16 Precision::Half,
266    convert ToElement::to_bf16,
267    random |distribution: Distribution, rng: &mut R| {
268        let sample: f32 = distribution.sampler(rng).sample();
269        bf16::from_elem(sample)
270    },
271    cmp |a: &bf16, b: &bf16| a.total_cmp(b),
272    dtype DType::BF16
273);
274
275#[cfg(feature = "cubecl")]
276make_element!(
277    ty flex32 Precision::Half,
278    convert |elem: &dyn ToElement| flex32::from_f32(elem.to_f32()),
279    random |distribution: Distribution, rng: &mut R| {
280        let sample: f32 = distribution.sampler(rng).sample();
281        flex32::from_elem(sample)
282    },
283    cmp |a: &flex32, b: &flex32| a.total_cmp(b),
284    dtype DType::Flex32,
285    min flex32::from_f32(half::f16::MIN.to_f32_const()),
286    max flex32::from_f32(half::f16::MAX.to_f32_const())
287);
288
289make_element!(
290    ty bool Precision::Other,
291    convert ToElement::to_bool,
292    random |distribution: Distribution, rng: &mut R| {
293        let sample: u8 = distribution.sampler(rng).sample();
294        bool::from_elem(sample)
295    },
296    cmp |a: &bool, b: &bool| Ord::cmp(a, b),
297    dtype DType::Bool,
298    min false,
299    max true
300);
301
302#[allow(missing_docs)]
303#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq, Serialize, Deserialize)]
304pub enum DType {
305    F64,
306    F32,
307    Flex32,
308    F16,
309    BF16,
310    I64,
311    I32,
312    I16,
313    I8,
314    U64,
315    U32,
316    U16,
317    U8,
318    Bool,
319    QFloat(QuantizationScheme),
320}
321
322impl DType {
323    /// Returns the size of a type in bytes.
324    pub const fn size(&self) -> usize {
325        match self {
326            DType::F64 => core::mem::size_of::<f64>(),
327            DType::F32 => core::mem::size_of::<f32>(),
328            DType::Flex32 => core::mem::size_of::<f32>(),
329            DType::F16 => core::mem::size_of::<f16>(),
330            DType::BF16 => core::mem::size_of::<bf16>(),
331            DType::I64 => core::mem::size_of::<i64>(),
332            DType::I32 => core::mem::size_of::<i32>(),
333            DType::I16 => core::mem::size_of::<i16>(),
334            DType::I8 => core::mem::size_of::<i8>(),
335            DType::U64 => core::mem::size_of::<u64>(),
336            DType::U32 => core::mem::size_of::<u32>(),
337            DType::U16 => core::mem::size_of::<u16>(),
338            DType::U8 => core::mem::size_of::<u8>(),
339            DType::Bool => core::mem::size_of::<bool>(),
340            DType::QFloat(scheme) => match scheme {
341                QuantizationScheme::PerTensor(_mode, QuantizationType::QInt8) => {
342                    core::mem::size_of::<i8>()
343                }
344            },
345        }
346    }
347    /// Returns true if the data type is a floating point type.
348    pub fn is_float(&self) -> bool {
349        matches!(self, DType::F64 | DType::F32 | DType::F16 | DType::BF16)
350    }
351    /// Returns true if the data type is a signed integer type.
352    pub fn is_int(&self) -> bool {
353        matches!(self, DType::I64 | DType::I32 | DType::I16 | DType::I8)
354    }
355
356    /// Returns true if the data type is a boolean type
357    pub fn is_bool(&self) -> bool {
358        matches!(self, DType::Bool)
359    }
360
361    /// Returns the data type name.
362    pub fn name(&self) -> &'static str {
363        match self {
364            DType::F64 => "f64",
365            DType::F32 => "f32",
366            DType::Flex32 => "flex32",
367            DType::F16 => "f16",
368            DType::BF16 => "bf16",
369            DType::I64 => "i64",
370            DType::I32 => "i32",
371            DType::I16 => "i16",
372            DType::I8 => "i8",
373            DType::U64 => "u64",
374            DType::U32 => "u32",
375            DType::U16 => "u16",
376            DType::U8 => "u8",
377            DType::Bool => "bool",
378            DType::QFloat(_) => "qfloat",
379        }
380    }
381}
382
383#[allow(missing_docs)]
384#[derive(Debug, Clone)]
385pub enum FloatDType {
386    F64,
387    F32,
388    F16,
389    BF16,
390}
391
392impl From<DType> for FloatDType {
393    fn from(value: DType) -> Self {
394        match value {
395            DType::F64 => FloatDType::F64,
396            DType::F32 => FloatDType::F32,
397            DType::F16 => FloatDType::F16,
398            DType::BF16 => FloatDType::BF16,
399            _ => panic!("Expected float data type, got {value:?}"),
400        }
401    }
402}
403
404impl From<FloatDType> for DType {
405    fn from(value: FloatDType) -> Self {
406        match value {
407            FloatDType::F64 => DType::F64,
408            FloatDType::F32 => DType::F32,
409            FloatDType::F16 => DType::F16,
410            FloatDType::BF16 => DType::BF16,
411        }
412    }
413}