1use core::cmp::Ordering;
2use rand::RngCore;
3
4use crate::distribution::Distribution;
5use burn_std::{DType, bf16, f16};
6
7#[cfg(feature = "cubecl")]
8use burn_std::flex32;
9
10use super::cast::ToElement;
11
12pub trait Element:
14 ToElement
15 + ElementRandom
16 + ElementConversion
17 + ElementComparison
18 + ElementLimits
19 + bytemuck::CheckedBitPattern
20 + bytemuck::NoUninit
21 + bytemuck::Zeroable
22 + core::fmt::Debug
23 + core::fmt::Display
24 + Default
25 + Send
26 + Sync
27 + Copy
28 + 'static
29{
30 fn dtype() -> DType;
32}
33
34pub trait ElementConversion {
36 fn from_elem<E: ToElement>(elem: E) -> Self;
46
47 fn elem<E: Element>(self) -> E;
49}
50
51pub trait ElementRandom {
53 fn random<R: RngCore>(distribution: Distribution, rng: &mut R) -> Self;
64}
65
66pub trait ElementComparison {
68 fn cmp(&self, other: &Self) -> Ordering;
70}
71
72pub trait ElementLimits {
74 const MIN: Self;
76 const MAX: Self;
78}
79
80#[macro_export]
82macro_rules! make_element {
83 (
84 ty $type:ident,
85 convert $convert:expr,
86 random $random:expr,
87 cmp $cmp:expr,
88 dtype $dtype:expr
89 ) => {
90 make_element!(ty $type, convert $convert, random $random, cmp $cmp, dtype $dtype, min $type::MIN, max $type::MAX);
91 };
92 (
93 ty $type:ident,
94 convert $convert:expr,
95 random $random:expr,
96 cmp $cmp:expr,
97 dtype $dtype:expr,
98 min $min:expr,
99 max $max:expr
100 ) => {
101 impl Element for $type {
102 #[inline(always)]
103 fn dtype() -> burn_std::DType {
104 $dtype
105 }
106 }
107
108 impl ElementConversion for $type {
109 #[inline(always)]
110 fn from_elem<E: ToElement>(elem: E) -> Self {
111 #[allow(clippy::redundant_closure_call)]
112 $convert(&elem)
113 }
114 #[inline(always)]
115 fn elem<E: Element>(self) -> E {
116 E::from_elem(self)
117 }
118 }
119
120 impl ElementRandom for $type {
121 fn random<R: RngCore>(distribution: Distribution, rng: &mut R) -> Self {
122 #[allow(clippy::redundant_closure_call)]
123 $random(distribution, rng)
124 }
125 }
126
127 impl ElementComparison for $type {
128 fn cmp(&self, other: &Self) -> Ordering {
129 let a = self.elem::<$type>();
130 let b = other.elem::<$type>();
131 #[allow(clippy::redundant_closure_call)]
132 $cmp(&a, &b)
133 }
134 }
135
136 impl ElementLimits for $type {
137 const MIN: Self = $min;
138 const MAX: Self = $max;
139 }
140 };
141}
142
143make_element!(
144 ty f64,
145 convert ToElement::to_f64,
146 random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(),
147 cmp |a: &f64, b: &f64| a.total_cmp(b),
148 dtype DType::F64
149);
150
151make_element!(
152 ty f32,
153 convert ToElement::to_f32,
154 random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(),
155 cmp |a: &f32, b: &f32| a.total_cmp(b),
156 dtype DType::F32
157);
158
159make_element!(
160 ty i64,
161 convert ToElement::to_i64,
162 random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(),
163 cmp |a: &i64, b: &i64| Ord::cmp(a, b),
164 dtype DType::I64
165);
166
167make_element!(
168 ty u64,
169 convert ToElement::to_u64,
170 random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(),
171 cmp |a: &u64, b: &u64| Ord::cmp(a, b),
172 dtype DType::U64
173);
174
175make_element!(
176 ty i32,
177 convert ToElement::to_i32,
178 random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(),
179 cmp |a: &i32, b: &i32| Ord::cmp(a, b),
180 dtype DType::I32
181);
182
183make_element!(
184 ty u32,
185 convert ToElement::to_u32,
186 random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(),
187 cmp |a: &u32, b: &u32| Ord::cmp(a, b),
188 dtype DType::U32
189);
190
191make_element!(
192 ty i16,
193 convert ToElement::to_i16,
194 random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(),
195 cmp |a: &i16, b: &i16| Ord::cmp(a, b),
196 dtype DType::I16
197);
198
199make_element!(
200 ty u16,
201 convert ToElement::to_u16,
202 random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(),
203 cmp |a: &u16, b: &u16| Ord::cmp(a, b),
204 dtype DType::U16
205);
206
207make_element!(
208 ty i8,
209 convert ToElement::to_i8,
210 random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(),
211 cmp |a: &i8, b: &i8| Ord::cmp(a, b),
212 dtype DType::I8
213);
214
215make_element!(
216 ty u8,
217 convert ToElement::to_u8,
218 random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(),
219 cmp |a: &u8, b: &u8| Ord::cmp(a, b),
220 dtype DType::U8
221);
222
223make_element!(
224 ty f16,
225 convert ToElement::to_f16,
226 random |distribution: Distribution, rng: &mut R| {
227 let sample: f32 = distribution.sampler(rng).sample();
228 f16::from_elem(sample)
229 },
230 cmp |a: &f16, b: &f16| a.total_cmp(b),
231 dtype DType::F16
232);
233make_element!(
234 ty bf16,
235 convert ToElement::to_bf16,
236 random |distribution: Distribution, rng: &mut R| {
237 let sample: f32 = distribution.sampler(rng).sample();
238 bf16::from_elem(sample)
239 },
240 cmp |a: &bf16, b: &bf16| a.total_cmp(b),
241 dtype DType::BF16
242);
243
244#[cfg(feature = "cubecl")]
245make_element!(
246 ty flex32,
247 convert |elem: &dyn ToElement| flex32::from_f32(elem.to_f32()),
248 random |distribution: Distribution, rng: &mut R| {
249 let sample: f32 = distribution.sampler(rng).sample();
250 flex32::from_elem(sample)
251 },
252 cmp |a: &flex32, b: &flex32| a.total_cmp(b),
253 dtype DType::Flex32,
254 min flex32::from_f32(f16::MIN.to_f32_const()),
255 max flex32::from_f32(f16::MAX.to_f32_const())
256);
257
258make_element!(
259 ty bool,
260 convert ToElement::to_bool,
261 random |distribution: Distribution, rng: &mut R| {
262 let sample: u8 = distribution.sampler(rng).sample();
263 bool::from_elem(sample)
264 },
265 cmp |a: &bool, b: &bool| Ord::cmp(a, b),
266 dtype DType::Bool,
267 min false,
268 max true
269);