burn_tensor/tensor/element/
base.rs

1use core::cmp::Ordering;
2
3use crate::{Distribution, cast::ToElement, quantization::QuantScheme};
4#[cfg(feature = "cubecl")]
5use cubecl::flex32;
6
7use cubecl_quant::scheme::{QuantStore, QuantValue};
8use half::{bf16, f16};
9use rand::RngCore;
10use serde::{Deserialize, Serialize};
11
12/// Element trait for tensor.
13pub trait Element:
14    ToElement
15    + ElementRandom
16    + ElementConversion
17    + ElementComparison
18    + ElementLimits
19    + bytemuck::CheckedBitPattern
20    + bytemuck::NoUninit
21    + bytemuck::Zeroable
22    + core::fmt::Debug
23    + core::fmt::Display
24    + Default
25    + Send
26    + Sync
27    + Copy
28    + 'static
29{
30    /// The dtype of the element.
31    fn dtype() -> DType;
32}
33
34/// Element conversion trait for tensor.
35pub trait ElementConversion {
36    /// Converts an element to another element.
37    ///
38    /// # Arguments
39    ///
40    /// * `elem` - The element to convert.
41    ///
42    /// # Returns
43    ///
44    /// The converted element.
45    fn from_elem<E: ToElement>(elem: E) -> Self;
46
47    /// Converts and returns the converted element.
48    fn elem<E: Element>(self) -> E;
49}
50
51/// Element trait for random value of a tensor.
52pub trait ElementRandom {
53    /// Returns a random value for the given distribution.
54    ///
55    /// # Arguments
56    ///
57    /// * `distribution` - The distribution to sample from.
58    /// * `rng` - The random number generator.
59    ///
60    /// # Returns
61    ///
62    /// The random value.
63    fn random<R: RngCore>(distribution: Distribution, rng: &mut R) -> Self;
64}
65
66/// Element ordering trait.
67pub trait ElementComparison {
68    /// Returns and [Ordering] between `self` and `other`.
69    fn cmp(&self, other: &Self) -> Ordering;
70}
71
72/// Element ordering trait.
73pub trait ElementLimits {
74    /// The minimum representable value
75    const MIN: Self;
76    /// The maximum representable value
77    const MAX: Self;
78}
79
80/// Macro to implement the element trait for a type.
81#[macro_export]
82macro_rules! make_element {
83    (
84        ty $type:ident,
85        convert $convert:expr,
86        random $random:expr,
87        cmp $cmp:expr,
88        dtype $dtype:expr
89    ) => {
90        make_element!(ty $type, convert $convert, random $random, cmp $cmp, dtype $dtype, min $type::MIN, max $type::MAX);
91    };
92    (
93        ty $type:ident,
94        convert $convert:expr,
95        random $random:expr,
96        cmp $cmp:expr,
97        dtype $dtype:expr,
98        min $min:expr,
99        max $max:expr
100    ) => {
101        impl Element for $type {
102            #[inline(always)]
103            fn dtype() -> $crate::DType {
104                $dtype
105            }
106        }
107
108        impl ElementConversion for $type {
109            #[inline(always)]
110            fn from_elem<E: ToElement>(elem: E) -> Self {
111                #[allow(clippy::redundant_closure_call)]
112                $convert(&elem)
113            }
114            #[inline(always)]
115            fn elem<E: Element>(self) -> E {
116                E::from_elem(self)
117            }
118        }
119
120        impl ElementRandom for $type {
121            fn random<R: RngCore>(distribution: Distribution, rng: &mut R) -> Self {
122                #[allow(clippy::redundant_closure_call)]
123                $random(distribution, rng)
124            }
125        }
126
127        impl ElementComparison for $type {
128            fn cmp(&self, other: &Self) -> Ordering {
129                let a = self.elem::<$type>();
130                let b = other.elem::<$type>();
131                #[allow(clippy::redundant_closure_call)]
132                $cmp(&a, &b)
133            }
134        }
135
136        impl ElementLimits for $type {
137            const MIN: Self = $min;
138            const MAX: Self = $max;
139        }
140    };
141}
142
143make_element!(
144    ty f64,
145    convert ToElement::to_f64,
146    random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(),
147    cmp |a: &f64, b: &f64| a.total_cmp(b),
148    dtype DType::F64
149);
150
151make_element!(
152    ty f32,
153    convert ToElement::to_f32,
154    random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(),
155    cmp |a: &f32, b: &f32| a.total_cmp(b),
156    dtype DType::F32
157);
158
159make_element!(
160    ty i64,
161    convert ToElement::to_i64,
162    random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(),
163    cmp |a: &i64, b: &i64| Ord::cmp(a, b),
164    dtype DType::I64
165);
166
167make_element!(
168    ty u64,
169    convert ToElement::to_u64,
170    random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(),
171    cmp |a: &u64, b: &u64| Ord::cmp(a, b),
172    dtype DType::U64
173);
174
175make_element!(
176    ty i32,
177    convert ToElement::to_i32,
178    random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(),
179    cmp |a: &i32, b: &i32| Ord::cmp(a, b),
180    dtype DType::I32
181);
182
183make_element!(
184    ty u32,
185    convert ToElement::to_u32,
186    random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(),
187    cmp |a: &u32, b: &u32| Ord::cmp(a, b),
188    dtype DType::U32
189);
190
191make_element!(
192    ty i16,
193    convert ToElement::to_i16,
194    random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(),
195    cmp |a: &i16, b: &i16| Ord::cmp(a, b),
196    dtype DType::I16
197);
198
199make_element!(
200    ty u16,
201    convert ToElement::to_u16,
202    random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(),
203    cmp |a: &u16, b: &u16| Ord::cmp(a, b),
204    dtype DType::U16
205);
206
207make_element!(
208    ty i8,
209    convert ToElement::to_i8,
210    random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(),
211    cmp |a: &i8, b: &i8| Ord::cmp(a, b),
212    dtype DType::I8
213);
214
215make_element!(
216    ty u8,
217    convert ToElement::to_u8,
218    random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(),
219    cmp |a: &u8, b: &u8| Ord::cmp(a, b),
220    dtype DType::U8
221);
222
223make_element!(
224    ty f16,
225    convert ToElement::to_f16,
226    random |distribution: Distribution, rng: &mut R| {
227        let sample: f32 = distribution.sampler(rng).sample();
228        f16::from_elem(sample)
229    },
230    cmp |a: &f16, b: &f16| a.total_cmp(b),
231    dtype DType::F16
232);
233make_element!(
234    ty bf16,
235    convert ToElement::to_bf16,
236    random |distribution: Distribution, rng: &mut R| {
237        let sample: f32 = distribution.sampler(rng).sample();
238        bf16::from_elem(sample)
239    },
240    cmp |a: &bf16, b: &bf16| a.total_cmp(b),
241    dtype DType::BF16
242);
243
244#[cfg(feature = "cubecl")]
245make_element!(
246    ty flex32,
247    convert |elem: &dyn ToElement| flex32::from_f32(elem.to_f32()),
248    random |distribution: Distribution, rng: &mut R| {
249        let sample: f32 = distribution.sampler(rng).sample();
250        flex32::from_elem(sample)
251    },
252    cmp |a: &flex32, b: &flex32| a.total_cmp(b),
253    dtype DType::Flex32,
254    min flex32::from_f32(half::f16::MIN.to_f32_const()),
255    max flex32::from_f32(half::f16::MAX.to_f32_const())
256);
257
258make_element!(
259    ty bool,
260    convert ToElement::to_bool,
261    random |distribution: Distribution, rng: &mut R| {
262        let sample: u8 = distribution.sampler(rng).sample();
263        bool::from_elem(sample)
264    },
265    cmp |a: &bool, b: &bool| Ord::cmp(a, b),
266    dtype DType::Bool,
267    min false,
268    max true
269);
270
271#[allow(missing_docs)]
272#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq, Serialize, Deserialize)]
273pub enum DType {
274    F64,
275    F32,
276    Flex32,
277    F16,
278    BF16,
279    I64,
280    I32,
281    I16,
282    I8,
283    U64,
284    U32,
285    U16,
286    U8,
287    Bool,
288    QFloat(QuantScheme),
289}
290
291#[cfg(feature = "cubecl")]
292impl From<cubecl::ir::ElemType> for DType {
293    fn from(value: cubecl::ir::ElemType) -> Self {
294        match value {
295            cubecl::ir::ElemType::Float(float_kind) => match float_kind {
296                cubecl::ir::FloatKind::F16 => DType::F16,
297                cubecl::ir::FloatKind::BF16 => DType::BF16,
298                cubecl::ir::FloatKind::Flex32 => DType::Flex32,
299                cubecl::ir::FloatKind::F32 => DType::F32,
300                cubecl::ir::FloatKind::F64 => DType::F64,
301                cubecl::ir::FloatKind::TF32 => panic!("Not a valid DType for tensors."),
302                cubecl::ir::FloatKind::E2M1
303                | cubecl::ir::FloatKind::E2M3
304                | cubecl::ir::FloatKind::E3M2
305                | cubecl::ir::FloatKind::E4M3
306                | cubecl::ir::FloatKind::E5M2
307                | cubecl::ir::FloatKind::UE8M0 => {
308                    unimplemented!("Not yet supported, will be used for quantization")
309                }
310            },
311            cubecl::ir::ElemType::Int(int_kind) => match int_kind {
312                cubecl::ir::IntKind::I8 => DType::I8,
313                cubecl::ir::IntKind::I16 => DType::I16,
314                cubecl::ir::IntKind::I32 => DType::I32,
315                cubecl::ir::IntKind::I64 => DType::I64,
316            },
317            cubecl::ir::ElemType::UInt(uint_kind) => match uint_kind {
318                cubecl::ir::UIntKind::U8 => DType::U8,
319                cubecl::ir::UIntKind::U16 => DType::U16,
320                cubecl::ir::UIntKind::U32 => DType::U32,
321                cubecl::ir::UIntKind::U64 => DType::U64,
322            },
323            _ => panic!("Not a valid DType for tensors."),
324        }
325    }
326}
327
328impl DType {
329    /// Returns the size of a type in bytes.
330    pub const fn size(&self) -> usize {
331        match self {
332            DType::F64 => core::mem::size_of::<f64>(),
333            DType::F32 => core::mem::size_of::<f32>(),
334            DType::Flex32 => core::mem::size_of::<f32>(),
335            DType::F16 => core::mem::size_of::<f16>(),
336            DType::BF16 => core::mem::size_of::<bf16>(),
337            DType::I64 => core::mem::size_of::<i64>(),
338            DType::I32 => core::mem::size_of::<i32>(),
339            DType::I16 => core::mem::size_of::<i16>(),
340            DType::I8 => core::mem::size_of::<i8>(),
341            DType::U64 => core::mem::size_of::<u64>(),
342            DType::U32 => core::mem::size_of::<u32>(),
343            DType::U16 => core::mem::size_of::<u16>(),
344            DType::U8 => core::mem::size_of::<u8>(),
345            DType::Bool => core::mem::size_of::<bool>(),
346            DType::QFloat(scheme) => match scheme.store {
347                QuantStore::Native => match scheme.value {
348                    QuantValue::Q8F | QuantValue::Q8S => core::mem::size_of::<i8>(),
349                    // e2m1 native is automatically packed by the kernels, so the actual storage is
350                    // 8 bits wide.
351                    QuantValue::E4M3 | QuantValue::E5M2 | QuantValue::E2M1 => {
352                        core::mem::size_of::<u8>()
353                    }
354                    QuantValue::Q4F | QuantValue::Q4S | QuantValue::Q2F | QuantValue::Q2S => {
355                        // Sub-byte values have fractional size
356                        0
357                    }
358                },
359                QuantStore::U32 => core::mem::size_of::<u32>(),
360            },
361        }
362    }
363    /// Returns true if the data type is a floating point type.
364    pub fn is_float(&self) -> bool {
365        matches!(
366            self,
367            DType::F64 | DType::F32 | DType::Flex32 | DType::F16 | DType::BF16
368        )
369    }
370    /// Returns true if the data type is a signed integer type.
371    pub fn is_int(&self) -> bool {
372        matches!(self, DType::I64 | DType::I32 | DType::I16 | DType::I8)
373    }
374    /// Returns true if the data type is an unsigned integer type.
375    pub fn is_uint(&self) -> bool {
376        matches!(self, DType::U64 | DType::U32 | DType::U16 | DType::U8)
377    }
378
379    /// Returns true if the data type is a boolean type
380    pub fn is_bool(&self) -> bool {
381        matches!(self, DType::Bool)
382    }
383
384    /// Returns the data type name.
385    pub fn name(&self) -> &'static str {
386        match self {
387            DType::F64 => "f64",
388            DType::F32 => "f32",
389            DType::Flex32 => "flex32",
390            DType::F16 => "f16",
391            DType::BF16 => "bf16",
392            DType::I64 => "i64",
393            DType::I32 => "i32",
394            DType::I16 => "i16",
395            DType::I8 => "i8",
396            DType::U64 => "u64",
397            DType::U32 => "u32",
398            DType::U16 => "u16",
399            DType::U8 => "u8",
400            DType::Bool => "bool",
401            DType::QFloat(_) => "qfloat",
402        }
403    }
404}
405
406#[allow(missing_docs)]
407#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
408pub enum FloatDType {
409    F64,
410    F32,
411    Flex32,
412    F16,
413    BF16,
414}
415
416impl From<DType> for FloatDType {
417    fn from(value: DType) -> Self {
418        match value {
419            DType::F64 => FloatDType::F64,
420            DType::F32 => FloatDType::F32,
421            DType::Flex32 => FloatDType::Flex32,
422            DType::F16 => FloatDType::F16,
423            DType::BF16 => FloatDType::BF16,
424            _ => panic!("Expected float data type, got {value:?}"),
425        }
426    }
427}
428
429impl From<FloatDType> for DType {
430    fn from(value: FloatDType) -> Self {
431        match value {
432            FloatDType::F64 => DType::F64,
433            FloatDType::F32 => DType::F32,
434            FloatDType::Flex32 => DType::Flex32,
435            FloatDType::F16 => DType::F16,
436            FloatDType::BF16 => DType::BF16,
437        }
438    }
439}
440
441#[allow(missing_docs)]
442#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
443pub enum IntDType {
444    I64,
445    I32,
446    I16,
447    I8,
448    U64,
449    U32,
450    U16,
451    U8,
452}
453
454impl From<DType> for IntDType {
455    fn from(value: DType) -> Self {
456        match value {
457            DType::I64 => IntDType::I64,
458            DType::I32 => IntDType::I32,
459            DType::I16 => IntDType::I16,
460            DType::I8 => IntDType::I8,
461            DType::U64 => IntDType::U64,
462            DType::U32 => IntDType::U32,
463            DType::U16 => IntDType::U16,
464            DType::U8 => IntDType::U8,
465            _ => panic!("Expected int data type, got {value:?}"),
466        }
467    }
468}
469
470impl From<IntDType> for DType {
471    fn from(value: IntDType) -> Self {
472        match value {
473            IntDType::I64 => DType::I64,
474            IntDType::I32 => DType::I32,
475            IntDType::I16 => DType::I16,
476            IntDType::I8 => DType::I8,
477            IntDType::U64 => DType::U64,
478            IntDType::U32 => DType::U32,
479            IntDType::U16 => DType::U16,
480            IntDType::U8 => DType::U8,
481        }
482    }
483}