burn_ndarray/
element.rs

1use burn_backend::Element;
2use num_traits::Signed;
3
4#[cfg(not(feature = "std"))]
5#[allow(unused_imports)]
6use num_traits::Float;
7
8use num_traits::Pow;
9
10use libm::{log1p, log1pf};
11
12/// A float element for ndarray backend.
13pub trait FloatNdArrayElement: NdArrayElement + Signed
14where
15    Self: Sized,
16{
17}
18
19/// An int element for ndarray backend.
20pub trait IntNdArrayElement: NdArrayElement {}
21
22/// A general element for ndarray backend.
23pub trait NdArrayElement:
24    Element
25    + ndarray::LinalgScalar
26    + ndarray::ScalarOperand
27    + ExpElement
28    + AddAssignElement
29    + num_traits::FromPrimitive
30    + core::ops::AddAssign
31    + core::cmp::PartialEq
32    + core::cmp::PartialOrd<Self>
33    + core::ops::Rem<Output = Self>
34{
35}
36
37/// A element for ndarray backend that supports exp ops.
38pub trait ExpElement {
39    /// Exponent
40    fn exp_elem(self) -> Self;
41    /// Log
42    fn log_elem(self) -> Self;
43    /// Log1p
44    fn log1p_elem(self) -> Self;
45    /// Powf
46    fn powf_elem(self, value: f32) -> Self;
47    /// Powi
48    fn powi_elem(self, value: i32) -> Self;
49    /// Sqrt
50    fn sqrt_elem(self) -> Self;
51    /// Abs
52    fn abs_elem(self) -> Self;
53    /// Abs for int
54    fn int_abs_elem(self) -> Self;
55}
56
57/// The addition assignment operator implemented for ndarray elements.
58pub trait AddAssignElement<Rhs = Self> {
59    /// Performs the addition assignment operation.
60    ///
61    /// For `bool`, this corresponds to logical OR assignment.
62    fn add_assign(&mut self, rhs: Rhs);
63}
64
65impl<E: NdArrayElement> AddAssignElement for E {
66    fn add_assign(&mut self, rhs: Self) {
67        *self += rhs;
68    }
69}
70
71impl AddAssignElement for bool {
72    fn add_assign(&mut self, rhs: Self) {
73        *self = *self || rhs; // logical OR for bool
74    }
75}
76
77/// A quantized element for the ndarray backend.
78pub trait QuantElement: NdArrayElement {}
79
80impl QuantElement for i8 {}
81
82impl FloatNdArrayElement for f64 {}
83impl FloatNdArrayElement for f32 {}
84
85impl IntNdArrayElement for i64 {}
86impl IntNdArrayElement for i32 {}
87impl IntNdArrayElement for i16 {}
88impl IntNdArrayElement for i8 {}
89
90impl IntNdArrayElement for u64 {}
91impl IntNdArrayElement for u32 {}
92impl IntNdArrayElement for u16 {}
93impl IntNdArrayElement for u8 {}
94
95macro_rules! make_elem {
96    (
97        double
98        $ty:ty
99    ) => {
100        impl NdArrayElement for $ty {}
101
102        #[allow(clippy::cast_abs_to_unsigned)]
103        impl ExpElement for $ty {
104            #[inline(always)]
105            fn exp_elem(self) -> Self {
106                (self as f64).exp() as $ty
107            }
108
109            #[inline(always)]
110            fn log_elem(self) -> Self {
111                (self as f64).ln() as $ty
112            }
113
114            #[inline(always)]
115            fn log1p_elem(self) -> Self {
116                log1p(self as f64) as $ty
117            }
118
119            #[inline(always)]
120            fn powf_elem(self, value: f32) -> Self {
121                (self as f64).pow(value) as $ty
122            }
123
124            #[inline(always)]
125            fn powi_elem(self, value: i32) -> Self {
126                #[cfg(feature = "std")]
127                let val = f64::powi(self as f64, value) as $ty;
128
129                #[cfg(not(feature = "std"))]
130                let val = Self::powf_elem(self, value as f32);
131
132                val
133            }
134
135            #[inline(always)]
136            fn sqrt_elem(self) -> Self {
137                (self as f64).sqrt() as $ty
138            }
139
140            #[inline(always)]
141            fn abs_elem(self) -> Self {
142                (self as f64).abs() as $ty
143            }
144
145            #[inline(always)]
146            fn int_abs_elem(self) -> Self {
147                (self as i64).abs() as $ty
148            }
149        }
150    };
151    (
152        single
153        $ty:ty
154    ) => {
155        impl NdArrayElement for $ty {}
156
157        impl ExpElement for $ty {
158            #[inline(always)]
159            fn exp_elem(self) -> Self {
160                (self as f32).exp() as $ty
161            }
162
163            #[inline(always)]
164            fn log_elem(self) -> Self {
165                (self as f32).ln() as $ty
166            }
167
168            #[inline(always)]
169            fn log1p_elem(self) -> Self {
170                log1pf(self as f32) as $ty
171            }
172
173            #[inline(always)]
174            fn powf_elem(self, value: f32) -> Self {
175                (self as f32).pow(value) as $ty
176            }
177
178            #[inline(always)]
179            fn powi_elem(self, value: i32) -> Self {
180                #[cfg(feature = "std")]
181                let val = f32::powi(self as f32, value) as $ty;
182
183                #[cfg(not(feature = "std"))]
184                let val = Self::powf_elem(self, value as f32);
185
186                val
187            }
188
189            #[inline(always)]
190            fn sqrt_elem(self) -> Self {
191                (self as f32).sqrt() as $ty
192            }
193
194            #[inline(always)]
195            fn abs_elem(self) -> Self {
196                (self as f32).abs() as $ty
197            }
198
199            #[inline(always)]
200            fn int_abs_elem(self) -> Self {
201                (self as i32).unsigned_abs() as $ty
202            }
203        }
204    };
205}
206
207make_elem!(double f64);
208make_elem!(double i64);
209make_elem!(double u64);
210
211make_elem!(single f32);
212make_elem!(single i32);
213make_elem!(single i16);
214make_elem!(single i8);
215make_elem!(single u32);
216make_elem!(single u16);
217make_elem!(single u8);