1use ndarray::{Array, ArrayBase, Data, Dimension};
6use num_traits::Signed;
7
8macro_rules! unary {
9 (@impl $(#[$meta:meta])* $name:ident::$call:ident($($rest:tt)*)) => {
10 $(#[$meta])*
11 pub trait $name {
12 type Output;
13
14 fn $call($($rest)*) -> Self::Output;
15 }
16 };
17 ($($(#[$meta:meta])*$name:ident::$call:ident($($rest:tt)*)),* $(,)?) => {
18 $(unary! { @impl $(#[$meta])* $name::$call($($rest)*) })*
19 };
20}
21
22macro_rules! impl_unary_op {
23 (@impl $name:ident::<$T:ty>::$method:ident) => {
24 impl $name for $T {
25 type Output = $T;
26
27 fn $method(self) -> Self::Output {
28 <$T>::$method(self)
29 }
30 }
31 };
32 ($($name:ident::<[$($T:ty),*]>::$method:ident),* $(,)?) => {
33 $($(impl_unary_op! { @impl $name::<$T>::$method })*)*
34 };
35}
36
37macro_rules! impl_something {
38 (@impl $trait:ident::<$T:ty>::$method:ident($self:ident $(, $($input:ident: $I:ty),*)?) -> $out:ty {$func:expr}) => {
39 impl $trait for $T {
40 type Output = $out;
41
42 fn $method($self $(, $($input: $I),*)?) -> Self::Output {
43 $func
44 }
45 }
46 };
47 ($($trait:ident::<[$($T:ty),* $(,)?]>::$method:ident($self:ident) -> $out:ty {$func:expr});* $(;)?) => {
48 $($(impl_something! { @impl $trait::<$T>::$method($self) -> $out {$func} } )*)*
49 };
50}
51
52unary! {
53 Abs::abs(self),
54 Cos::cos(self),
55 Cosh::cosh(self),
56 Exp::exp(self),
57 Sine::sin(self),
58 Sinh::sinh(self),
59 Tan::tan(self),
60 Tanh::tanh(self),
61 Squared::pow2(self),
62 Cubed::pow3(self),
63 SquareRoot::sqrt(self),
64 Conjugate::conj(&self),
65}
66
67impl_unary_op! {
68 Abs::<[i8, i16, i32, i64, i128, isize, f32, f64]>::abs,
69 Cos::<[f32, f64]>::cos,
70 Cosh::<[f32, f64]>::cosh,
71 Exp::<[f32, f64]>::exp,
72 Sinh::<[f32, f64]>::sinh,
73 Sine::<[f32, f64]>::sin,
74 Tan::<[f32, f64]>::tan,
75 Tanh::<[f32, f64]>::tanh,
76 SquareRoot::<[f32, f64]>::sqrt
77}
78
79impl_something! {
80 Squared::<[u8, u16, u32, u64, u128, usize, i8, i16, i32, i64, i128, isize, f32, f64]>::pow2(self) -> Self {
81 self * self
82 };
83 Cubed::<[u8, u16, u32, u64, u128, usize, i8, i16, i32, i64, i128, isize, f32, f64]>::pow3(self) -> Self {
84 self * self * self
85 };
86}
87
88impl<A, S, D> Abs for ArrayBase<S, D>
89where
90 A: Clone + Signed,
91 D: Dimension,
92 S: Data<Elem = A>,
93{
94 type Output = Array<A, D>;
95
96 fn abs(self) -> Self::Output {
97 self.mapv(|x| x.abs())
98 }
99}
100
101impl<A, S, D> Abs for &ArrayBase<S, D>
102where
103 A: Clone + Signed,
104 D: Dimension,
105 S: Data<Elem = A>,
106{
107 type Output = Array<A, D>;
108
109 fn abs(self) -> Self::Output {
110 self.mapv(|x| x.abs())
111 }
112}
113
114impl<A, B, S, D> SquareRoot for ArrayBase<S, D>
115where
116 A: Clone + SquareRoot<Output = B>,
117 D: Dimension,
118 S: Data<Elem = A>,
119{
120 type Output = Array<B, D>;
121
122 fn sqrt(self) -> Self::Output {
123 self.mapv(|x| x.sqrt())
124 }
125}
126
127impl<A, B, S, D> Exp for ArrayBase<S, D>
128where
129 A: Clone + Exp<Output = B>,
130 D: Dimension,
131 S: Data<Elem = A>,
132{
133 type Output = Array<B, D>;
134
135 fn exp(self) -> Self::Output {
136 self.mapv(|x| x.exp())
137 }
138}
139impl<A, S, D> Exp for &ArrayBase<S, D, A>
140where
141 A: Clone + Exp<Output = A>,
142 D: Dimension,
143 S: Data<Elem = A>,
144{
145 type Output = Array<A, D>;
146
147 fn exp(self) -> Self::Output {
148 self.mapv(|x| x.exp())
149 }
150}
151
152#[cfg(feature = "complex")]
153mod impl_complex {
154 use super::*;
155
156 use ndarray::{Array, Dimension};
157 use num_complex::{Complex, ComplexFloat};
158 use num_traits::Signed;
159
160 macro_rules! impl_complex_for {
161 (@impl $name:ident::<$T:ident>::$method:ident) => {
162 #[cfg(feature = "complex")]
163 impl<$T> $name for num_complex::Complex<$T>
164 where
165 num_complex::Complex<$T>: num_complex::ComplexFloat<Real = $T>,
166 {
167 type Output = num_complex::Complex<$T>;
168
169 fn $method(self) -> Self::Output {
170 num_complex::ComplexFloat::$method(self)
171 }
172 }
173 };
174 ($($name:ident::<$T:ident>::$method:ident),* $(,)?) => {
175 $(impl_complex_for!(@impl $name::<$T>::$method);)*
176 };
177 }
178
179 impl_complex_for! {
180 Cos::<T>::cos,
181 Cosh::<T>::cosh,
182 Exp::<T>::exp,
183 Sine::<T>::sin,
184 Sinh::<T>::sinh,
185 Tan::<T>::tan,
186 Tanh::<T>::tanh,
187 SquareRoot::<T>::sqrt,
188 }
189
190 macro_rules! impl_conj {
191 ($($t:ident<$res:ident>),*) => {
192 $(
193 impl_conj!(@impl $t<$res>);
194 )*
195 };
196 (@impl $t:ident<$res:ident>) => {
197 impl Conjugate for $t {
198 type Output = $res<$t>;
199
200 fn conj(&self) -> Self::Output {
201 Complex { re: *self, im: num_traits::Zero::zero() }
202 }
203 }
204 };
205}
206
207 impl_conj!(f32<Complex>, f64<Complex>);
208
209 impl<T> Conjugate for Complex<T>
210 where
211 T: Clone + Signed,
212 {
213 type Output = Complex<T>;
214
215 fn conj(&self) -> Self {
216 Complex::<T>::conj(self)
217 }
218 }
219
220 impl<T, D> Conjugate for Array<T, D>
221 where
222 D: Dimension,
223 T: Clone + ComplexFloat,
224 {
225 type Output = Array<T, D>;
226 fn conj(&self) -> Self::Output {
227 self.mapv(|x| x.conj())
228 }
229 }
230}