1use core::cmp::Ordering;
2use rand::RngCore;
3
4use crate::distribution::Distribution;
5use burn_std::{DType, bf16, f16};
6
7#[cfg(feature = "cubecl")]
8use burn_std::flex32;
9
10use super::cast::ToElement;
11
12pub trait Element:
17 ToElement
18 + ElementRandom
19 + ElementConversion
20 + ElementEq
21 + ElementLimits
22 + bytemuck::CheckedBitPattern
23 + bytemuck::NoUninit
24 + bytemuck::Zeroable
25 + core::fmt::Debug
26 + core::fmt::Display
27 + Default
28 + Send
29 + Sync
30 + Copy
31 + 'static
32{
33 fn dtype() -> DType;
35}
36
37pub trait ElementOrdered: Element + ElementComparison {}
45
46pub trait ElementConversion {
48 fn from_elem<E: ToElement>(elem: E) -> Self;
58
59 fn elem<E: Element>(self) -> E;
61}
62
63pub trait ElementRandom {
65 fn random<R: RngCore>(distribution: Distribution, rng: &mut R) -> Self;
76}
77
78pub trait ElementEq {
80 fn eq(&self, other: &Self) -> bool;
82}
83
84pub trait ElementComparison {
86 fn cmp(&self, other: &Self) -> Ordering;
88}
89
90pub trait ElementLimits {
92 const MIN: Self;
94 const MAX: Self;
96}
97
98#[macro_export]
100macro_rules! make_element {
101 (
102 ty $type:ident,
103 convert $convert:expr,
104 random $random:expr,
105 cmp $cmp:expr,
106 dtype $dtype:expr
107 ) => {
108 make_element!(ty $type, convert $convert, random $random, cmp $cmp, dtype $dtype, min $type::MIN, max $type::MAX);
109 };
110 (
111 ty $type:ident,
112 convert $convert:expr,
113 random $random:expr,
114 cmp $cmp:expr,
115 dtype $dtype:expr,
116 min $min:expr,
117 max $max:expr
118 ) => {
119 impl Element for $type {
120 #[inline(always)]
121 fn dtype() -> burn_std::DType {
122 $dtype
123 }
124 }
125 impl ElementEq for $type {
126 fn eq(&self, other: &Self) -> bool {
127 self == other
128 }
129 }
130
131 impl ElementConversion for $type {
132 #[inline(always)]
133 fn from_elem<E: ToElement>(elem: E) -> Self {
134 #[allow(clippy::redundant_closure_call)]
135 $convert(&elem)
136 }
137 #[inline(always)]
138 fn elem<E: Element>(self) -> E {
139 E::from_elem(self)
140 }
141 }
142
143 impl ElementRandom for $type {
144 fn random<R: RngCore>(distribution: Distribution, rng: &mut R) -> Self {
145 #[allow(clippy::redundant_closure_call)]
146 $random(distribution, rng)
147 }
148 }
149
150 impl ElementComparison for $type {
151 fn cmp(&self, other: &Self) -> Ordering {
152 let a = self.elem::<$type>();
153 let b = other.elem::<$type>();
154 #[allow(clippy::redundant_closure_call)]
155 $cmp(&a, &b)
156 }
157 }
158
159 impl ElementLimits for $type {
160 const MIN: Self = $min;
161 const MAX: Self = $max;
162 }
163
164 impl ElementOrdered for $type {}
165
166 };
167}
168
169make_element!(
170 ty f64,
171 convert ToElement::to_f64,
172 random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(),
173 cmp |a: &f64, b: &f64| a.total_cmp(b),
174 dtype DType::F64
175);
176
177make_element!(
178 ty f32,
179 convert ToElement::to_f32,
180 random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(),
181 cmp |a: &f32, b: &f32| a.total_cmp(b),
182 dtype DType::F32
183);
184
185make_element!(
186 ty i64,
187 convert ToElement::to_i64,
188 random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(),
189 cmp |a: &i64, b: &i64| Ord::cmp(a, b),
190 dtype DType::I64
191);
192
193make_element!(
194 ty u64,
195 convert ToElement::to_u64,
196 random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(),
197 cmp |a: &u64, b: &u64| Ord::cmp(a, b),
198 dtype DType::U64
199);
200
201make_element!(
202 ty i32,
203 convert ToElement::to_i32,
204 random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(),
205 cmp |a: &i32, b: &i32| Ord::cmp(a, b),
206 dtype DType::I32
207);
208
209make_element!(
210 ty u32,
211 convert ToElement::to_u32,
212 random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(),
213 cmp |a: &u32, b: &u32| Ord::cmp(a, b),
214 dtype DType::U32
215);
216
217make_element!(
218 ty i16,
219 convert ToElement::to_i16,
220 random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(),
221 cmp |a: &i16, b: &i16| Ord::cmp(a, b),
222 dtype DType::I16
223);
224
225make_element!(
226 ty u16,
227 convert ToElement::to_u16,
228 random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(),
229 cmp |a: &u16, b: &u16| Ord::cmp(a, b),
230 dtype DType::U16
231);
232
233make_element!(
234 ty i8,
235 convert ToElement::to_i8,
236 random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(),
237 cmp |a: &i8, b: &i8| Ord::cmp(a, b),
238 dtype DType::I8
239);
240
241make_element!(
242 ty u8,
243 convert ToElement::to_u8,
244 random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(),
245 cmp |a: &u8, b: &u8| Ord::cmp(a, b),
246 dtype DType::U8
247);
248
249make_element!(
250 ty f16,
251 convert ToElement::to_f16,
252 random |distribution: Distribution, rng: &mut R| {
253 let sample: f32 = distribution.sampler(rng).sample();
254 f16::from_elem(sample)
255 },
256 cmp |a: &f16, b: &f16| a.total_cmp(b),
257 dtype DType::F16
258);
259make_element!(
260 ty bf16,
261 convert ToElement::to_bf16,
262 random |distribution: Distribution, rng: &mut R| {
263 let sample: f32 = distribution.sampler(rng).sample();
264 bf16::from_elem(sample)
265 },
266 cmp |a: &bf16, b: &bf16| a.total_cmp(b),
267 dtype DType::BF16
268);
269
270#[cfg(feature = "cubecl")]
271make_element!(
272 ty flex32,
273 convert |elem: &dyn ToElement| flex32::from_f32(elem.to_f32()),
274 random |distribution: Distribution, rng: &mut R| {
275 let sample: f32 = distribution.sampler(rng).sample();
276 flex32::from_elem(sample)
277 },
278 cmp |a: &flex32, b: &flex32| a.total_cmp(b),
279 dtype DType::Flex32,
280 min flex32::from_f32(f16::MIN.to_f32_const()),
281 max flex32::from_f32(f16::MAX.to_f32_const())
282);
283
284make_element!(
285 ty bool,
286 convert ToElement::to_bool,
287 random |distribution: Distribution, rng: &mut R| {
288 let sample: u8 = distribution.sampler(rng).sample();
289 bool::from_elem(sample)
290 },
291 cmp |a: &bool, b: &bool| Ord::cmp(a, b),
292 dtype DType::Bool,
293 min false,
294 max true
295);