Skip to main content

burn_backend/element/
base.rs

1use core::cmp::Ordering;
2use rand::RngCore;
3
4use crate::distribution::Distribution;
5use burn_std::{DType, bf16, f16};
6
7#[cfg(feature = "cubecl")]
8use burn_std::flex32;
9
10use super::cast::ToElement;
11
12/// Core element trait for tensor values.
13///
14/// This trait defines the minimal set of capabilities required for a type to be
15/// stored and manipulated as a tensor element across all backends.
16pub trait Element:
17    ToElement
18    + ElementRandom
19    + ElementConversion
20    + ElementEq
21    + ElementLimits
22    + bytemuck::CheckedBitPattern
23    + bytemuck::NoUninit
24    + bytemuck::Zeroable
25    + core::fmt::Debug
26    + core::fmt::Display
27    + Default
28    + Send
29    + Sync
30    + Copy
31    + 'static
32{
33    /// The dtype of the element.
34    fn dtype() -> DType;
35}
36
37/// Ordered element trait for tensor values.
38///
39/// This trait extends [`Element`] with ordering semantics, enabling comparison
40/// and order-dependent operations in generic Rust implementations.
41///
42/// Backends that implement these operations entirely at the device level do
43/// not rely on this trait. It only constrains the scalar type for generic Rust code.
44pub trait ElementOrdered: Element + ElementComparison {}
45
46/// Element conversion trait for tensor.
47pub trait ElementConversion {
48    /// Converts an element to another element.
49    ///
50    /// # Arguments
51    ///
52    /// * `elem` - The element to convert.
53    ///
54    /// # Returns
55    ///
56    /// The converted element.
57    fn from_elem<E: ToElement>(elem: E) -> Self;
58
59    /// Converts and returns the converted element.
60    fn elem<E: Element>(self) -> E;
61}
62
63/// Element trait for random value of a tensor.
64pub trait ElementRandom {
65    /// Returns a random value for the given distribution.
66    ///
67    /// # Arguments
68    ///
69    /// * `distribution` - The distribution to sample from.
70    /// * `rng` - The random number generator.
71    ///
72    /// # Returns
73    ///
74    /// The random value.
75    fn random<R: RngCore>(distribution: Distribution, rng: &mut R) -> Self;
76}
77
78/// Element trait for equality of a tensor.
79pub trait ElementEq {
80    /// Returns whether `self` and `other` are equal.
81    fn eq(&self, other: &Self) -> bool;
82}
83
84/// Element ordering trait.
85pub trait ElementComparison {
86    /// Returns and [Ordering] between `self` and `other`.
87    fn cmp(&self, other: &Self) -> Ordering;
88}
89
90/// Element limits trait.
91pub trait ElementLimits {
92    /// The minimum representable value
93    const MIN: Self;
94    /// The maximum representable value
95    const MAX: Self;
96}
97
98/// Macro to implement the element trait for a type.
99#[macro_export]
100macro_rules! make_element {
101    (
102        ty $type:ident,
103        convert $convert:expr,
104        random $random:expr,
105        cmp $cmp:expr,
106        dtype $dtype:expr
107    ) => {
108        make_element!(ty $type, convert $convert, random $random, cmp $cmp, dtype $dtype, min $type::MIN, max $type::MAX);
109    };
110    (
111        ty $type:ident,
112        convert $convert:expr,
113        random $random:expr,
114        cmp $cmp:expr,
115        dtype $dtype:expr,
116        min $min:expr,
117        max $max:expr
118    ) => {
119        impl Element for $type {
120            #[inline(always)]
121            fn dtype() -> burn_std::DType {
122                $dtype
123            }
124        }
125        impl ElementEq for $type {
126            fn eq(&self, other: &Self) -> bool {
127                self == other
128            }
129        }
130
131        impl ElementConversion for $type {
132            #[inline(always)]
133            fn from_elem<E: ToElement>(elem: E) -> Self {
134                #[allow(clippy::redundant_closure_call)]
135                $convert(&elem)
136            }
137            #[inline(always)]
138            fn elem<E: Element>(self) -> E {
139                E::from_elem(self)
140            }
141        }
142
143        impl ElementRandom for $type {
144            fn random<R: RngCore>(distribution: Distribution, rng: &mut R) -> Self {
145                #[allow(clippy::redundant_closure_call)]
146                $random(distribution, rng)
147            }
148        }
149
150        impl ElementComparison for $type {
151            fn cmp(&self, other: &Self) -> Ordering {
152                let a = self.elem::<$type>();
153                let b = other.elem::<$type>();
154                #[allow(clippy::redundant_closure_call)]
155                $cmp(&a, &b)
156            }
157        }
158
159        impl ElementLimits for $type {
160            const MIN: Self = $min;
161            const MAX: Self = $max;
162        }
163
164        impl ElementOrdered for $type {}
165
166    };
167}
168
169make_element!(
170    ty f64,
171    convert ToElement::to_f64,
172    random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(),
173    cmp |a: &f64, b: &f64| a.total_cmp(b),
174    dtype DType::F64
175);
176
177make_element!(
178    ty f32,
179    convert ToElement::to_f32,
180    random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(),
181    cmp |a: &f32, b: &f32| a.total_cmp(b),
182    dtype DType::F32
183);
184
185make_element!(
186    ty i64,
187    convert ToElement::to_i64,
188    random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(),
189    cmp |a: &i64, b: &i64| Ord::cmp(a, b),
190    dtype DType::I64
191);
192
193make_element!(
194    ty u64,
195    convert ToElement::to_u64,
196    random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(),
197    cmp |a: &u64, b: &u64| Ord::cmp(a, b),
198    dtype DType::U64
199);
200
201make_element!(
202    ty i32,
203    convert ToElement::to_i32,
204    random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(),
205    cmp |a: &i32, b: &i32| Ord::cmp(a, b),
206    dtype DType::I32
207);
208
209make_element!(
210    ty u32,
211    convert ToElement::to_u32,
212    random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(),
213    cmp |a: &u32, b: &u32| Ord::cmp(a, b),
214    dtype DType::U32
215);
216
217make_element!(
218    ty i16,
219    convert ToElement::to_i16,
220    random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(),
221    cmp |a: &i16, b: &i16| Ord::cmp(a, b),
222    dtype DType::I16
223);
224
225make_element!(
226    ty u16,
227    convert ToElement::to_u16,
228    random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(),
229    cmp |a: &u16, b: &u16| Ord::cmp(a, b),
230    dtype DType::U16
231);
232
233make_element!(
234    ty i8,
235    convert ToElement::to_i8,
236    random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(),
237    cmp |a: &i8, b: &i8| Ord::cmp(a, b),
238    dtype DType::I8
239);
240
241make_element!(
242    ty u8,
243    convert ToElement::to_u8,
244    random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(),
245    cmp |a: &u8, b: &u8| Ord::cmp(a, b),
246    dtype DType::U8
247);
248
249make_element!(
250    ty f16,
251    convert ToElement::to_f16,
252    random |distribution: Distribution, rng: &mut R| {
253        let sample: f32 = distribution.sampler(rng).sample();
254        f16::from_elem(sample)
255    },
256    cmp |a: &f16, b: &f16| a.total_cmp(b),
257    dtype DType::F16
258);
259make_element!(
260    ty bf16,
261    convert ToElement::to_bf16,
262    random |distribution: Distribution, rng: &mut R| {
263        let sample: f32 = distribution.sampler(rng).sample();
264        bf16::from_elem(sample)
265    },
266    cmp |a: &bf16, b: &bf16| a.total_cmp(b),
267    dtype DType::BF16
268);
269
270#[cfg(feature = "cubecl")]
271make_element!(
272    ty flex32,
273    convert |elem: &dyn ToElement| flex32::from_f32(elem.to_f32()),
274    random |distribution: Distribution, rng: &mut R| {
275        let sample: f32 = distribution.sampler(rng).sample();
276        flex32::from_elem(sample)
277    },
278    cmp |a: &flex32, b: &flex32| a.total_cmp(b),
279    dtype DType::Flex32,
280    min flex32::from_f32(f16::MIN.to_f32_const()),
281    max flex32::from_f32(f16::MAX.to_f32_const())
282);
283
284make_element!(
285    ty bool,
286    convert ToElement::to_bool,
287    random |distribution: Distribution, rng: &mut R| {
288        let sample: u8 = distribution.sampler(rng).sample();
289        bool::from_elem(sample)
290    },
291    cmp |a: &bool, b: &bool| Ord::cmp(a, b),
292    dtype DType::Bool,
293    min false,
294    max true
295);