Skip to main content

burn_ndarray/
element.rs

1use burn_backend::Element;
2use num_traits::Signed;
3
4#[cfg(not(feature = "std"))]
5#[allow(unused_imports)]
6use num_traits::Float;
7
8use num_traits::Pow;
9
10use libm::{log1p, log1pf};
11
12/// A float element for ndarray backend.
13pub trait FloatNdArrayElement: NdArrayElement + Signed + core::cmp::PartialOrd<Self>
14where
15    Self: Sized,
16{
17}
18
19/// An int element for ndarray backend.
20pub trait IntNdArrayElement: NdArrayElement + core::cmp::PartialOrd<Self> {}
21
22/// A general element for ndarray backend.
23pub trait NdArrayElement:
24    Element
25    + ndarray::LinalgScalar
26    + ndarray::ScalarOperand
27    + ExpElement
28    + AddAssignElement
29    + num_traits::FromPrimitive
30    + core::ops::AddAssign
31    + core::cmp::PartialEq
32    + core::ops::Rem<Output = Self>
33{
34}
35
36/// A element for ndarray backend that supports exp ops.
37pub trait ExpElement {
38    /// Exponent
39    fn exp_elem(self) -> Self;
40    /// Log
41    fn log_elem(self) -> Self;
42    /// Log1p
43    fn log1p_elem(self) -> Self;
44    /// Powf
45    fn powf_elem(self, value: f32) -> Self;
46    /// Powi
47    fn powi_elem(self, value: i32) -> Self;
48    /// Sqrt
49    fn sqrt_elem(self) -> Self;
50    /// Abs
51    fn abs_elem(self) -> Self;
52}
53
54/// The addition assignment operator implemented for ndarray elements.
55pub trait AddAssignElement<Rhs = Self> {
56    /// Performs the addition assignment operation.
57    ///
58    /// For `bool`, this corresponds to logical OR assignment.
59    fn add_assign(&mut self, rhs: Rhs);
60}
61
62impl<E: NdArrayElement> AddAssignElement for E {
63    fn add_assign(&mut self, rhs: Self) {
64        *self += rhs;
65    }
66}
67
68impl AddAssignElement for bool {
69    fn add_assign(&mut self, rhs: Self) {
70        *self = *self || rhs; // logical OR for bool
71    }
72}
73
74/// A quantized element for the ndarray backend.
75pub trait QuantElement: NdArrayElement {}
76
77impl QuantElement for i8 {}
78
79impl FloatNdArrayElement for f64 {}
80impl FloatNdArrayElement for f32 {}
81
82impl IntNdArrayElement for i64 {}
83impl IntNdArrayElement for i32 {}
84impl IntNdArrayElement for i16 {}
85impl IntNdArrayElement for i8 {}
86
87impl IntNdArrayElement for u64 {}
88impl IntNdArrayElement for u32 {}
89impl IntNdArrayElement for u16 {}
90impl IntNdArrayElement for u8 {}
91
92macro_rules! make_float {
93    (
94        $ty:ty,
95        $log1p:expr
96    ) => {
97        impl NdArrayElement for $ty {}
98
99        #[allow(clippy::cast_abs_to_unsigned)]
100        impl ExpElement for $ty {
101            #[inline(always)]
102            fn exp_elem(self) -> Self {
103                self.exp()
104            }
105
106            #[inline(always)]
107            fn log_elem(self) -> Self {
108                self.ln()
109            }
110
111            #[inline(always)]
112            fn log1p_elem(self) -> Self {
113                $log1p(self)
114            }
115
116            #[inline(always)]
117            fn powf_elem(self, value: f32) -> Self {
118                self.pow(value)
119            }
120
121            #[inline(always)]
122            fn powi_elem(self, value: i32) -> Self {
123                #[cfg(feature = "std")]
124                let val = self.powi(value);
125
126                #[cfg(not(feature = "std"))]
127                let val = Self::powf_elem(self, value as f32);
128
129                val
130            }
131
132            #[inline(always)]
133            fn sqrt_elem(self) -> Self {
134                self.sqrt()
135            }
136
137            #[inline(always)]
138            fn abs_elem(self) -> Self {
139                self.abs()
140            }
141        }
142    };
143}
144macro_rules! make_int {
145    (
146        $ty:ty,
147        $abs:expr
148    ) => {
149        impl NdArrayElement for $ty {}
150
151        #[allow(clippy::cast_abs_to_unsigned)]
152        impl ExpElement for $ty {
153            #[inline(always)]
154            fn exp_elem(self) -> Self {
155                (self as f32).exp() as $ty
156            }
157
158            #[inline(always)]
159            fn log_elem(self) -> Self {
160                (self as f32).ln() as $ty
161            }
162
163            #[inline(always)]
164            fn log1p_elem(self) -> Self {
165                log1pf(self as f32) as $ty
166            }
167
168            #[inline(always)]
169            fn powf_elem(self, value: f32) -> Self {
170                (self as f32).pow(value) as $ty
171            }
172
173            #[inline(always)]
174            fn powi_elem(self, value: i32) -> Self {
175                #[cfg(feature = "std")]
176                let val = f32::powi(self as f32, value) as $ty;
177
178                #[cfg(not(feature = "std"))]
179                let val = Self::powf_elem(self, value as f32);
180
181                val
182            }
183
184            #[inline(always)]
185            fn sqrt_elem(self) -> Self {
186                (self as f32).sqrt() as $ty
187            }
188
189            #[inline(always)]
190            fn abs_elem(self) -> Self {
191                $abs(self)
192            }
193        }
194    };
195}
196
197make_float!(f64, log1p);
198make_float!(f32, log1pf);
199
200make_int!(i64, i64::wrapping_abs);
201make_int!(i32, i32::wrapping_abs);
202make_int!(i16, i16::wrapping_abs);
203make_int!(i8, i8::wrapping_abs);
204make_int!(u64, |x| x);
205make_int!(u32, |x| x);
206make_int!(u16, |x| x);
207make_int!(u8, |x| x);