1use burn_backend::Element;
2use num_traits::Signed;
3
4#[cfg(not(feature = "std"))]
5#[allow(unused_imports)]
6use num_traits::Float;
7
8use num_traits::Pow;
9
10use libm::{log1p, log1pf};
11
12pub trait FloatNdArrayElement: NdArrayElement + Signed + core::cmp::PartialOrd<Self>
14where
15 Self: Sized,
16{
17}
18
19pub trait IntNdArrayElement: NdArrayElement + core::cmp::PartialOrd<Self> {}
21
22pub trait NdArrayElement:
24 Element
25 + ndarray::LinalgScalar
26 + ndarray::ScalarOperand
27 + ExpElement
28 + AddAssignElement
29 + num_traits::FromPrimitive
30 + core::ops::AddAssign
31 + core::cmp::PartialEq
32 + core::ops::Rem<Output = Self>
33{
34}
35
36pub trait ExpElement {
38 fn exp_elem(self) -> Self;
40 fn log_elem(self) -> Self;
42 fn log1p_elem(self) -> Self;
44 fn powf_elem(self, value: f32) -> Self;
46 fn powi_elem(self, value: i32) -> Self;
48 fn sqrt_elem(self) -> Self;
50 fn abs_elem(self) -> Self;
52}
53
54pub trait AddAssignElement<Rhs = Self> {
56 fn add_assign(&mut self, rhs: Rhs);
60}
61
62impl<E: NdArrayElement> AddAssignElement for E {
63 fn add_assign(&mut self, rhs: Self) {
64 *self += rhs;
65 }
66}
67
68impl AddAssignElement for bool {
69 fn add_assign(&mut self, rhs: Self) {
70 *self = *self || rhs; }
72}
73
74pub trait QuantElement: NdArrayElement {}
76
77impl QuantElement for i8 {}
78
79impl FloatNdArrayElement for f64 {}
80impl FloatNdArrayElement for f32 {}
81
82impl IntNdArrayElement for i64 {}
83impl IntNdArrayElement for i32 {}
84impl IntNdArrayElement for i16 {}
85impl IntNdArrayElement for i8 {}
86
87impl IntNdArrayElement for u64 {}
88impl IntNdArrayElement for u32 {}
89impl IntNdArrayElement for u16 {}
90impl IntNdArrayElement for u8 {}
91
92macro_rules! make_float {
93 (
94 $ty:ty,
95 $log1p:expr
96 ) => {
97 impl NdArrayElement for $ty {}
98
99 #[allow(clippy::cast_abs_to_unsigned)]
100 impl ExpElement for $ty {
101 #[inline(always)]
102 fn exp_elem(self) -> Self {
103 self.exp()
104 }
105
106 #[inline(always)]
107 fn log_elem(self) -> Self {
108 self.ln()
109 }
110
111 #[inline(always)]
112 fn log1p_elem(self) -> Self {
113 $log1p(self)
114 }
115
116 #[inline(always)]
117 fn powf_elem(self, value: f32) -> Self {
118 self.pow(value)
119 }
120
121 #[inline(always)]
122 fn powi_elem(self, value: i32) -> Self {
123 #[cfg(feature = "std")]
124 let val = self.powi(value);
125
126 #[cfg(not(feature = "std"))]
127 let val = Self::powf_elem(self, value as f32);
128
129 val
130 }
131
132 #[inline(always)]
133 fn sqrt_elem(self) -> Self {
134 self.sqrt()
135 }
136
137 #[inline(always)]
138 fn abs_elem(self) -> Self {
139 self.abs()
140 }
141 }
142 };
143}
144macro_rules! make_int {
145 (
146 $ty:ty,
147 $abs:expr
148 ) => {
149 impl NdArrayElement for $ty {}
150
151 #[allow(clippy::cast_abs_to_unsigned)]
152 impl ExpElement for $ty {
153 #[inline(always)]
154 fn exp_elem(self) -> Self {
155 (self as f32).exp() as $ty
156 }
157
158 #[inline(always)]
159 fn log_elem(self) -> Self {
160 (self as f32).ln() as $ty
161 }
162
163 #[inline(always)]
164 fn log1p_elem(self) -> Self {
165 log1pf(self as f32) as $ty
166 }
167
168 #[inline(always)]
169 fn powf_elem(self, value: f32) -> Self {
170 (self as f32).pow(value) as $ty
171 }
172
173 #[inline(always)]
174 fn powi_elem(self, value: i32) -> Self {
175 #[cfg(feature = "std")]
176 let val = f32::powi(self as f32, value) as $ty;
177
178 #[cfg(not(feature = "std"))]
179 let val = Self::powf_elem(self, value as f32);
180
181 val
182 }
183
184 #[inline(always)]
185 fn sqrt_elem(self) -> Self {
186 (self as f32).sqrt() as $ty
187 }
188
189 #[inline(always)]
190 fn abs_elem(self) -> Self {
191 $abs(self)
192 }
193 }
194 };
195}
196
197make_float!(f64, log1p);
198make_float!(f32, log1pf);
199
200make_int!(i64, i64::wrapping_abs);
201make_int!(i32, i32::wrapping_abs);
202make_int!(i16, i16::wrapping_abs);
203make_int!(i8, i8::wrapping_abs);
204make_int!(u64, |x| x);
205make_int!(u32, |x| x);
206make_int!(u16, |x| x);
207make_int!(u8, |x| x);