burn_ndarray/
element.rs

1use burn_tensor::Element;
2use ndarray::LinalgScalar;
3use num_traits::Signed;
4
5#[cfg(not(feature = "std"))]
6#[allow(unused_imports)]
7use num_traits::Float;
8
9use num_traits::Pow;
10
11use libm::{log1p, log1pf};
12
13/// A float element for ndarray backend.
14pub trait FloatNdArrayElement: NdArrayElement + LinalgScalar + Signed
15where
16    Self: Sized,
17{
18}
19
20/// An int element for ndarray backend.
21pub trait IntNdArrayElement: NdArrayElement + Signed {}
22
23/// A general element for ndarray backend.
24pub trait NdArrayElement:
25    Element
26    + ndarray::LinalgScalar
27    + ndarray::ScalarOperand
28    + ExpElement
29    + num_traits::FromPrimitive
30    + core::ops::AddAssign
31    + core::cmp::PartialEq
32    + core::cmp::PartialOrd<Self>
33    + core::ops::Rem<Output = Self>
34{
35}
36
37/// A element for ndarray backend that supports exp ops.
38pub trait ExpElement {
39    /// Exponent
40    fn exp_elem(self) -> Self;
41    /// Log
42    fn log_elem(self) -> Self;
43    /// Log1p
44    fn log1p_elem(self) -> Self;
45    /// Powf
46    fn powf_elem(self, value: f32) -> Self;
47    /// Powi
48    fn powi_elem(self, value: i32) -> Self;
49    /// Sqrt
50    fn sqrt_elem(self) -> Self;
51    /// Abs
52    fn abs_elem(self) -> Self;
53    /// Abs for int
54    fn int_abs_elem(self) -> Self;
55}
56
57/// A quantized element for the ndarray backend.
58pub trait QuantElement: NdArrayElement {}
59
60impl QuantElement for i8 {}
61
62impl FloatNdArrayElement for f64 {}
63impl FloatNdArrayElement for f32 {}
64
65impl IntNdArrayElement for i64 {}
66impl IntNdArrayElement for i32 {}
67
68macro_rules! make_elem {
69    (
70        double
71        $ty:ty
72    ) => {
73        impl NdArrayElement for $ty {}
74
75        impl ExpElement for $ty {
76            #[inline(always)]
77            fn exp_elem(self) -> Self {
78                (self as f64).exp() as $ty
79            }
80
81            #[inline(always)]
82            fn log_elem(self) -> Self {
83                (self as f64).ln() as $ty
84            }
85
86            #[inline(always)]
87            fn log1p_elem(self) -> Self {
88                log1p(self as f64) as $ty
89            }
90
91            #[inline(always)]
92            fn powf_elem(self, value: f32) -> Self {
93                (self as f64).pow(value) as $ty
94            }
95
96            #[inline(always)]
97            fn powi_elem(self, value: i32) -> Self {
98                #[cfg(feature = "std")]
99                let val = f64::powi(self as f64, value) as $ty;
100
101                #[cfg(not(feature = "std"))]
102                let val = Self::powf_elem(self, value as f32);
103
104                val
105            }
106
107            #[inline(always)]
108            fn sqrt_elem(self) -> Self {
109                (self as f64).sqrt() as $ty
110            }
111
112            #[inline(always)]
113            fn abs_elem(self) -> Self {
114                (self as f64).abs() as $ty
115            }
116
117            #[inline(always)]
118            fn int_abs_elem(self) -> Self {
119                (self as i64).abs() as $ty
120            }
121        }
122    };
123    (
124        single
125        $ty:ty
126    ) => {
127        impl NdArrayElement for $ty {}
128
129        impl ExpElement for $ty {
130            #[inline(always)]
131            fn exp_elem(self) -> Self {
132                (self as f32).exp() as $ty
133            }
134
135            #[inline(always)]
136            fn log_elem(self) -> Self {
137                (self as f32).ln() as $ty
138            }
139
140            #[inline(always)]
141            fn log1p_elem(self) -> Self {
142                log1pf(self as f32) as $ty
143            }
144
145            #[inline(always)]
146            fn powf_elem(self, value: f32) -> Self {
147                (self as f32).pow(value) as $ty
148            }
149
150            #[inline(always)]
151            fn powi_elem(self, value: i32) -> Self {
152                #[cfg(feature = "std")]
153                let val = f32::powi(self as f32, value) as $ty;
154
155                #[cfg(not(feature = "std"))]
156                let val = Self::powf_elem(self, value as f32);
157
158                val
159            }
160
161            #[inline(always)]
162            fn sqrt_elem(self) -> Self {
163                (self as f32).sqrt() as $ty
164            }
165
166            #[inline(always)]
167            fn abs_elem(self) -> Self {
168                (self as f32).abs() as $ty
169            }
170
171            #[inline(always)]
172            fn int_abs_elem(self) -> Self {
173                (self as i32).unsigned_abs() as $ty
174            }
175        }
176    };
177}
178
179make_elem!(double f64);
180make_elem!(double i64);
181
182make_elem!(single f32);
183make_elem!(single i32);
184make_elem!(single i16);
185make_elem!(single i8);
186make_elem!(single u64);
187make_elem!(single u32);
188make_elem!(single u16);
189make_elem!(single u8);