1use burn_tensor::Element;
2use ndarray::LinalgScalar;
3use num_traits::Signed;
4
5#[cfg(not(feature = "std"))]
6use num_traits::Float;
7
8use num_traits::Pow;
9
10use libm::{log1p, log1pf};
11
12pub trait FloatNdArrayElement: NdArrayElement + LinalgScalar + Signed
14where
15 Self: Sized,
16{
17}
18
19pub trait IntNdArrayElement: NdArrayElement + Signed {}
21
22pub trait NdArrayElement:
24 Element
25 + ndarray::LinalgScalar
26 + ndarray::ScalarOperand
27 + ExpElement
28 + num_traits::FromPrimitive
29 + core::ops::AddAssign
30 + core::cmp::PartialEq
31 + core::cmp::PartialOrd<Self>
32 + core::ops::Rem<Output = Self>
33{
34}
35
36pub trait ExpElement {
38 fn exp_elem(self) -> Self;
40 fn log_elem(self) -> Self;
42 fn log1p_elem(self) -> Self;
44 fn powf_elem(self, value: f32) -> Self;
46 fn powi_elem(self, value: i32) -> Self;
48 fn sqrt_elem(self) -> Self;
50 fn abs_elem(self) -> Self;
52 fn int_abs_elem(self) -> Self;
54}
55
56pub trait QuantElement: NdArrayElement {}
58
59impl QuantElement for i8 {}
60
61impl FloatNdArrayElement for f64 {}
62impl FloatNdArrayElement for f32 {}
63
64impl IntNdArrayElement for i64 {}
65impl IntNdArrayElement for i32 {}
66
67macro_rules! make_elem {
68 (
69 double
70 $ty:ty
71 ) => {
72 impl NdArrayElement for $ty {}
73
74 impl ExpElement for $ty {
75 #[inline(always)]
76 fn exp_elem(self) -> Self {
77 (self as f64).exp() as $ty
78 }
79
80 #[inline(always)]
81 fn log_elem(self) -> Self {
82 (self as f64).ln() as $ty
83 }
84
85 #[inline(always)]
86 fn log1p_elem(self) -> Self {
87 log1p(self as f64) as $ty
88 }
89
90 #[inline(always)]
91 fn powf_elem(self, value: f32) -> Self {
92 (self as f64).pow(value) as $ty
93 }
94
95 #[inline(always)]
96 fn powi_elem(self, value: i32) -> Self {
97 #[cfg(feature = "std")]
98 let val = f64::powi(self as f64, value) as $ty;
99
100 #[cfg(not(feature = "std"))]
101 let val = Self::powf_elem(self, value as f32);
102
103 val
104 }
105
106 #[inline(always)]
107 fn sqrt_elem(self) -> Self {
108 (self as f64).sqrt() as $ty
109 }
110
111 #[inline(always)]
112 fn abs_elem(self) -> Self {
113 (self as f64).abs() as $ty
114 }
115
116 #[inline(always)]
117 fn int_abs_elem(self) -> Self {
118 (self as i64).abs() as $ty
119 }
120 }
121 };
122 (
123 single
124 $ty:ty
125 ) => {
126 impl NdArrayElement for $ty {}
127
128 impl ExpElement for $ty {
129 #[inline(always)]
130 fn exp_elem(self) -> Self {
131 (self as f32).exp() as $ty
132 }
133
134 #[inline(always)]
135 fn log_elem(self) -> Self {
136 (self as f32).ln() as $ty
137 }
138
139 #[inline(always)]
140 fn log1p_elem(self) -> Self {
141 log1pf(self as f32) as $ty
142 }
143
144 #[inline(always)]
145 fn powf_elem(self, value: f32) -> Self {
146 (self as f32).pow(value) as $ty
147 }
148
149 #[inline(always)]
150 fn powi_elem(self, value: i32) -> Self {
151 #[cfg(feature = "std")]
152 let val = f32::powi(self as f32, value) as $ty;
153
154 #[cfg(not(feature = "std"))]
155 let val = Self::powf_elem(self, value as f32);
156
157 val
158 }
159
160 #[inline(always)]
161 fn sqrt_elem(self) -> Self {
162 (self as f32).sqrt() as $ty
163 }
164
165 #[inline(always)]
166 fn abs_elem(self) -> Self {
167 (self as f32).abs() as $ty
168 }
169
170 #[inline(always)]
171 fn int_abs_elem(self) -> Self {
172 (self as i32).unsigned_abs() as $ty
173 }
174 }
175 };
176}
177
178make_elem!(double f64);
179make_elem!(double i64);
180
181make_elem!(single f32);
182make_elem!(single i32);
183make_elem!(single i16);
184make_elem!(single i8);
185make_elem!(single u64);
186make_elem!(single u32);
187make_elem!(single u16);
188make_elem!(single u8);