burn_backend/element/
base.rs

1use core::cmp::Ordering;
2use rand::RngCore;
3
4use crate::distribution::Distribution;
5use burn_std::{DType, bf16, f16};
6
7#[cfg(feature = "cubecl")]
8use burn_std::flex32;
9
10use super::cast::ToElement;
11
12/// Element trait for tensor.
13pub trait Element:
14    ToElement
15    + ElementRandom
16    + ElementConversion
17    + ElementComparison
18    + ElementLimits
19    + bytemuck::CheckedBitPattern
20    + bytemuck::NoUninit
21    + bytemuck::Zeroable
22    + core::fmt::Debug
23    + core::fmt::Display
24    + Default
25    + Send
26    + Sync
27    + Copy
28    + 'static
29{
30    /// The dtype of the element.
31    fn dtype() -> DType;
32}
33
34/// Element conversion trait for tensor.
35pub trait ElementConversion {
36    /// Converts an element to another element.
37    ///
38    /// # Arguments
39    ///
40    /// * `elem` - The element to convert.
41    ///
42    /// # Returns
43    ///
44    /// The converted element.
45    fn from_elem<E: ToElement>(elem: E) -> Self;
46
47    /// Converts and returns the converted element.
48    fn elem<E: Element>(self) -> E;
49}
50
51/// Element trait for random value of a tensor.
52pub trait ElementRandom {
53    /// Returns a random value for the given distribution.
54    ///
55    /// # Arguments
56    ///
57    /// * `distribution` - The distribution to sample from.
58    /// * `rng` - The random number generator.
59    ///
60    /// # Returns
61    ///
62    /// The random value.
63    fn random<R: RngCore>(distribution: Distribution, rng: &mut R) -> Self;
64}
65
66/// Element ordering trait.
67pub trait ElementComparison {
68    /// Returns and [Ordering] between `self` and `other`.
69    fn cmp(&self, other: &Self) -> Ordering;
70}
71
72/// Element ordering trait.
73pub trait ElementLimits {
74    /// The minimum representable value
75    const MIN: Self;
76    /// The maximum representable value
77    const MAX: Self;
78}
79
80/// Macro to implement the element trait for a type.
81#[macro_export]
82macro_rules! make_element {
83    (
84        ty $type:ident,
85        convert $convert:expr,
86        random $random:expr,
87        cmp $cmp:expr,
88        dtype $dtype:expr
89    ) => {
90        make_element!(ty $type, convert $convert, random $random, cmp $cmp, dtype $dtype, min $type::MIN, max $type::MAX);
91    };
92    (
93        ty $type:ident,
94        convert $convert:expr,
95        random $random:expr,
96        cmp $cmp:expr,
97        dtype $dtype:expr,
98        min $min:expr,
99        max $max:expr
100    ) => {
101        impl Element for $type {
102            #[inline(always)]
103            fn dtype() -> burn_std::DType {
104                $dtype
105            }
106        }
107
108        impl ElementConversion for $type {
109            #[inline(always)]
110            fn from_elem<E: ToElement>(elem: E) -> Self {
111                #[allow(clippy::redundant_closure_call)]
112                $convert(&elem)
113            }
114            #[inline(always)]
115            fn elem<E: Element>(self) -> E {
116                E::from_elem(self)
117            }
118        }
119
120        impl ElementRandom for $type {
121            fn random<R: RngCore>(distribution: Distribution, rng: &mut R) -> Self {
122                #[allow(clippy::redundant_closure_call)]
123                $random(distribution, rng)
124            }
125        }
126
127        impl ElementComparison for $type {
128            fn cmp(&self, other: &Self) -> Ordering {
129                let a = self.elem::<$type>();
130                let b = other.elem::<$type>();
131                #[allow(clippy::redundant_closure_call)]
132                $cmp(&a, &b)
133            }
134        }
135
136        impl ElementLimits for $type {
137            const MIN: Self = $min;
138            const MAX: Self = $max;
139        }
140    };
141}
142
143make_element!(
144    ty f64,
145    convert ToElement::to_f64,
146    random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(),
147    cmp |a: &f64, b: &f64| a.total_cmp(b),
148    dtype DType::F64
149);
150
151make_element!(
152    ty f32,
153    convert ToElement::to_f32,
154    random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(),
155    cmp |a: &f32, b: &f32| a.total_cmp(b),
156    dtype DType::F32
157);
158
159make_element!(
160    ty i64,
161    convert ToElement::to_i64,
162    random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(),
163    cmp |a: &i64, b: &i64| Ord::cmp(a, b),
164    dtype DType::I64
165);
166
167make_element!(
168    ty u64,
169    convert ToElement::to_u64,
170    random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(),
171    cmp |a: &u64, b: &u64| Ord::cmp(a, b),
172    dtype DType::U64
173);
174
175make_element!(
176    ty i32,
177    convert ToElement::to_i32,
178    random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(),
179    cmp |a: &i32, b: &i32| Ord::cmp(a, b),
180    dtype DType::I32
181);
182
183make_element!(
184    ty u32,
185    convert ToElement::to_u32,
186    random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(),
187    cmp |a: &u32, b: &u32| Ord::cmp(a, b),
188    dtype DType::U32
189);
190
191make_element!(
192    ty i16,
193    convert ToElement::to_i16,
194    random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(),
195    cmp |a: &i16, b: &i16| Ord::cmp(a, b),
196    dtype DType::I16
197);
198
199make_element!(
200    ty u16,
201    convert ToElement::to_u16,
202    random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(),
203    cmp |a: &u16, b: &u16| Ord::cmp(a, b),
204    dtype DType::U16
205);
206
207make_element!(
208    ty i8,
209    convert ToElement::to_i8,
210    random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(),
211    cmp |a: &i8, b: &i8| Ord::cmp(a, b),
212    dtype DType::I8
213);
214
215make_element!(
216    ty u8,
217    convert ToElement::to_u8,
218    random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(),
219    cmp |a: &u8, b: &u8| Ord::cmp(a, b),
220    dtype DType::U8
221);
222
223make_element!(
224    ty f16,
225    convert ToElement::to_f16,
226    random |distribution: Distribution, rng: &mut R| {
227        let sample: f32 = distribution.sampler(rng).sample();
228        f16::from_elem(sample)
229    },
230    cmp |a: &f16, b: &f16| a.total_cmp(b),
231    dtype DType::F16
232);
233make_element!(
234    ty bf16,
235    convert ToElement::to_bf16,
236    random |distribution: Distribution, rng: &mut R| {
237        let sample: f32 = distribution.sampler(rng).sample();
238        bf16::from_elem(sample)
239    },
240    cmp |a: &bf16, b: &bf16| a.total_cmp(b),
241    dtype DType::BF16
242);
243
244#[cfg(feature = "cubecl")]
245make_element!(
246    ty flex32,
247    convert |elem: &dyn ToElement| flex32::from_f32(elem.to_f32()),
248    random |distribution: Distribution, rng: &mut R| {
249        let sample: f32 = distribution.sampler(rng).sample();
250        flex32::from_elem(sample)
251    },
252    cmp |a: &flex32, b: &flex32| a.total_cmp(b),
253    dtype DType::Flex32,
254    min flex32::from_f32(f16::MIN.to_f32_const()),
255    max flex32::from_f32(f16::MAX.to_f32_const())
256);
257
258make_element!(
259    ty bool,
260    convert ToElement::to_bool,
261    random |distribution: Distribution, rng: &mut R| {
262        let sample: u8 = distribution.sampler(rng).sample();
263        bool::from_elem(sample)
264    },
265    cmp |a: &bool, b: &bool| Ord::cmp(a, b),
266    dtype DType::Bool,
267    min false,
268    max true
269);