1use burn_backend::Element;
2use num_traits::Signed;
3
4#[cfg(not(feature = "std"))]
5#[allow(unused_imports)]
6use num_traits::Float;
7
8use num_traits::Pow;
9
10use libm::{log1p, log1pf};
11
12pub trait FloatNdArrayElement: NdArrayElement + Signed
14where
15 Self: Sized,
16{
17}
18
19pub trait IntNdArrayElement: NdArrayElement {}
21
22pub trait NdArrayElement:
24 Element
25 + ndarray::LinalgScalar
26 + ndarray::ScalarOperand
27 + ExpElement
28 + AddAssignElement
29 + num_traits::FromPrimitive
30 + core::ops::AddAssign
31 + core::cmp::PartialEq
32 + core::cmp::PartialOrd<Self>
33 + core::ops::Rem<Output = Self>
34{
35}
36
37pub trait ExpElement {
39 fn exp_elem(self) -> Self;
41 fn log_elem(self) -> Self;
43 fn log1p_elem(self) -> Self;
45 fn powf_elem(self, value: f32) -> Self;
47 fn powi_elem(self, value: i32) -> Self;
49 fn sqrt_elem(self) -> Self;
51 fn abs_elem(self) -> Self;
53 fn int_abs_elem(self) -> Self;
55}
56
57pub trait AddAssignElement<Rhs = Self> {
59 fn add_assign(&mut self, rhs: Rhs);
63}
64
65impl<E: NdArrayElement> AddAssignElement for E {
66 fn add_assign(&mut self, rhs: Self) {
67 *self += rhs;
68 }
69}
70
71impl AddAssignElement for bool {
72 fn add_assign(&mut self, rhs: Self) {
73 *self = *self || rhs; }
75}
76
77pub trait QuantElement: NdArrayElement {}
79
80impl QuantElement for i8 {}
81
82impl FloatNdArrayElement for f64 {}
83impl FloatNdArrayElement for f32 {}
84
85impl IntNdArrayElement for i64 {}
86impl IntNdArrayElement for i32 {}
87impl IntNdArrayElement for i16 {}
88impl IntNdArrayElement for i8 {}
89
90impl IntNdArrayElement for u64 {}
91impl IntNdArrayElement for u32 {}
92impl IntNdArrayElement for u16 {}
93impl IntNdArrayElement for u8 {}
94
95macro_rules! make_elem {
96 (
97 double
98 $ty:ty
99 ) => {
100 impl NdArrayElement for $ty {}
101
102 #[allow(clippy::cast_abs_to_unsigned)]
103 impl ExpElement for $ty {
104 #[inline(always)]
105 fn exp_elem(self) -> Self {
106 (self as f64).exp() as $ty
107 }
108
109 #[inline(always)]
110 fn log_elem(self) -> Self {
111 (self as f64).ln() as $ty
112 }
113
114 #[inline(always)]
115 fn log1p_elem(self) -> Self {
116 log1p(self as f64) as $ty
117 }
118
119 #[inline(always)]
120 fn powf_elem(self, value: f32) -> Self {
121 (self as f64).pow(value) as $ty
122 }
123
124 #[inline(always)]
125 fn powi_elem(self, value: i32) -> Self {
126 #[cfg(feature = "std")]
127 let val = f64::powi(self as f64, value) as $ty;
128
129 #[cfg(not(feature = "std"))]
130 let val = Self::powf_elem(self, value as f32);
131
132 val
133 }
134
135 #[inline(always)]
136 fn sqrt_elem(self) -> Self {
137 (self as f64).sqrt() as $ty
138 }
139
140 #[inline(always)]
141 fn abs_elem(self) -> Self {
142 (self as f64).abs() as $ty
143 }
144
145 #[inline(always)]
146 fn int_abs_elem(self) -> Self {
147 (self as i64).abs() as $ty
148 }
149 }
150 };
151 (
152 single
153 $ty:ty
154 ) => {
155 impl NdArrayElement for $ty {}
156
157 impl ExpElement for $ty {
158 #[inline(always)]
159 fn exp_elem(self) -> Self {
160 (self as f32).exp() as $ty
161 }
162
163 #[inline(always)]
164 fn log_elem(self) -> Self {
165 (self as f32).ln() as $ty
166 }
167
168 #[inline(always)]
169 fn log1p_elem(self) -> Self {
170 log1pf(self as f32) as $ty
171 }
172
173 #[inline(always)]
174 fn powf_elem(self, value: f32) -> Self {
175 (self as f32).pow(value) as $ty
176 }
177
178 #[inline(always)]
179 fn powi_elem(self, value: i32) -> Self {
180 #[cfg(feature = "std")]
181 let val = f32::powi(self as f32, value) as $ty;
182
183 #[cfg(not(feature = "std"))]
184 let val = Self::powf_elem(self, value as f32);
185
186 val
187 }
188
189 #[inline(always)]
190 fn sqrt_elem(self) -> Self {
191 (self as f32).sqrt() as $ty
192 }
193
194 #[inline(always)]
195 fn abs_elem(self) -> Self {
196 (self as f32).abs() as $ty
197 }
198
199 #[inline(always)]
200 fn int_abs_elem(self) -> Self {
201 (self as i32).unsigned_abs() as $ty
202 }
203 }
204 };
205}
206
207make_elem!(double f64);
208make_elem!(double i64);
209make_elem!(double u64);
210
211make_elem!(single f32);
212make_elem!(single i32);
213make_elem!(single i16);
214make_elem!(single i8);
215make_elem!(single u32);
216make_elem!(single u16);
217make_elem!(single u8);