1use burn_tensor::Element;
2use ndarray::LinalgScalar;
3use num_traits::Signed;
4
5#[cfg(not(feature = "std"))]
6use num_traits::Float;
7
8use num_traits::Pow;
9
10use libm::{log1p, log1pf};
11
12pub trait FloatNdArrayElement: NdArrayElement + LinalgScalar + Signed
14where
15 Self: Sized,
16{
17}
18
19pub trait IntNdArrayElement: NdArrayElement + Signed {}
20
21pub trait NdArrayElement:
23 Element
24 + ndarray::LinalgScalar
25 + ndarray::ScalarOperand
26 + ExpElement
27 + num_traits::FromPrimitive
28 + core::ops::AddAssign
29 + core::cmp::PartialEq
30 + core::cmp::PartialOrd<Self>
31 + core::ops::Rem<Output = Self>
32{
33}
34
35pub trait ExpElement {
37 fn exp_elem(self) -> Self;
38 fn log_elem(self) -> Self;
39 fn log1p_elem(self) -> Self;
40 fn powf_elem(self, value: f32) -> Self;
41 fn powi_elem(self, value: i32) -> Self;
42 fn sqrt_elem(self) -> Self;
43 fn abs_elem(self) -> Self;
44 fn int_abs_elem(self) -> Self;
45}
46
47pub trait QuantElement: NdArrayElement {}
49
50impl QuantElement for i8 {}
51
52impl FloatNdArrayElement for f64 {}
53impl FloatNdArrayElement for f32 {}
54
55impl IntNdArrayElement for i64 {}
56impl IntNdArrayElement for i32 {}
57
58macro_rules! make_elem {
59 (
60 double
61 $ty:ty
62 ) => {
63 impl NdArrayElement for $ty {}
64
65 impl ExpElement for $ty {
66 #[inline(always)]
67 fn exp_elem(self) -> Self {
68 (self as f64).exp() as $ty
69 }
70
71 #[inline(always)]
72 fn log_elem(self) -> Self {
73 (self as f64).ln() as $ty
74 }
75
76 #[inline(always)]
77 fn log1p_elem(self) -> Self {
78 log1p(self as f64) as $ty
79 }
80
81 #[inline(always)]
82 fn powf_elem(self, value: f32) -> Self {
83 (self as f64).pow(value) as $ty
84 }
85
86 #[inline(always)]
87 fn powi_elem(self, value: i32) -> Self {
88 #[cfg(feature = "std")]
89 let val = f64::powi(self as f64, value) as $ty;
90
91 #[cfg(not(feature = "std"))]
92 let val = Self::powf_elem(self, value as f32);
93
94 val
95 }
96
97 #[inline(always)]
98 fn sqrt_elem(self) -> Self {
99 (self as f64).sqrt() as $ty
100 }
101
102 #[inline(always)]
103 fn abs_elem(self) -> Self {
104 (self as f64).abs() as $ty
105 }
106
107 #[inline(always)]
108 fn int_abs_elem(self) -> Self {
109 (self as i64).abs() as $ty
110 }
111 }
112 };
113 (
114 single
115 $ty:ty
116 ) => {
117 impl NdArrayElement for $ty {}
118
119 impl ExpElement for $ty {
120 #[inline(always)]
121 fn exp_elem(self) -> Self {
122 (self as f32).exp() as $ty
123 }
124
125 #[inline(always)]
126 fn log_elem(self) -> Self {
127 (self as f32).ln() as $ty
128 }
129
130 #[inline(always)]
131 fn log1p_elem(self) -> Self {
132 log1pf(self as f32) as $ty
133 }
134
135 #[inline(always)]
136 fn powf_elem(self, value: f32) -> Self {
137 (self as f32).pow(value) as $ty
138 }
139
140 #[inline(always)]
141 fn powi_elem(self, value: i32) -> Self {
142 #[cfg(feature = "std")]
143 let val = f32::powi(self as f32, value) as $ty;
144
145 #[cfg(not(feature = "std"))]
146 let val = Self::powf_elem(self, value as f32);
147
148 val
149 }
150
151 #[inline(always)]
152 fn sqrt_elem(self) -> Self {
153 (self as f32).sqrt() as $ty
154 }
155
156 #[inline(always)]
157 fn abs_elem(self) -> Self {
158 (self as f32).abs() as $ty
159 }
160
161 #[inline(always)]
162 fn int_abs_elem(self) -> Self {
163 (self as i32).unsigned_abs() as $ty
164 }
165 }
166 };
167}
168
169make_elem!(double f64);
170make_elem!(double i64);
171
172make_elem!(single f32);
173make_elem!(single i32);
174make_elem!(single i16);
175make_elem!(single i8);
176make_elem!(single u8);