1use core::cmp::Ordering;
2use rand::Rng;
3
4use crate::distribution::Distribution;
5use burn_std::{BoolStore, DType, bf16, f16};
6
7#[cfg(feature = "cubecl")]
8use burn_std::flex32;
9
10use super::cast::ToElement;
11
12pub trait Element:
17 ToElement
18 + ElementRandom
19 + ElementConversion
20 + ElementEq
21 + bytemuck::CheckedBitPattern
22 + bytemuck::NoUninit
23 + bytemuck::Zeroable
24 + core::fmt::Debug
25 + core::fmt::Display
26 + Default
27 + Send
28 + Sync
29 + Copy
30 + 'static
31{
32 fn dtype() -> DType;
34}
35
36pub trait ElementOrdered: Element + ElementComparison + ElementLimits {}
44
45pub trait ElementConversion {
47 fn from_elem<E: ToElement>(elem: E) -> Self;
57
58 fn elem<E: Element>(self) -> E;
60}
61
62pub trait ElementRandom {
64 fn random<R: Rng>(distribution: Distribution, rng: &mut R) -> Self;
75}
76
77pub trait ElementEq {
79 fn eq(&self, other: &Self) -> bool;
81}
82
83pub trait ElementComparison {
85 fn cmp(&self, other: &Self) -> Ordering;
87}
88
89pub trait ElementLimits {
91 const MIN: Self;
93 const MAX: Self;
95}
96
97#[macro_export]
99macro_rules! make_element {
100 (
101 ty $type:ident,
102 convert $convert:expr,
103 random $random:expr,
104 cmp $cmp:expr,
105 dtype $dtype:expr
106 ) => {
107 make_element!(ty $type, convert $convert, random $random, cmp $cmp, dtype $dtype, min $type::MIN, max $type::MAX);
108 };
109 (
110 ty $type:ident,
111 convert $convert:expr,
112 random $random:expr,
113 cmp $cmp:expr,
114 dtype $dtype:expr,
115 min $min:expr,
116 max $max:expr
117 ) => {
118 impl Element for $type {
119 #[inline(always)]
120 fn dtype() -> burn_std::DType {
121 $dtype
122 }
123 }
124 impl ElementEq for $type {
125 fn eq(&self, other: &Self) -> bool {
126 self == other
127 }
128 }
129
130 impl ElementConversion for $type {
131 #[inline(always)]
132 fn from_elem<E: ToElement>(elem: E) -> Self {
133 #[allow(clippy::redundant_closure_call)]
134 $convert(&elem)
135 }
136 #[inline(always)]
137 fn elem<E: Element>(self) -> E {
138 E::from_elem(self)
139 }
140 }
141
142 impl ElementRandom for $type {
143 fn random<R: Rng>(distribution: Distribution, rng: &mut R) -> Self {
144 #[allow(clippy::redundant_closure_call)]
145 $random(distribution, rng)
146 }
147 }
148
149 impl ElementComparison for $type {
150 fn cmp(&self, other: &Self) -> Ordering {
151 let a = self.elem::<$type>();
152 let b = other.elem::<$type>();
153 #[allow(clippy::redundant_closure_call)]
154 $cmp(&a, &b)
155 }
156 }
157
158 impl ElementLimits for $type {
159 const MIN: Self = $min;
160 const MAX: Self = $max;
161 }
162
163 impl ElementOrdered for $type {}
164
165 };
166}
167
168make_element!(
169 ty f64,
170 convert ToElement::to_f64,
171 random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(),
172 cmp |a: &f64, b: &f64| a.total_cmp(b),
173 dtype DType::F64
174);
175
176make_element!(
177 ty f32,
178 convert ToElement::to_f32,
179 random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(),
180 cmp |a: &f32, b: &f32| a.total_cmp(b),
181 dtype DType::F32
182);
183
184make_element!(
185 ty i64,
186 convert ToElement::to_i64,
187 random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(),
188 cmp |a: &i64, b: &i64| Ord::cmp(a, b),
189 dtype DType::I64
190);
191
192make_element!(
193 ty u64,
194 convert ToElement::to_u64,
195 random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(),
196 cmp |a: &u64, b: &u64| Ord::cmp(a, b),
197 dtype DType::U64
198);
199
200make_element!(
201 ty i32,
202 convert ToElement::to_i32,
203 random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(),
204 cmp |a: &i32, b: &i32| Ord::cmp(a, b),
205 dtype DType::I32
206);
207
208make_element!(
209 ty u32,
210 convert ToElement::to_u32,
211 random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(),
212 cmp |a: &u32, b: &u32| Ord::cmp(a, b),
213 dtype DType::U32
214);
215
216make_element!(
217 ty i16,
218 convert ToElement::to_i16,
219 random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(),
220 cmp |a: &i16, b: &i16| Ord::cmp(a, b),
221 dtype DType::I16
222);
223
224make_element!(
225 ty u16,
226 convert ToElement::to_u16,
227 random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(),
228 cmp |a: &u16, b: &u16| Ord::cmp(a, b),
229 dtype DType::U16
230);
231
232make_element!(
233 ty i8,
234 convert ToElement::to_i8,
235 random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(),
236 cmp |a: &i8, b: &i8| Ord::cmp(a, b),
237 dtype DType::I8
238);
239
240make_element!(
241 ty u8,
242 convert ToElement::to_u8,
243 random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(),
244 cmp |a: &u8, b: &u8| Ord::cmp(a, b),
245 dtype DType::U8
246);
247
248make_element!(
249 ty f16,
250 convert ToElement::to_f16,
251 random |distribution: Distribution, rng: &mut R| {
252 let sample: f32 = distribution.sampler(rng).sample();
253 f16::from_elem(sample)
254 },
255 cmp |a: &f16, b: &f16| a.total_cmp(b),
256 dtype DType::F16
257);
258make_element!(
259 ty bf16,
260 convert ToElement::to_bf16,
261 random |distribution: Distribution, rng: &mut R| {
262 let sample: f32 = distribution.sampler(rng).sample();
263 bf16::from_elem(sample)
264 },
265 cmp |a: &bf16, b: &bf16| a.total_cmp(b),
266 dtype DType::BF16
267);
268
269#[cfg(feature = "cubecl")]
270make_element!(
271 ty flex32,
272 convert |elem: &dyn ToElement| flex32::from_f32(elem.to_f32()),
273 random |distribution: Distribution, rng: &mut R| {
274 let sample: f32 = distribution.sampler(rng).sample();
275 flex32::from_elem(sample)
276 },
277 cmp |a: &flex32, b: &flex32| a.total_cmp(b),
278 dtype DType::Flex32,
279 min flex32::from_f32(f16::MIN.to_f32_const()),
280 max flex32::from_f32(f16::MAX.to_f32_const())
281);
282
283make_element!(
284 ty bool,
285 convert ToElement::to_bool,
286 random |distribution: Distribution, rng: &mut R| {
287 let sample: u8 = distribution.sampler(rng).sample();
288 bool::from_elem(sample)
289 },
290 cmp |a: &bool, b: &bool| Ord::cmp(a, b),
291 dtype DType::Bool(BoolStore::Native),
292 min false,
293 max true
294);