burn_ndarray/
element.rs

1use burn_tensor::Element;
2use ndarray::LinalgScalar;
3use num_traits::Signed;
4
5#[cfg(not(feature = "std"))]
6use num_traits::Float;
7
8use num_traits::Pow;
9
10use libm::{log1p, log1pf};
11
12/// A float element for ndarray backend.
13pub trait FloatNdArrayElement: NdArrayElement + LinalgScalar + Signed
14where
15    Self: Sized,
16{
17}
18
19/// An int element for ndarray backend.
20pub trait IntNdArrayElement: NdArrayElement + Signed {}
21
22/// A general element for ndarray backend.
23pub trait NdArrayElement:
24    Element
25    + ndarray::LinalgScalar
26    + ndarray::ScalarOperand
27    + ExpElement
28    + num_traits::FromPrimitive
29    + core::ops::AddAssign
30    + core::cmp::PartialEq
31    + core::cmp::PartialOrd<Self>
32    + core::ops::Rem<Output = Self>
33{
34}
35
36/// A element for ndarray backend that supports exp ops.
37pub trait ExpElement {
38    /// Exponent
39    fn exp_elem(self) -> Self;
40    /// Log
41    fn log_elem(self) -> Self;
42    /// Log1p
43    fn log1p_elem(self) -> Self;
44    /// Powf
45    fn powf_elem(self, value: f32) -> Self;
46    /// Powi
47    fn powi_elem(self, value: i32) -> Self;
48    /// Sqrt
49    fn sqrt_elem(self) -> Self;
50    /// Abs
51    fn abs_elem(self) -> Self;
52    /// Abs for int
53    fn int_abs_elem(self) -> Self;
54}
55
56/// A quantized element for the ndarray backend.
57pub trait QuantElement: NdArrayElement {}
58
59impl QuantElement for i8 {}
60
61impl FloatNdArrayElement for f64 {}
62impl FloatNdArrayElement for f32 {}
63
64impl IntNdArrayElement for i64 {}
65impl IntNdArrayElement for i32 {}
66
67macro_rules! make_elem {
68    (
69        double
70        $ty:ty
71    ) => {
72        impl NdArrayElement for $ty {}
73
74        impl ExpElement for $ty {
75            #[inline(always)]
76            fn exp_elem(self) -> Self {
77                (self as f64).exp() as $ty
78            }
79
80            #[inline(always)]
81            fn log_elem(self) -> Self {
82                (self as f64).ln() as $ty
83            }
84
85            #[inline(always)]
86            fn log1p_elem(self) -> Self {
87                log1p(self as f64) as $ty
88            }
89
90            #[inline(always)]
91            fn powf_elem(self, value: f32) -> Self {
92                (self as f64).pow(value) as $ty
93            }
94
95            #[inline(always)]
96            fn powi_elem(self, value: i32) -> Self {
97                #[cfg(feature = "std")]
98                let val = f64::powi(self as f64, value) as $ty;
99
100                #[cfg(not(feature = "std"))]
101                let val = Self::powf_elem(self, value as f32);
102
103                val
104            }
105
106            #[inline(always)]
107            fn sqrt_elem(self) -> Self {
108                (self as f64).sqrt() as $ty
109            }
110
111            #[inline(always)]
112            fn abs_elem(self) -> Self {
113                (self as f64).abs() as $ty
114            }
115
116            #[inline(always)]
117            fn int_abs_elem(self) -> Self {
118                (self as i64).abs() as $ty
119            }
120        }
121    };
122    (
123        single
124        $ty:ty
125    ) => {
126        impl NdArrayElement for $ty {}
127
128        impl ExpElement for $ty {
129            #[inline(always)]
130            fn exp_elem(self) -> Self {
131                (self as f32).exp() as $ty
132            }
133
134            #[inline(always)]
135            fn log_elem(self) -> Self {
136                (self as f32).ln() as $ty
137            }
138
139            #[inline(always)]
140            fn log1p_elem(self) -> Self {
141                log1pf(self as f32) as $ty
142            }
143
144            #[inline(always)]
145            fn powf_elem(self, value: f32) -> Self {
146                (self as f32).pow(value) as $ty
147            }
148
149            #[inline(always)]
150            fn powi_elem(self, value: i32) -> Self {
151                #[cfg(feature = "std")]
152                let val = f32::powi(self as f32, value) as $ty;
153
154                #[cfg(not(feature = "std"))]
155                let val = Self::powf_elem(self, value as f32);
156
157                val
158            }
159
160            #[inline(always)]
161            fn sqrt_elem(self) -> Self {
162                (self as f32).sqrt() as $ty
163            }
164
165            #[inline(always)]
166            fn abs_elem(self) -> Self {
167                (self as f32).abs() as $ty
168            }
169
170            #[inline(always)]
171            fn int_abs_elem(self) -> Self {
172                (self as i32).unsigned_abs() as $ty
173            }
174        }
175    };
176}
177
178make_elem!(double f64);
179make_elem!(double i64);
180
181make_elem!(single f32);
182make_elem!(single i32);
183make_elem!(single i16);
184make_elem!(single i8);
185make_elem!(single u64);
186make_elem!(single u32);
187make_elem!(single u16);
188make_elem!(single u8);