burn_ndarray/
element.rs

1use burn_tensor::Element;
2use ndarray::LinalgScalar;
3use num_traits::Signed;
4
5#[cfg(not(feature = "std"))]
6#[allow(unused_imports)]
7use num_traits::Float;
8
9use num_traits::Pow;
10
11use libm::{log1p, log1pf};
12
13/// A float element for ndarray backend.
14pub trait FloatNdArrayElement: NdArrayElement + LinalgScalar + Signed
15where
16    Self: Sized,
17{
18}
19
20/// An int element for ndarray backend.
21pub trait IntNdArrayElement: NdArrayElement {}
22
23/// A general element for ndarray backend.
24pub trait NdArrayElement:
25    Element
26    + ndarray::LinalgScalar
27    + ndarray::ScalarOperand
28    + ExpElement
29    + num_traits::FromPrimitive
30    + core::ops::AddAssign
31    + core::cmp::PartialEq
32    + core::cmp::PartialOrd<Self>
33    + core::ops::Rem<Output = Self>
34{
35}
36
37/// A element for ndarray backend that supports exp ops.
38pub trait ExpElement {
39    /// Exponent
40    fn exp_elem(self) -> Self;
41    /// Log
42    fn log_elem(self) -> Self;
43    /// Log1p
44    fn log1p_elem(self) -> Self;
45    /// Powf
46    fn powf_elem(self, value: f32) -> Self;
47    /// Powi
48    fn powi_elem(self, value: i32) -> Self;
49    /// Sqrt
50    fn sqrt_elem(self) -> Self;
51    /// Abs
52    fn abs_elem(self) -> Self;
53    /// Abs for int
54    fn int_abs_elem(self) -> Self;
55}
56
57/// A quantized element for the ndarray backend.
58pub trait QuantElement: NdArrayElement {}
59
60impl QuantElement for i8 {}
61
62impl FloatNdArrayElement for f64 {}
63impl FloatNdArrayElement for f32 {}
64
65impl IntNdArrayElement for i64 {}
66impl IntNdArrayElement for i32 {}
67impl IntNdArrayElement for i16 {}
68impl IntNdArrayElement for i8 {}
69
70impl IntNdArrayElement for u64 {}
71impl IntNdArrayElement for u32 {}
72impl IntNdArrayElement for u16 {}
73impl IntNdArrayElement for u8 {}
74
75macro_rules! make_elem {
76    (
77        double
78        $ty:ty
79    ) => {
80        impl NdArrayElement for $ty {}
81
82        #[allow(clippy::cast_abs_to_unsigned)]
83        impl ExpElement for $ty {
84            #[inline(always)]
85            fn exp_elem(self) -> Self {
86                (self as f64).exp() as $ty
87            }
88
89            #[inline(always)]
90            fn log_elem(self) -> Self {
91                (self as f64).ln() as $ty
92            }
93
94            #[inline(always)]
95            fn log1p_elem(self) -> Self {
96                log1p(self as f64) as $ty
97            }
98
99            #[inline(always)]
100            fn powf_elem(self, value: f32) -> Self {
101                (self as f64).pow(value) as $ty
102            }
103
104            #[inline(always)]
105            fn powi_elem(self, value: i32) -> Self {
106                #[cfg(feature = "std")]
107                let val = f64::powi(self as f64, value) as $ty;
108
109                #[cfg(not(feature = "std"))]
110                let val = Self::powf_elem(self, value as f32);
111
112                val
113            }
114
115            #[inline(always)]
116            fn sqrt_elem(self) -> Self {
117                (self as f64).sqrt() as $ty
118            }
119
120            #[inline(always)]
121            fn abs_elem(self) -> Self {
122                (self as f64).abs() as $ty
123            }
124
125            #[inline(always)]
126            fn int_abs_elem(self) -> Self {
127                (self as i64).abs() as $ty
128            }
129        }
130    };
131    (
132        single
133        $ty:ty
134    ) => {
135        impl NdArrayElement for $ty {}
136
137        impl ExpElement for $ty {
138            #[inline(always)]
139            fn exp_elem(self) -> Self {
140                (self as f32).exp() as $ty
141            }
142
143            #[inline(always)]
144            fn log_elem(self) -> Self {
145                (self as f32).ln() as $ty
146            }
147
148            #[inline(always)]
149            fn log1p_elem(self) -> Self {
150                log1pf(self as f32) as $ty
151            }
152
153            #[inline(always)]
154            fn powf_elem(self, value: f32) -> Self {
155                (self as f32).pow(value) as $ty
156            }
157
158            #[inline(always)]
159            fn powi_elem(self, value: i32) -> Self {
160                #[cfg(feature = "std")]
161                let val = f32::powi(self as f32, value) as $ty;
162
163                #[cfg(not(feature = "std"))]
164                let val = Self::powf_elem(self, value as f32);
165
166                val
167            }
168
169            #[inline(always)]
170            fn sqrt_elem(self) -> Self {
171                (self as f32).sqrt() as $ty
172            }
173
174            #[inline(always)]
175            fn abs_elem(self) -> Self {
176                (self as f32).abs() as $ty
177            }
178
179            #[inline(always)]
180            fn int_abs_elem(self) -> Self {
181                (self as i32).unsigned_abs() as $ty
182            }
183        }
184    };
185}
186
187make_elem!(double f64);
188make_elem!(double i64);
189make_elem!(double u64);
190
191make_elem!(single f32);
192make_elem!(single i32);
193make_elem!(single i16);
194make_elem!(single i8);
195make_elem!(single u32);
196make_elem!(single u16);
197make_elem!(single u8);