1use burn_tensor::Element;
2use ndarray::LinalgScalar;
3use num_traits::Signed;
4
5#[cfg(not(feature = "std"))]
6#[allow(unused_imports)]
7use num_traits::Float;
8
9use num_traits::Pow;
10
11use libm::{log1p, log1pf};
12
13pub trait FloatNdArrayElement: NdArrayElement + LinalgScalar + Signed
15where
16 Self: Sized,
17{
18}
19
20pub trait IntNdArrayElement: NdArrayElement + Signed {}
22
23pub trait NdArrayElement:
25 Element
26 + ndarray::LinalgScalar
27 + ndarray::ScalarOperand
28 + ExpElement
29 + num_traits::FromPrimitive
30 + core::ops::AddAssign
31 + core::cmp::PartialEq
32 + core::cmp::PartialOrd<Self>
33 + core::ops::Rem<Output = Self>
34{
35}
36
37pub trait ExpElement {
39 fn exp_elem(self) -> Self;
41 fn log_elem(self) -> Self;
43 fn log1p_elem(self) -> Self;
45 fn powf_elem(self, value: f32) -> Self;
47 fn powi_elem(self, value: i32) -> Self;
49 fn sqrt_elem(self) -> Self;
51 fn abs_elem(self) -> Self;
53 fn int_abs_elem(self) -> Self;
55}
56
57pub trait QuantElement: NdArrayElement {}
59
60impl QuantElement for i8 {}
61
62impl FloatNdArrayElement for f64 {}
63impl FloatNdArrayElement for f32 {}
64
65impl IntNdArrayElement for i64 {}
66impl IntNdArrayElement for i32 {}
67
68macro_rules! make_elem {
69 (
70 double
71 $ty:ty
72 ) => {
73 impl NdArrayElement for $ty {}
74
75 impl ExpElement for $ty {
76 #[inline(always)]
77 fn exp_elem(self) -> Self {
78 (self as f64).exp() as $ty
79 }
80
81 #[inline(always)]
82 fn log_elem(self) -> Self {
83 (self as f64).ln() as $ty
84 }
85
86 #[inline(always)]
87 fn log1p_elem(self) -> Self {
88 log1p(self as f64) as $ty
89 }
90
91 #[inline(always)]
92 fn powf_elem(self, value: f32) -> Self {
93 (self as f64).pow(value) as $ty
94 }
95
96 #[inline(always)]
97 fn powi_elem(self, value: i32) -> Self {
98 #[cfg(feature = "std")]
99 let val = f64::powi(self as f64, value) as $ty;
100
101 #[cfg(not(feature = "std"))]
102 let val = Self::powf_elem(self, value as f32);
103
104 val
105 }
106
107 #[inline(always)]
108 fn sqrt_elem(self) -> Self {
109 (self as f64).sqrt() as $ty
110 }
111
112 #[inline(always)]
113 fn abs_elem(self) -> Self {
114 (self as f64).abs() as $ty
115 }
116
117 #[inline(always)]
118 fn int_abs_elem(self) -> Self {
119 (self as i64).abs() as $ty
120 }
121 }
122 };
123 (
124 single
125 $ty:ty
126 ) => {
127 impl NdArrayElement for $ty {}
128
129 impl ExpElement for $ty {
130 #[inline(always)]
131 fn exp_elem(self) -> Self {
132 (self as f32).exp() as $ty
133 }
134
135 #[inline(always)]
136 fn log_elem(self) -> Self {
137 (self as f32).ln() as $ty
138 }
139
140 #[inline(always)]
141 fn log1p_elem(self) -> Self {
142 log1pf(self as f32) as $ty
143 }
144
145 #[inline(always)]
146 fn powf_elem(self, value: f32) -> Self {
147 (self as f32).pow(value) as $ty
148 }
149
150 #[inline(always)]
151 fn powi_elem(self, value: i32) -> Self {
152 #[cfg(feature = "std")]
153 let val = f32::powi(self as f32, value) as $ty;
154
155 #[cfg(not(feature = "std"))]
156 let val = Self::powf_elem(self, value as f32);
157
158 val
159 }
160
161 #[inline(always)]
162 fn sqrt_elem(self) -> Self {
163 (self as f32).sqrt() as $ty
164 }
165
166 #[inline(always)]
167 fn abs_elem(self) -> Self {
168 (self as f32).abs() as $ty
169 }
170
171 #[inline(always)]
172 fn int_abs_elem(self) -> Self {
173 (self as i32).unsigned_abs() as $ty
174 }
175 }
176 };
177}
178
179make_elem!(double f64);
180make_elem!(double i64);
181
182make_elem!(single f32);
183make_elem!(single i32);
184make_elem!(single i16);
185make_elem!(single i8);
186make_elem!(single u64);
187make_elem!(single u32);
188make_elem!(single u16);
189make_elem!(single u8);