burn_ndarray/
element.rs

1use burn_tensor::Element;
2use ndarray::LinalgScalar;
3use num_traits::Signed;
4
5#[cfg(not(feature = "std"))]
6use num_traits::Float;
7
8use num_traits::Pow;
9
10use libm::{log1p, log1pf};
11
12/// A float element for ndarray backend.
13pub trait FloatNdArrayElement: NdArrayElement + LinalgScalar + Signed
14where
15    Self: Sized,
16{
17}
18
19pub trait IntNdArrayElement: NdArrayElement + Signed {}
20
21/// A general element for ndarray backend.
22pub trait NdArrayElement:
23    Element
24    + ndarray::LinalgScalar
25    + ndarray::ScalarOperand
26    + ExpElement
27    + num_traits::FromPrimitive
28    + core::ops::AddAssign
29    + core::cmp::PartialEq
30    + core::cmp::PartialOrd<Self>
31    + core::ops::Rem<Output = Self>
32{
33}
34
35/// A element for ndarray backend that supports exp ops.
36pub trait ExpElement {
37    fn exp_elem(self) -> Self;
38    fn log_elem(self) -> Self;
39    fn log1p_elem(self) -> Self;
40    fn powf_elem(self, value: f32) -> Self;
41    fn powi_elem(self, value: i32) -> Self;
42    fn sqrt_elem(self) -> Self;
43    fn abs_elem(self) -> Self;
44    fn int_abs_elem(self) -> Self;
45}
46
47/// A quantized element for the ndarray backend.
48pub trait QuantElement: NdArrayElement {}
49
50impl QuantElement for i8 {}
51
52impl FloatNdArrayElement for f64 {}
53impl FloatNdArrayElement for f32 {}
54
55impl IntNdArrayElement for i64 {}
56impl IntNdArrayElement for i32 {}
57
58macro_rules! make_elem {
59    (
60        double
61        $ty:ty
62    ) => {
63        impl NdArrayElement for $ty {}
64
65        impl ExpElement for $ty {
66            #[inline(always)]
67            fn exp_elem(self) -> Self {
68                (self as f64).exp() as $ty
69            }
70
71            #[inline(always)]
72            fn log_elem(self) -> Self {
73                (self as f64).ln() as $ty
74            }
75
76            #[inline(always)]
77            fn log1p_elem(self) -> Self {
78                log1p(self as f64) as $ty
79            }
80
81            #[inline(always)]
82            fn powf_elem(self, value: f32) -> Self {
83                (self as f64).pow(value) as $ty
84            }
85
86            #[inline(always)]
87            fn powi_elem(self, value: i32) -> Self {
88                #[cfg(feature = "std")]
89                let val = f64::powi(self as f64, value) as $ty;
90
91                #[cfg(not(feature = "std"))]
92                let val = Self::powf_elem(self, value as f32);
93
94                val
95            }
96
97            #[inline(always)]
98            fn sqrt_elem(self) -> Self {
99                (self as f64).sqrt() as $ty
100            }
101
102            #[inline(always)]
103            fn abs_elem(self) -> Self {
104                (self as f64).abs() as $ty
105            }
106
107            #[inline(always)]
108            fn int_abs_elem(self) -> Self {
109                (self as i64).abs() as $ty
110            }
111        }
112    };
113    (
114        single
115        $ty:ty
116    ) => {
117        impl NdArrayElement for $ty {}
118
119        impl ExpElement for $ty {
120            #[inline(always)]
121            fn exp_elem(self) -> Self {
122                (self as f32).exp() as $ty
123            }
124
125            #[inline(always)]
126            fn log_elem(self) -> Self {
127                (self as f32).ln() as $ty
128            }
129
130            #[inline(always)]
131            fn log1p_elem(self) -> Self {
132                log1pf(self as f32) as $ty
133            }
134
135            #[inline(always)]
136            fn powf_elem(self, value: f32) -> Self {
137                (self as f32).pow(value) as $ty
138            }
139
140            #[inline(always)]
141            fn powi_elem(self, value: i32) -> Self {
142                #[cfg(feature = "std")]
143                let val = f32::powi(self as f32, value) as $ty;
144
145                #[cfg(not(feature = "std"))]
146                let val = Self::powf_elem(self, value as f32);
147
148                val
149            }
150
151            #[inline(always)]
152            fn sqrt_elem(self) -> Self {
153                (self as f32).sqrt() as $ty
154            }
155
156            #[inline(always)]
157            fn abs_elem(self) -> Self {
158                (self as f32).abs() as $ty
159            }
160
161            #[inline(always)]
162            fn int_abs_elem(self) -> Self {
163                (self as i32).unsigned_abs() as $ty
164            }
165        }
166    };
167}
168
169make_elem!(double f64);
170make_elem!(double i64);
171
172make_elem!(single f32);
173make_elem!(single i32);
174make_elem!(single i16);
175make_elem!(single i8);
176make_elem!(single u8);