1use burn_tensor::Element;
2use ndarray::LinalgScalar;
3use num_traits::Signed;
4
5#[cfg(not(feature = "std"))]
6#[allow(unused_imports)]
7use num_traits::Float;
8
9use num_traits::Pow;
10
11use libm::{log1p, log1pf};
12
13pub trait FloatNdArrayElement: NdArrayElement + LinalgScalar + Signed
15where
16 Self: Sized,
17{
18}
19
20pub trait IntNdArrayElement: NdArrayElement {}
22
23pub trait NdArrayElement:
25 Element
26 + ndarray::LinalgScalar
27 + ndarray::ScalarOperand
28 + ExpElement
29 + num_traits::FromPrimitive
30 + core::ops::AddAssign
31 + core::cmp::PartialEq
32 + core::cmp::PartialOrd<Self>
33 + core::ops::Rem<Output = Self>
34{
35}
36
37pub trait ExpElement {
39 fn exp_elem(self) -> Self;
41 fn log_elem(self) -> Self;
43 fn log1p_elem(self) -> Self;
45 fn powf_elem(self, value: f32) -> Self;
47 fn powi_elem(self, value: i32) -> Self;
49 fn sqrt_elem(self) -> Self;
51 fn abs_elem(self) -> Self;
53 fn int_abs_elem(self) -> Self;
55}
56
57pub trait QuantElement: NdArrayElement {}
59
60impl QuantElement for i8 {}
61
62impl FloatNdArrayElement for f64 {}
63impl FloatNdArrayElement for f32 {}
64
65impl IntNdArrayElement for i64 {}
66impl IntNdArrayElement for i32 {}
67impl IntNdArrayElement for i16 {}
68impl IntNdArrayElement for i8 {}
69
70impl IntNdArrayElement for u64 {}
71impl IntNdArrayElement for u32 {}
72impl IntNdArrayElement for u16 {}
73impl IntNdArrayElement for u8 {}
74
75macro_rules! make_elem {
76 (
77 double
78 $ty:ty
79 ) => {
80 impl NdArrayElement for $ty {}
81
82 #[allow(clippy::cast_abs_to_unsigned)]
83 impl ExpElement for $ty {
84 #[inline(always)]
85 fn exp_elem(self) -> Self {
86 (self as f64).exp() as $ty
87 }
88
89 #[inline(always)]
90 fn log_elem(self) -> Self {
91 (self as f64).ln() as $ty
92 }
93
94 #[inline(always)]
95 fn log1p_elem(self) -> Self {
96 log1p(self as f64) as $ty
97 }
98
99 #[inline(always)]
100 fn powf_elem(self, value: f32) -> Self {
101 (self as f64).pow(value) as $ty
102 }
103
104 #[inline(always)]
105 fn powi_elem(self, value: i32) -> Self {
106 #[cfg(feature = "std")]
107 let val = f64::powi(self as f64, value) as $ty;
108
109 #[cfg(not(feature = "std"))]
110 let val = Self::powf_elem(self, value as f32);
111
112 val
113 }
114
115 #[inline(always)]
116 fn sqrt_elem(self) -> Self {
117 (self as f64).sqrt() as $ty
118 }
119
120 #[inline(always)]
121 fn abs_elem(self) -> Self {
122 (self as f64).abs() as $ty
123 }
124
125 #[inline(always)]
126 fn int_abs_elem(self) -> Self {
127 (self as i64).abs() as $ty
128 }
129 }
130 };
131 (
132 single
133 $ty:ty
134 ) => {
135 impl NdArrayElement for $ty {}
136
137 impl ExpElement for $ty {
138 #[inline(always)]
139 fn exp_elem(self) -> Self {
140 (self as f32).exp() as $ty
141 }
142
143 #[inline(always)]
144 fn log_elem(self) -> Self {
145 (self as f32).ln() as $ty
146 }
147
148 #[inline(always)]
149 fn log1p_elem(self) -> Self {
150 log1pf(self as f32) as $ty
151 }
152
153 #[inline(always)]
154 fn powf_elem(self, value: f32) -> Self {
155 (self as f32).pow(value) as $ty
156 }
157
158 #[inline(always)]
159 fn powi_elem(self, value: i32) -> Self {
160 #[cfg(feature = "std")]
161 let val = f32::powi(self as f32, value) as $ty;
162
163 #[cfg(not(feature = "std"))]
164 let val = Self::powf_elem(self, value as f32);
165
166 val
167 }
168
169 #[inline(always)]
170 fn sqrt_elem(self) -> Self {
171 (self as f32).sqrt() as $ty
172 }
173
174 #[inline(always)]
175 fn abs_elem(self) -> Self {
176 (self as f32).abs() as $ty
177 }
178
179 #[inline(always)]
180 fn int_abs_elem(self) -> Self {
181 (self as i32).unsigned_abs() as $ty
182 }
183 }
184 };
185}
186
187make_elem!(double f64);
188make_elem!(double i64);
189make_elem!(double u64);
190
191make_elem!(single f32);
192make_elem!(single i32);
193make_elem!(single i16);
194make_elem!(single i8);
195make_elem!(single u32);
196make_elem!(single u16);
197make_elem!(single u8);