1use crate::type_promote::FloatOutUnary2;
2use crate::type_promote::{BitWiseOut2, Eval2, FloatOutBinary2, NormalOut2, NormalOutUnary2};
3use num_complex::ComplexFloat;
4macro_rules! impl_int_traits {
5 ($type:ty, [$($abs:tt)*], [$($neg:tt)*], [$($signum:tt)*]) => {
6 impl FloatOutBinary2 for $type {
7 #[inline(always)]
8 fn __div(self, rhs: Self) -> Self {
9 if rhs == 0 {
10 panic!("Division by zero for {}", stringify!($typ>));
11 } else {
12 self / rhs
13 }
14 }
15 #[inline(always)]
16 fn __log(self, _: Self) -> Self {
17 panic!("Logarithm operation is not supported for {}", stringify!($type));
18 }
19 #[inline(always)]
20 fn __hypot(self, _: Self) -> Self {
21 panic!("Hypot operation is not supported for {}", stringify!($type));
22 }
23 }
24
25 impl NormalOut2 for $type {
26 #[inline(always)]
27 fn __add(self, rhs: Self) -> Self {
28 self.wrapping_add(rhs)
29 }
30
31 #[inline(always)]
32 fn __sub(self, rhs: Self) -> Self {
33 self.wrapping_sub(rhs)
34 }
35
36 #[inline(always)]
37 fn __mul_add(self, a: Self, b: Self) -> Self {
38 (self * a) + b
39 }
40
41 #[inline(always)]
42 fn __mul(self, rhs: Self) -> Self {
43 self.wrapping_mul(rhs)
44 }
45
46 #[inline(always)]
47 fn __pow(self, rhs: Self) -> Self {
48 self.pow(rhs as u32)
49 }
50
51 #[inline(always)]
52 fn __rem(self, rhs: Self) -> Self {
53 self.wrapping_rem(rhs)
54 }
55
56 #[inline(always)]
57 fn __max(self, rhs: Self) -> Self {
58 self.max(rhs)
59 }
60
61 #[inline(always)]
62 fn __min(self, rhs: Self) -> Self {
63 self.min(rhs)
64 }
65
66 #[inline(always)]
67 fn __clamp(self, min: Self, max: Self) -> Self {
68 self.clamp(min, max)
69 }
70 }
71
72 impl NormalOutUnary2 for $type {
73 #[inline(always)]
74 fn __square(self) -> Self {
75 self.wrapping_mul(self)
76 }
77
78 #[inline(always)]
79 fn __abs(self) -> Self {
80 self$($abs)*
81 }
82
83 #[inline(always)]
84 fn __ceil(self) -> Self {
85 self
86 }
87
88 #[inline(always)]
89 fn __floor(self) -> Self {
90 self
91 }
92
93 #[inline(always)]
94 fn __neg(self) -> Self {
95 $($neg)*self
96 }
97
98 #[inline(always)]
99 fn __round(self) -> Self {
100 self
101 }
102
103 #[inline(always)]
104 fn __signum(self) -> Self {
105 self$($signum)*
106 }
107
108 #[inline(always)]
109 fn __trunc(self) -> Self {
110 self
111 }
112
113 #[inline(always)]
114 fn __leaky_relu(self, alpha: Self) -> Self {
115 self.max(0) + alpha * self.min(0)
116 }
117
118 #[inline(always)]
119 fn __relu(self) -> Self {
120 self.max(0)
121 }
122
123 #[inline(always)]
124 fn __relu6(self) -> Self {
125 self.min(6).max(0)
126 }
127
128 #[inline(always)]
129 fn __copysign(self, _: Self) -> Self {
130 panic!("copysign is not supported for integer types")
131 }
132 }
133
134 impl BitWiseOut2 for $type {
135 #[inline(always)]
136 fn __bitand(self, rhs: Self) -> Self {
137 self & rhs
138 }
139
140 #[inline(always)]
141 fn __bitor(self, rhs: Self) -> Self {
142 self | rhs
143 }
144
145 #[inline(always)]
146 fn __bitxor(self, rhs: Self) -> Self {
147 self ^ rhs
148 }
149
150 #[inline(always)]
151 fn __not(self) -> Self {
152 !self
153 }
154
155 #[inline(always)]
156 fn __shl(self, rhs: Self) -> Self {
157 self.wrapping_shl(rhs as u32)
158 }
159
160 #[inline(always)]
161 fn __shr(self, rhs: Self) -> Self {
162 self.wrapping_shr(rhs as u32)
163 }
164 }
165
166 impl Eval2 for $type {
167 type Output = bool;
168 #[inline(always)]
169 fn __is_nan(&self) -> Self::Output {
170 false
171 }
172
173 #[inline(always)]
174 fn __is_true(&self) -> Self::Output {
175 *self != 0
176 }
177
178 #[inline(always)]
179 fn __is_inf(&self) -> Self::Output {
180 false
181 }
182 }
183 };
184}
185
186impl_int_traits!(i8, [.abs()], [-], [.signum()]);
187impl_int_traits!(i16, [.abs()], [-], [.signum()]);
188impl_int_traits!(i32, [.abs()], [-], [.signum()]);
189impl_int_traits!(i64, [.abs()], [-], [.signum()]);
190impl_int_traits!(i128, [.abs()], [-], [.signum()]);
191impl_int_traits!(isize, [.abs()], [-], [.signum()]);
192impl_int_traits!(u8, [], [], []);
193impl_int_traits!(u16, [], [], []);
194impl_int_traits!(u32, [], [], []);
195impl_int_traits!(u64, [], [], []);
196impl_int_traits!(u128, [], [], []);
197impl_int_traits!(usize, [], [], []);
198
199use num_complex::Complex;
200macro_rules! impl_complex {
201 ($type:ident) => {
202 impl FloatOutBinary2 for Complex<$type> {
203 #[inline(always)]
204 fn __div(self, rhs: Self) -> Self {
205 self / rhs
206 }
207 #[inline(always)]
208 fn __log(self, base: Self) -> Self {
209 self.log(base.re)
210 }
211 #[inline(always)]
212 fn __hypot(self, _: Self) -> Self {
213 panic!("Hypot operation is not supported for complex numbers");
214 }
215 }
216
217 impl NormalOut2 for Complex<$type> {
218 #[inline(always)]
219 fn __add(self, rhs: Self) -> Self {
220 self + rhs
221 }
222
223 #[inline(always)]
224 fn __sub(self, rhs: Self) -> Self {
225 self - rhs
226 }
227
228 #[inline(always)]
229 fn __mul_add(self, a: Self, b: Self) -> Self {
230 (self * a) + b
231 }
232
233 #[inline(always)]
234 fn __mul(self, rhs: Self) -> Self {
235 self * rhs
236 }
237
238 #[inline(always)]
239 fn __pow(self, rhs: Self) -> Self {
240 self.powf(rhs.re)
241 }
242
243 #[inline(always)]
244 fn __rem(self, rhs: Self) -> Self {
245 self % rhs
246 }
247
248 #[inline(always)]
249 fn __max(self, rhs: Self) -> Self {
250 if self.norm() >= rhs.norm() {
251 self
252 } else {
253 rhs
254 }
255 }
256
257 #[inline(always)]
258 fn __min(self, rhs: Self) -> Self {
259 if self.norm() <= rhs.norm() {
260 self
261 } else {
262 rhs
263 }
264 }
265
266 #[inline(always)]
267 fn __clamp(self, min: Self, max: Self) -> Self {
268 let norm = self.norm();
269 if norm < min.norm() {
270 self * (min.norm() / norm)
271 } else if norm > max.norm() {
272 self * (max.norm() / norm)
273 } else {
274 self
275 }
276 }
277 }
278
279 impl NormalOutUnary2 for Complex<$type> {
280 #[inline(always)]
281 fn __square(self) -> Self {
282 self * self
283 }
284
285 #[inline(always)]
286 fn __abs(self) -> Self {
287 self.abs().into()
288 }
289
290 #[inline(always)]
291 fn __ceil(self) -> Self {
292 Complex::<$type>::new(self.re.ceil(), self.im.ceil())
293 }
294
295 #[inline(always)]
296 fn __floor(self) -> Self {
297 Complex::<$type>::new(self.re.floor(), self.im.floor())
298 }
299
300 #[inline(always)]
301 fn __neg(self) -> Self {
302 -self
303 }
304
305 #[inline(always)]
306 fn __round(self) -> Self {
307 Complex::<$type>::new(self.re.round(), self.im.round())
308 }
309
310 #[inline(always)]
311 fn __signum(self) -> Self {
312 if self == Complex::<$type>::new(0.0, 0.0) {
313 self
314 } else {
315 self / Complex::<$type>::from(self.norm())
316 }
317 }
318
319 #[inline(always)]
320 fn __trunc(self) -> Self {
321 Complex::<$type>::new(self.re.trunc(), self.im.trunc())
322 }
323
324 #[inline(always)]
325 fn __leaky_relu(self, alpha: Self) -> Self {
326 let norm = self.norm();
327 if norm > 0.0 {
328 self
329 } else {
330 self * alpha
331 }
332 }
333
334 #[inline(always)]
335 fn __relu(self) -> Self {
336 let norm = self.norm();
337 if norm > 0.0 {
338 self
339 } else {
340 Complex::<$type>::new(0.0, 0.0)
341 }
342 }
343
344 #[inline(always)]
345 fn __relu6(self) -> Self {
346 let norm = self.norm();
347 if norm > 6.0 {
348 self * (6.0 / norm)
349 } else if norm > 0.0 {
350 self
351 } else {
352 Complex::<$type>::new(0.0, 0.0)
353 }
354 }
355
356 #[inline(always)]
357 fn __copysign(self, _: Self) -> Self {
358 panic!("copysign is not supported for complex numbers")
359 }
360 }
361
362 impl BitWiseOut2 for Complex<$type> {
363 #[inline(always)]
364 fn __bitand(self, rhs: Self) -> Self {
365 Complex::<$type>::new(
366 $type::from_bits(self.re.to_bits() & rhs.re.to_bits()),
367 $type::from_bits(self.im.to_bits() & rhs.im.to_bits()),
368 )
369 }
370
371 #[inline(always)]
372 fn __bitor(self, rhs: Self) -> Self {
373 Complex::<$type>::new(
374 $type::from_bits(self.re.to_bits() | rhs.re.to_bits()),
375 $type::from_bits(self.im.to_bits() | rhs.im.to_bits()),
376 )
377 }
378
379 #[inline(always)]
380 fn __bitxor(self, rhs: Self) -> Self {
381 Complex::<$type>::new(
382 $type::from_bits(self.re.to_bits() ^ rhs.re.to_bits()),
383 $type::from_bits(self.im.to_bits() ^ rhs.im.to_bits()),
384 )
385 }
386
387 #[inline(always)]
388 fn __not(self) -> Self {
389 Complex::<$type>::new(
390 $type::from_bits(!self.re.to_bits()),
391 $type::from_bits(!self.im.to_bits()),
392 )
393 }
394
395 #[inline(always)]
396 fn __shl(self, _: Self) -> Self {
397 panic!("shift left is not supported for complex numbers")
398 }
399
400 #[inline(always)]
401 fn __shr(self, _: Self) -> Self {
402 panic!("shift right is not supported for complex numbers")
403 }
404 }
405
406 impl Eval2 for Complex<$type> {
407 type Output = bool;
408 #[inline(always)]
409 fn __is_nan(&self) -> Self::Output {
410 self.is_nan()
411 }
412
413 #[inline(always)]
414 fn __is_true(&self) -> Self::Output {
415 self.norm() != 0.0 && !self.is_nan()
416 }
417
418 #[inline(always)]
419 fn __is_inf(&self) -> bool {
420 self.is_infinite()
421 }
422 }
423
424 impl FloatOutUnary2 for Complex<$type> {
425 #[inline(always)]
426 fn __exp(self) -> Self {
427 self.exp()
428 }
429 #[inline(always)]
430 fn __expm1(self) -> Self {
431 self.exp() - 1.0
432 }
433 #[inline(always)]
434 fn __exp2(self) -> Self {
435 self.exp2()
436 }
437 #[inline(always)]
438 fn __ln(self) -> Self {
439 self.ln()
440 }
441 #[inline(always)]
442 fn __log1p(self) -> Self {
443 self.ln() + 1.0
444 }
445 #[inline(always)]
446 fn __celu(self, _: Self) -> Self {
447 panic!("celu is not supported for complex numbers")
448 }
449 #[inline(always)]
450 fn __log2(self) -> Self {
451 self.log2()
452 }
453 #[inline(always)]
454 fn __log10(self) -> Self {
455 self.log10()
456 }
457 #[inline(always)]
458 fn __sqrt(self) -> Self {
459 self.sqrt()
460 }
461 #[inline(always)]
462 fn __sin(self) -> Self {
463 self.sin()
464 }
465 #[inline(always)]
466 fn __cos(self) -> Self {
467 self.cos()
468 }
469 #[inline(always)]
470 fn __tan(self) -> Self {
471 self.tan()
472 }
473 #[inline(always)]
474 fn __asin(self) -> Self {
475 self.asin()
476 }
477 #[inline(always)]
478 fn __acos(self) -> Self {
479 self.acos()
480 }
481 #[inline(always)]
482 fn __atan(self) -> Self {
483 self.atan()
484 }
485 #[inline(always)]
486 fn __sinh(self) -> Self {
487 self.sinh()
488 }
489 #[inline(always)]
490 fn __cosh(self) -> Self {
491 self.cosh()
492 }
493 #[inline(always)]
494 fn __tanh(self) -> Self {
495 self.tanh()
496 }
497 #[inline(always)]
498 fn __asinh(self) -> Self {
499 self.asinh()
500 }
501 #[inline(always)]
502 fn __acosh(self) -> Self {
503 self.acosh()
504 }
505 #[inline(always)]
506 fn __atanh(self) -> Self {
507 self.atanh()
508 }
509 #[inline(always)]
510 fn __recip(self) -> Self {
511 self.recip()
512 }
513 #[inline(always)]
514 fn __erf(self) -> Self {
515 panic!("erf is not supported for complex numbers")
516 }
517
518 #[inline(always)]
519 fn __sigmoid(self) -> Self {
520 1.0 / (1.0 + (-self).exp())
521 }
522
523 fn __elu(self, _: Self) -> Self {
524 panic!("elu is not supported for complex numbers")
525 }
526
527 fn __gelu(self) -> Self {
528 panic!("gelu is not supported for complex numbers")
529 }
530
531 fn __selu(self, _: Self, _: Self) -> Self {
532 panic!("selu is not supported for complex numbers")
533 }
534
535 fn __hard_sigmoid(self) -> Self {
536 panic!("hard sigmoid is not supported for complex numbers")
537 }
538
539 fn __hard_swish(self) -> Self {
540 panic!("hard swish is not supported for complex numbers")
541 }
542
543 fn __softplus(self) -> Self {
544 panic!("softplus is not supported for complex numbers")
545 }
546
547 fn __softsign(self) -> Self {
548 self / (1.0 + self.abs())
549 }
550
551 fn __mish(self) -> Self {
552 self * ((1.0 + self.exp()).ln()).tanh()
553 }
554
555 fn __cbrt(self) -> Self {
556 panic!("cbrt is not supported for complex numbers")
557 }
558
559 fn __sincos(self) -> (Self, Self) {
560 (self.sin(), self.cos())
561 }
562
563 fn __atan2(self, _: Self) -> Self {
564 panic!("atan2 is not supported for complex numbers")
565 }
566
567 fn __exp10(self) -> Self {
568 panic!("exp10 is not supported for complex numbers")
569 }
570 }
571 };
572}
573
574impl_complex!(f32);
575impl_complex!(f64);