Skip to main content

burn_backend/element/
base.rs

1use core::cmp::Ordering;
2use rand::Rng;
3
4use crate::distribution::Distribution;
5use burn_std::{BoolStore, DType, bf16, f16};
6
7#[cfg(feature = "cubecl")]
8use burn_std::flex32;
9
10use super::cast::ToElement;
11
12/// Core element trait for tensor values.
13///
14/// This trait defines the minimal set of capabilities required for a type to be
15/// stored and manipulated as a tensor element across all backends.
16pub trait Element:
17    ToElement
18    + ElementRandom
19    + ElementConversion
20    + ElementEq
21    + bytemuck::CheckedBitPattern
22    + bytemuck::NoUninit
23    + bytemuck::Zeroable
24    + core::fmt::Debug
25    + core::fmt::Display
26    + Default
27    + Send
28    + Sync
29    + Copy
30    + 'static
31{
32    /// The dtype of the element.
33    fn dtype() -> DType;
34}
35
36/// Ordered element trait for tensor values.
37///
38/// This trait extends [`Element`] with ordering semantics, enabling comparison
39/// and order-dependent operations in generic Rust implementations.
40///
41/// Backends that implement these operations entirely at the device level do
42/// not rely on this trait. It only constrains the scalar type for generic Rust code.
43pub trait ElementOrdered: Element + ElementComparison + ElementLimits {}
44
45/// Element conversion trait for tensor.
46pub trait ElementConversion {
47    /// Converts an element to another element.
48    ///
49    /// # Arguments
50    ///
51    /// * `elem` - The element to convert.
52    ///
53    /// # Returns
54    ///
55    /// The converted element.
56    fn from_elem<E: ToElement>(elem: E) -> Self;
57
58    /// Converts and returns the converted element.
59    fn elem<E: Element>(self) -> E;
60}
61
62/// Element trait for random value of a tensor.
63pub trait ElementRandom {
64    /// Returns a random value for the given distribution.
65    ///
66    /// # Arguments
67    ///
68    /// * `distribution` - The distribution to sample from.
69    /// * `rng` - The random number generator.
70    ///
71    /// # Returns
72    ///
73    /// The random value.
74    fn random<R: Rng>(distribution: Distribution, rng: &mut R) -> Self;
75}
76
77/// Element trait for equality of a tensor.
78pub trait ElementEq {
79    /// Returns whether `self` and `other` are equal.
80    fn eq(&self, other: &Self) -> bool;
81}
82
83/// Element ordering trait.
84pub trait ElementComparison {
85    /// Returns and [Ordering] between `self` and `other`.
86    fn cmp(&self, other: &Self) -> Ordering;
87}
88
89/// Element limits trait.
90pub trait ElementLimits {
91    /// The minimum representable value
92    const MIN: Self;
93    /// The maximum representable value
94    const MAX: Self;
95}
96
97/// Macro to implement the element trait for a type.
98#[macro_export]
99macro_rules! make_element {
100    (
101        ty $type:ident,
102        convert $convert:expr,
103        random $random:expr,
104        cmp $cmp:expr,
105        dtype $dtype:expr
106    ) => {
107        make_element!(ty $type, convert $convert, random $random, cmp $cmp, dtype $dtype, min $type::MIN, max $type::MAX);
108    };
109    (
110        ty $type:ident,
111        convert $convert:expr,
112        random $random:expr,
113        cmp $cmp:expr,
114        dtype $dtype:expr,
115        min $min:expr,
116        max $max:expr
117    ) => {
118        impl Element for $type {
119            #[inline(always)]
120            fn dtype() -> burn_std::DType {
121                $dtype
122            }
123        }
124        impl ElementEq for $type {
125            fn eq(&self, other: &Self) -> bool {
126                self == other
127            }
128        }
129
130        impl ElementConversion for $type {
131            #[inline(always)]
132            fn from_elem<E: ToElement>(elem: E) -> Self {
133                #[allow(clippy::redundant_closure_call)]
134                $convert(&elem)
135            }
136            #[inline(always)]
137            fn elem<E: Element>(self) -> E {
138                E::from_elem(self)
139            }
140        }
141
142        impl ElementRandom for $type {
143            fn random<R: Rng>(distribution: Distribution, rng: &mut R) -> Self {
144                #[allow(clippy::redundant_closure_call)]
145                $random(distribution, rng)
146            }
147        }
148
149        impl ElementComparison for $type {
150            fn cmp(&self, other: &Self) -> Ordering {
151                let a = self.elem::<$type>();
152                let b = other.elem::<$type>();
153                #[allow(clippy::redundant_closure_call)]
154                $cmp(&a, &b)
155            }
156        }
157
158        impl ElementLimits for $type {
159            const MIN: Self = $min;
160            const MAX: Self = $max;
161        }
162
163        impl ElementOrdered for $type {}
164
165    };
166}
167
168make_element!(
169    ty f64,
170    convert ToElement::to_f64,
171    random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(),
172    cmp |a: &f64, b: &f64| a.total_cmp(b),
173    dtype DType::F64
174);
175
176make_element!(
177    ty f32,
178    convert ToElement::to_f32,
179    random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(),
180    cmp |a: &f32, b: &f32| a.total_cmp(b),
181    dtype DType::F32
182);
183
184make_element!(
185    ty i64,
186    convert ToElement::to_i64,
187    random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(),
188    cmp |a: &i64, b: &i64| Ord::cmp(a, b),
189    dtype DType::I64
190);
191
192make_element!(
193    ty u64,
194    convert ToElement::to_u64,
195    random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(),
196    cmp |a: &u64, b: &u64| Ord::cmp(a, b),
197    dtype DType::U64
198);
199
200make_element!(
201    ty i32,
202    convert ToElement::to_i32,
203    random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(),
204    cmp |a: &i32, b: &i32| Ord::cmp(a, b),
205    dtype DType::I32
206);
207
208make_element!(
209    ty u32,
210    convert ToElement::to_u32,
211    random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(),
212    cmp |a: &u32, b: &u32| Ord::cmp(a, b),
213    dtype DType::U32
214);
215
216make_element!(
217    ty i16,
218    convert ToElement::to_i16,
219    random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(),
220    cmp |a: &i16, b: &i16| Ord::cmp(a, b),
221    dtype DType::I16
222);
223
224make_element!(
225    ty u16,
226    convert ToElement::to_u16,
227    random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(),
228    cmp |a: &u16, b: &u16| Ord::cmp(a, b),
229    dtype DType::U16
230);
231
232make_element!(
233    ty i8,
234    convert ToElement::to_i8,
235    random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(),
236    cmp |a: &i8, b: &i8| Ord::cmp(a, b),
237    dtype DType::I8
238);
239
240make_element!(
241    ty u8,
242    convert ToElement::to_u8,
243    random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(),
244    cmp |a: &u8, b: &u8| Ord::cmp(a, b),
245    dtype DType::U8
246);
247
248make_element!(
249    ty f16,
250    convert ToElement::to_f16,
251    random |distribution: Distribution, rng: &mut R| {
252        let sample: f32 = distribution.sampler(rng).sample();
253        f16::from_elem(sample)
254    },
255    cmp |a: &f16, b: &f16| a.total_cmp(b),
256    dtype DType::F16
257);
258make_element!(
259    ty bf16,
260    convert ToElement::to_bf16,
261    random |distribution: Distribution, rng: &mut R| {
262        let sample: f32 = distribution.sampler(rng).sample();
263        bf16::from_elem(sample)
264    },
265    cmp |a: &bf16, b: &bf16| a.total_cmp(b),
266    dtype DType::BF16
267);
268
269#[cfg(feature = "cubecl")]
270make_element!(
271    ty flex32,
272    convert |elem: &dyn ToElement| flex32::from_f32(elem.to_f32()),
273    random |distribution: Distribution, rng: &mut R| {
274        let sample: f32 = distribution.sampler(rng).sample();
275        flex32::from_elem(sample)
276    },
277    cmp |a: &flex32, b: &flex32| a.total_cmp(b),
278    dtype DType::Flex32,
279    min flex32::from_f32(f16::MIN.to_f32_const()),
280    max flex32::from_f32(f16::MAX.to_f32_const())
281);
282
283make_element!(
284    ty bool,
285    convert ToElement::to_bool,
286    random |distribution: Distribution, rng: &mut R| {
287        let sample: u8 = distribution.sampler(rng).sample();
288        bool::from_elem(sample)
289    },
290    cmp |a: &bool, b: &bool| Ord::cmp(a, b),
291    dtype DType::Bool(BoolStore::Native),
292    min false,
293    max true
294);