burn_tensor/tensor/element/
base.rs

1use core::cmp::Ordering;
2
3use crate::{
4    cast::ToElement,
5    quantization::{QuantizationScheme, QuantizationType},
6    Distribution,
7};
8#[cfg(feature = "cubecl")]
9use cubecl::flex32;
10use half::{bf16, f16};
11use rand::RngCore;
12use serde::{Deserialize, Serialize};
13
14/// Element trait for tensor.
15pub trait Element:
16    ToElement
17    + ElementRandom
18    + ElementConversion
19    + ElementPrecision
20    + ElementComparison
21    + bytemuck::CheckedBitPattern
22    + bytemuck::NoUninit
23    + core::fmt::Debug
24    + core::fmt::Display
25    + Default
26    + Send
27    + Sync
28    + Copy
29    + 'static
30{
31    /// The dtype of the element.
32    fn dtype() -> DType;
33}
34
35/// Element conversion trait for tensor.
36pub trait ElementConversion {
37    /// Converts an element to another element.
38    ///
39    /// # Arguments
40    ///
41    /// * `elem` - The element to convert.
42    ///
43    /// # Returns
44    ///
45    /// The converted element.
46    fn from_elem<E: ToElement>(elem: E) -> Self;
47
48    /// Converts and returns the converted element.
49    fn elem<E: Element>(self) -> E;
50}
51
52/// Element trait for random value of a tensor.
53pub trait ElementRandom {
54    /// Returns a random value for the given distribution.
55    ///
56    /// # Arguments
57    ///
58    /// * `distribution` - The distribution to sample from.
59    /// * `rng` - The random number generator.
60    ///
61    /// # Returns
62    ///
63    /// The random value.
64    fn random<R: RngCore>(distribution: Distribution, rng: &mut R) -> Self;
65}
66
67/// Element ordering trait.
68pub trait ElementComparison {
69    /// Returns and [Ordering] between `self` and `other`.
70    fn cmp(&self, other: &Self) -> Ordering;
71}
72
73/// Element precision trait for tensor.
74#[derive(Clone, PartialEq, Eq, Copy, Debug)]
75pub enum Precision {
76    /// Double precision, e.g. f64.
77    Double,
78
79    /// Full precision, e.g. f32.
80    Full,
81
82    /// Half precision, e.g. f16.
83    Half,
84
85    /// Other precision.
86    Other,
87}
88
89/// Element precision trait for tensor.
90pub trait ElementPrecision {
91    /// Returns the precision of the element.
92    fn precision() -> Precision;
93}
94
95/// Macro to implement the element trait for a type.
96#[macro_export]
97macro_rules! make_element {
98    (
99        ty $type:ident $precision:expr,
100        convert $convert:expr,
101        random $random:expr,
102        cmp $cmp:expr,
103        dtype $dtype:expr
104
105    ) => {
106        impl Element for $type {
107            fn dtype() -> $crate::DType {
108                $dtype
109            }
110        }
111
112        impl ElementConversion for $type {
113            fn from_elem<E: ToElement>(elem: E) -> Self {
114                #[allow(clippy::redundant_closure_call)]
115                $convert(&elem)
116            }
117            fn elem<E: Element>(self) -> E {
118                E::from_elem(self)
119            }
120        }
121
122        impl ElementPrecision for $type {
123            fn precision() -> Precision {
124                $precision
125            }
126        }
127
128        impl ElementRandom for $type {
129            fn random<R: RngCore>(distribution: Distribution, rng: &mut R) -> Self {
130                #[allow(clippy::redundant_closure_call)]
131                $random(distribution, rng)
132            }
133        }
134
135        impl ElementComparison for $type {
136            fn cmp(&self, other: &Self) -> Ordering {
137                let a = self.elem::<$type>();
138                let b = other.elem::<$type>();
139                #[allow(clippy::redundant_closure_call)]
140                $cmp(&a, &b)
141            }
142        }
143    };
144}
145
146make_element!(
147    ty f64 Precision::Double,
148    convert |elem: &dyn ToElement| elem.to_f64(),
149    random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(),
150    cmp |a: &f64, b: &f64| a.total_cmp(b),
151    dtype DType::F64
152);
153
154make_element!(
155    ty f32 Precision::Full,
156    convert |elem: &dyn ToElement| elem.to_f32(),
157    random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(),
158    cmp |a: &f32, b: &f32| a.total_cmp(b),
159    dtype DType::F32
160);
161
162make_element!(
163    ty i64 Precision::Double,
164    convert |elem: &dyn ToElement| elem.to_i64(),
165    random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(),
166    cmp |a: &i64, b: &i64| Ord::cmp(a, b),
167    dtype DType::I64
168);
169
170make_element!(
171    ty u64 Precision::Double,
172    convert |elem: &dyn ToElement| elem.to_u64(),
173    random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(),
174    cmp |a: &u64, b: &u64| Ord::cmp(a, b),
175    dtype DType::U64
176);
177
178make_element!(
179    ty i32 Precision::Full,
180    convert |elem: &dyn ToElement| elem.to_i32(),
181    random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(),
182    cmp |a: &i32, b: &i32| Ord::cmp(a, b),
183    dtype DType::I32
184);
185
186make_element!(
187    ty u32 Precision::Full,
188    convert |elem: &dyn ToElement| elem.to_u32(),
189    random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(),
190    cmp |a: &u32, b: &u32| Ord::cmp(a, b),
191    dtype DType::U32
192);
193
194make_element!(
195    ty i16 Precision::Half,
196    convert |elem: &dyn ToElement| elem.to_i16(),
197    random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(),
198    cmp |a: &i16, b: &i16| Ord::cmp(a, b),
199    dtype DType::I16
200);
201
202make_element!(
203    ty u16 Precision::Half,
204    convert |elem: &dyn ToElement| elem.to_u16(),
205    random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(),
206    cmp |a: &u16, b: &u16| Ord::cmp(a, b),
207    dtype DType::U16
208);
209
210make_element!(
211    ty i8 Precision::Other,
212    convert |elem: &dyn ToElement| elem.to_i8(),
213    random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(),
214    cmp |a: &i8, b: &i8| Ord::cmp(a, b),
215    dtype DType::I8
216);
217
218make_element!(
219    ty u8 Precision::Other,
220    convert |elem: &dyn ToElement| elem.to_u8(),
221    random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(),
222    cmp |a: &u8, b: &u8| Ord::cmp(a, b),
223    dtype DType::U8
224);
225
226make_element!(
227    ty f16 Precision::Half,
228    convert |elem: &dyn ToElement| f16::from_f32(elem.to_f32()),
229    random |distribution: Distribution, rng: &mut R| {
230        let sample: f32 = distribution.sampler(rng).sample();
231        f16::from_elem(sample)
232    },
233    cmp |a: &f16, b: &f16| a.total_cmp(b),
234    dtype DType::F16
235);
236make_element!(
237    ty bf16 Precision::Half,
238    convert |elem: &dyn ToElement| bf16::from_f32(elem.to_f32()),
239    random |distribution: Distribution, rng: &mut R| {
240        let sample: f32 = distribution.sampler(rng).sample();
241        bf16::from_elem(sample)
242    },
243    cmp |a: &bf16, b: &bf16| a.total_cmp(b),
244    dtype DType::BF16
245);
246
247#[cfg(feature = "cubecl")]
248make_element!(
249    ty flex32 Precision::Half,
250    convert |elem: &dyn ToElement| flex32::from_f32(elem.to_f32()),
251    random |distribution: Distribution, rng: &mut R| {
252        let sample: f32 = distribution.sampler(rng).sample();
253        flex32::from_elem(sample)
254    },
255    cmp |a: &flex32, b: &flex32| a.total_cmp(b),
256    dtype DType::F32
257);
258
259make_element!(
260    ty bool Precision::Other,
261    convert |elem: &dyn ToElement| elem.to_u8() != 0,
262    random |distribution: Distribution, rng: &mut R| {
263        let sample: u8 = distribution.sampler(rng).sample();
264        bool::from_elem(sample)
265    },
266    cmp |a: &bool, b: &bool| Ord::cmp(a, b),
267    dtype DType::Bool
268);
269
270#[allow(missing_docs)]
271#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq, Serialize, Deserialize)]
272pub enum DType {
273    F64,
274    F32,
275    F16,
276    BF16,
277    I64,
278    I32,
279    I16,
280    I8,
281    U64,
282    U32,
283    U16,
284    U8,
285    Bool,
286    QFloat(QuantizationScheme),
287}
288
289impl DType {
290    /// Returns the size of a type in bytes.
291    pub const fn size(&self) -> usize {
292        match self {
293            DType::F64 => core::mem::size_of::<f64>(),
294            DType::F32 => core::mem::size_of::<f32>(),
295            DType::F16 => core::mem::size_of::<f16>(),
296            DType::BF16 => core::mem::size_of::<bf16>(),
297            DType::I64 => core::mem::size_of::<i64>(),
298            DType::I32 => core::mem::size_of::<i32>(),
299            DType::I16 => core::mem::size_of::<i16>(),
300            DType::I8 => core::mem::size_of::<i8>(),
301            DType::U64 => core::mem::size_of::<u64>(),
302            DType::U32 => core::mem::size_of::<u32>(),
303            DType::U16 => core::mem::size_of::<u16>(),
304            DType::U8 => core::mem::size_of::<u8>(),
305            DType::Bool => core::mem::size_of::<bool>(),
306            DType::QFloat(scheme) => match scheme {
307                QuantizationScheme::PerTensorAffine(qtype)
308                | QuantizationScheme::PerTensorSymmetric(qtype) => match qtype {
309                    QuantizationType::QInt8 => core::mem::size_of::<i8>(),
310                },
311            },
312        }
313    }
314    /// Returns true if the data type is a floating point type.
315    pub fn is_float(&self) -> bool {
316        matches!(self, DType::F64 | DType::F32 | DType::F16 | DType::BF16)
317    }
318    /// Returns true if the data type is a signed integer type.
319    pub fn is_int(&self) -> bool {
320        matches!(self, DType::I64 | DType::I32 | DType::I16 | DType::I8)
321    }
322
323    /// Returns true if the data type is a boolean type
324    pub fn is_bool(&self) -> bool {
325        matches!(self, DType::Bool)
326    }
327
328    /// Returns the data type name.
329    pub fn name(&self) -> &'static str {
330        match self {
331            DType::F64 => "f64",
332            DType::F32 => "f32",
333            DType::F16 => "f16",
334            DType::BF16 => "bf16",
335            DType::I64 => "i64",
336            DType::I32 => "i32",
337            DType::I16 => "i16",
338            DType::I8 => "i8",
339            DType::U64 => "u64",
340            DType::U32 => "u32",
341            DType::U16 => "u16",
342            DType::U8 => "u8",
343            DType::Bool => "bool",
344            DType::QFloat(_) => "qfloat",
345        }
346    }
347}
348
349#[allow(missing_docs)]
350#[derive(Debug, Clone)]
351pub enum FloatDType {
352    F64,
353    F32,
354    F16,
355    BF16,
356}
357
358impl From<DType> for FloatDType {
359    fn from(value: DType) -> Self {
360        match value {
361            DType::F64 => FloatDType::F64,
362            DType::F32 => FloatDType::F32,
363            DType::F16 => FloatDType::F16,
364            DType::BF16 => FloatDType::BF16,
365            _ => panic!("Expected float data type, got {value:?}"),
366        }
367    }
368}
369
370impl From<FloatDType> for DType {
371    fn from(value: FloatDType) -> Self {
372        match value {
373            FloatDType::F64 => DType::F64,
374            FloatDType::F32 => DType::F32,
375            FloatDType::F16 => DType::F16,
376            FloatDType::BF16 => DType::BF16,
377        }
378    }
379}