hpt_types/scalars/
impls.rs

1use crate::type_promote::FloatOutUnary2;
2use crate::type_promote::{BitWiseOut2, Eval2, FloatOutBinary2, NormalOut2, NormalOutUnary2};
3use num_complex::ComplexFloat;
4macro_rules! impl_int_traits {
5    ($type:ty, [$($abs:tt)*], [$($neg:tt)*], [$($signum:tt)*]) => {
6        impl FloatOutBinary2 for $type {
7            #[inline(always)]
8            fn __div(self, rhs: Self) -> Self {
9                if rhs == 0 {
10                    panic!("Division by zero for {}", stringify!($typ>));
11                } else {
12                    self / rhs
13                }
14            }
15            #[inline(always)]
16            fn __log(self, _: Self) -> Self {
17                panic!("Logarithm operation is not supported for {}", stringify!($type));
18            }
19            #[inline(always)]
20            fn __hypot(self, _: Self) -> Self {
21                panic!("Hypot operation is not supported for {}", stringify!($type));
22            }
23        }
24
25        impl NormalOut2 for $type {
26            #[inline(always)]
27            fn __add(self, rhs: Self) -> Self {
28                self.wrapping_add(rhs)
29            }
30
31            #[inline(always)]
32            fn __sub(self, rhs: Self) -> Self {
33                self.wrapping_sub(rhs)
34            }
35
36            #[inline(always)]
37            fn __mul_add(self, a: Self, b: Self) -> Self {
38                (self * a) + b
39            }
40
41            #[inline(always)]
42            fn __mul(self, rhs: Self) -> Self {
43                self.wrapping_mul(rhs)
44            }
45
46            #[inline(always)]
47            fn __pow(self, rhs: Self) -> Self {
48                self.pow(rhs as u32)
49            }
50
51            #[inline(always)]
52            fn __rem(self, rhs: Self) -> Self {
53                self.wrapping_rem(rhs)
54            }
55
56            #[inline(always)]
57            fn __max(self, rhs: Self) -> Self {
58                self.max(rhs)
59            }
60
61            #[inline(always)]
62            fn __min(self, rhs: Self) -> Self {
63                self.min(rhs)
64            }
65
66            #[inline(always)]
67            fn __clamp(self, min: Self, max: Self) -> Self {
68                self.clamp(min, max)
69            }
70        }
71
72        impl NormalOutUnary2 for $type {
73            #[inline(always)]
74            fn __square(self) -> Self {
75                self.wrapping_mul(self)
76            }
77
78            #[inline(always)]
79            fn __abs(self) -> Self {
80                self$($abs)*
81            }
82
83            #[inline(always)]
84            fn __ceil(self) -> Self {
85                self
86            }
87
88            #[inline(always)]
89            fn __floor(self) -> Self {
90                self
91            }
92
93            #[inline(always)]
94            fn __neg(self) -> Self {
95                $($neg)*self
96            }
97
98            #[inline(always)]
99            fn __round(self) -> Self {
100                self
101            }
102
103            #[inline(always)]
104            fn __signum(self) -> Self {
105                self$($signum)*
106            }
107
108            #[inline(always)]
109            fn __trunc(self) -> Self {
110                self
111            }
112
113            #[inline(always)]
114            fn __leaky_relu(self, alpha: Self) -> Self {
115                self.max(0) + alpha * self.min(0)
116            }
117
118            #[inline(always)]
119            fn __relu(self) -> Self {
120                self.max(0)
121            }
122
123            #[inline(always)]
124            fn __relu6(self) -> Self {
125                self.min(6).max(0)
126            }
127
128            #[inline(always)]
129            fn __copysign(self, _: Self) -> Self {
130                panic!("copysign is not supported for integer types")
131            }
132        }
133
134        impl BitWiseOut2 for $type {
135            #[inline(always)]
136            fn __bitand(self, rhs: Self) -> Self {
137                self & rhs
138            }
139
140            #[inline(always)]
141            fn __bitor(self, rhs: Self) -> Self {
142                self | rhs
143            }
144
145            #[inline(always)]
146            fn __bitxor(self, rhs: Self) -> Self {
147                self ^ rhs
148            }
149
150            #[inline(always)]
151            fn __not(self) -> Self {
152                !self
153            }
154
155            #[inline(always)]
156            fn __shl(self, rhs: Self) -> Self {
157                self.wrapping_shl(rhs as u32)
158            }
159
160            #[inline(always)]
161            fn __shr(self, rhs: Self) -> Self {
162                self.wrapping_shr(rhs as u32)
163            }
164        }
165
166        impl Eval2 for $type {
167            type Output = bool;
168            #[inline(always)]
169            fn __is_nan(&self) -> Self::Output {
170                false
171            }
172
173            #[inline(always)]
174            fn __is_true(&self) -> Self::Output {
175                *self != 0
176            }
177
178            #[inline(always)]
179            fn __is_inf(&self) -> Self::Output {
180                false
181            }
182        }
183    };
184}
185
186impl_int_traits!(i8, [.abs()], [-], [.signum()]);
187impl_int_traits!(i16, [.abs()], [-], [.signum()]);
188impl_int_traits!(i32, [.abs()], [-], [.signum()]);
189impl_int_traits!(i64, [.abs()], [-], [.signum()]);
190impl_int_traits!(i128, [.abs()], [-], [.signum()]);
191impl_int_traits!(isize, [.abs()], [-], [.signum()]);
192impl_int_traits!(u8, [], [], []);
193impl_int_traits!(u16, [], [], []);
194impl_int_traits!(u32, [], [], []);
195impl_int_traits!(u64, [], [], []);
196impl_int_traits!(u128, [], [], []);
197impl_int_traits!(usize, [], [], []);
198
199use num_complex::Complex;
200macro_rules! impl_complex {
201    ($type:ident) => {
202        impl FloatOutBinary2 for Complex<$type> {
203            #[inline(always)]
204            fn __div(self, rhs: Self) -> Self {
205                self / rhs
206            }
207            #[inline(always)]
208            fn __log(self, base: Self) -> Self {
209                self.log(base.re)
210            }
211            #[inline(always)]
212            fn __hypot(self, _: Self) -> Self {
213                panic!("Hypot operation is not supported for complex numbers");
214            }
215        }
216
217        impl NormalOut2 for Complex<$type> {
218            #[inline(always)]
219            fn __add(self, rhs: Self) -> Self {
220                self + rhs
221            }
222
223            #[inline(always)]
224            fn __sub(self, rhs: Self) -> Self {
225                self - rhs
226            }
227
228            #[inline(always)]
229            fn __mul_add(self, a: Self, b: Self) -> Self {
230                (self * a) + b
231            }
232
233            #[inline(always)]
234            fn __mul(self, rhs: Self) -> Self {
235                self * rhs
236            }
237
238            #[inline(always)]
239            fn __pow(self, rhs: Self) -> Self {
240                self.powf(rhs.re)
241            }
242
243            #[inline(always)]
244            fn __rem(self, rhs: Self) -> Self {
245                self % rhs
246            }
247
248            #[inline(always)]
249            fn __max(self, rhs: Self) -> Self {
250                if self.norm() >= rhs.norm() {
251                    self
252                } else {
253                    rhs
254                }
255            }
256
257            #[inline(always)]
258            fn __min(self, rhs: Self) -> Self {
259                if self.norm() <= rhs.norm() {
260                    self
261                } else {
262                    rhs
263                }
264            }
265
266            #[inline(always)]
267            fn __clamp(self, min: Self, max: Self) -> Self {
268                let norm = self.norm();
269                if norm < min.norm() {
270                    self * (min.norm() / norm)
271                } else if norm > max.norm() {
272                    self * (max.norm() / norm)
273                } else {
274                    self
275                }
276            }
277        }
278
279        impl NormalOutUnary2 for Complex<$type> {
280            #[inline(always)]
281            fn __square(self) -> Self {
282                self * self
283            }
284
285            #[inline(always)]
286            fn __abs(self) -> Self {
287                self.abs().into()
288            }
289
290            #[inline(always)]
291            fn __ceil(self) -> Self {
292                Complex::<$type>::new(self.re.ceil(), self.im.ceil())
293            }
294
295            #[inline(always)]
296            fn __floor(self) -> Self {
297                Complex::<$type>::new(self.re.floor(), self.im.floor())
298            }
299
300            #[inline(always)]
301            fn __neg(self) -> Self {
302                -self
303            }
304
305            #[inline(always)]
306            fn __round(self) -> Self {
307                Complex::<$type>::new(self.re.round(), self.im.round())
308            }
309
310            #[inline(always)]
311            fn __signum(self) -> Self {
312                if self == Complex::<$type>::new(0.0, 0.0) {
313                    self
314                } else {
315                    self / Complex::<$type>::from(self.norm())
316                }
317            }
318
319            #[inline(always)]
320            fn __trunc(self) -> Self {
321                Complex::<$type>::new(self.re.trunc(), self.im.trunc())
322            }
323
324            #[inline(always)]
325            fn __leaky_relu(self, alpha: Self) -> Self {
326                let norm = self.norm();
327                if norm > 0.0 {
328                    self
329                } else {
330                    self * alpha
331                }
332            }
333
334            #[inline(always)]
335            fn __relu(self) -> Self {
336                let norm = self.norm();
337                if norm > 0.0 {
338                    self
339                } else {
340                    Complex::<$type>::new(0.0, 0.0)
341                }
342            }
343
344            #[inline(always)]
345            fn __relu6(self) -> Self {
346                let norm = self.norm();
347                if norm > 6.0 {
348                    self * (6.0 / norm)
349                } else if norm > 0.0 {
350                    self
351                } else {
352                    Complex::<$type>::new(0.0, 0.0)
353                }
354            }
355
356            #[inline(always)]
357            fn __copysign(self, _: Self) -> Self {
358                panic!("copysign is not supported for complex numbers")
359            }
360        }
361
362        impl BitWiseOut2 for Complex<$type> {
363            #[inline(always)]
364            fn __bitand(self, rhs: Self) -> Self {
365                Complex::<$type>::new(
366                    $type::from_bits(self.re.to_bits() & rhs.re.to_bits()),
367                    $type::from_bits(self.im.to_bits() & rhs.im.to_bits()),
368                )
369            }
370
371            #[inline(always)]
372            fn __bitor(self, rhs: Self) -> Self {
373                Complex::<$type>::new(
374                    $type::from_bits(self.re.to_bits() | rhs.re.to_bits()),
375                    $type::from_bits(self.im.to_bits() | rhs.im.to_bits()),
376                )
377            }
378
379            #[inline(always)]
380            fn __bitxor(self, rhs: Self) -> Self {
381                Complex::<$type>::new(
382                    $type::from_bits(self.re.to_bits() ^ rhs.re.to_bits()),
383                    $type::from_bits(self.im.to_bits() ^ rhs.im.to_bits()),
384                )
385            }
386
387            #[inline(always)]
388            fn __not(self) -> Self {
389                Complex::<$type>::new(
390                    $type::from_bits(!self.re.to_bits()),
391                    $type::from_bits(!self.im.to_bits()),
392                )
393            }
394
395            #[inline(always)]
396            fn __shl(self, _: Self) -> Self {
397                panic!("shift left is not supported for complex numbers")
398            }
399
400            #[inline(always)]
401            fn __shr(self, _: Self) -> Self {
402                panic!("shift right is not supported for complex numbers")
403            }
404        }
405
406        impl Eval2 for Complex<$type> {
407            type Output = bool;
408            #[inline(always)]
409            fn __is_nan(&self) -> Self::Output {
410                self.is_nan()
411            }
412
413            #[inline(always)]
414            fn __is_true(&self) -> Self::Output {
415                self.norm() != 0.0 && !self.is_nan()
416            }
417
418            #[inline(always)]
419            fn __is_inf(&self) -> bool {
420                self.is_infinite()
421            }
422        }
423
424        impl FloatOutUnary2 for Complex<$type> {
425            #[inline(always)]
426            fn __exp(self) -> Self {
427                self.exp()
428            }
429            #[inline(always)]
430            fn __expm1(self) -> Self {
431                self.exp() - 1.0
432            }
433            #[inline(always)]
434            fn __exp2(self) -> Self {
435                self.exp2()
436            }
437            #[inline(always)]
438            fn __ln(self) -> Self {
439                self.ln()
440            }
441            #[inline(always)]
442            fn __log1p(self) -> Self {
443                self.ln() + 1.0
444            }
445            #[inline(always)]
446            fn __celu(self, _: Self) -> Self {
447                panic!("celu is not supported for complex numbers")
448            }
449            #[inline(always)]
450            fn __log2(self) -> Self {
451                self.log2()
452            }
453            #[inline(always)]
454            fn __log10(self) -> Self {
455                self.log10()
456            }
457            #[inline(always)]
458            fn __sqrt(self) -> Self {
459                self.sqrt()
460            }
461            #[inline(always)]
462            fn __sin(self) -> Self {
463                self.sin()
464            }
465            #[inline(always)]
466            fn __cos(self) -> Self {
467                self.cos()
468            }
469            #[inline(always)]
470            fn __tan(self) -> Self {
471                self.tan()
472            }
473            #[inline(always)]
474            fn __asin(self) -> Self {
475                self.asin()
476            }
477            #[inline(always)]
478            fn __acos(self) -> Self {
479                self.acos()
480            }
481            #[inline(always)]
482            fn __atan(self) -> Self {
483                self.atan()
484            }
485            #[inline(always)]
486            fn __sinh(self) -> Self {
487                self.sinh()
488            }
489            #[inline(always)]
490            fn __cosh(self) -> Self {
491                self.cosh()
492            }
493            #[inline(always)]
494            fn __tanh(self) -> Self {
495                self.tanh()
496            }
497            #[inline(always)]
498            fn __asinh(self) -> Self {
499                self.asinh()
500            }
501            #[inline(always)]
502            fn __acosh(self) -> Self {
503                self.acosh()
504            }
505            #[inline(always)]
506            fn __atanh(self) -> Self {
507                self.atanh()
508            }
509            #[inline(always)]
510            fn __recip(self) -> Self {
511                self.recip()
512            }
513            #[inline(always)]
514            fn __erf(self) -> Self {
515                panic!("erf is not supported for complex numbers")
516            }
517
518            #[inline(always)]
519            fn __sigmoid(self) -> Self {
520                1.0 / (1.0 + (-self).exp())
521            }
522
523            fn __elu(self, _: Self) -> Self {
524                panic!("elu is not supported for complex numbers")
525            }
526
527            fn __gelu(self) -> Self {
528                panic!("gelu is not supported for complex numbers")
529            }
530
531            fn __selu(self, _: Self, _: Self) -> Self {
532                panic!("selu is not supported for complex numbers")
533            }
534
535            fn __hard_sigmoid(self) -> Self {
536                panic!("hard sigmoid is not supported for complex numbers")
537            }
538
539            fn __hard_swish(self) -> Self {
540                panic!("hard swish is not supported for complex numbers")
541            }
542
543            fn __softplus(self) -> Self {
544                panic!("softplus is not supported for complex numbers")
545            }
546
547            fn __softsign(self) -> Self {
548                self / (1.0 + self.abs())
549            }
550
551            fn __mish(self) -> Self {
552                self * ((1.0 + self.exp()).ln()).tanh()
553            }
554
555            fn __cbrt(self) -> Self {
556                panic!("cbrt is not supported for complex numbers")
557            }
558
559            fn __sincos(self) -> (Self, Self) {
560                (self.sin(), self.cos())
561            }
562
563            fn __atan2(self, _: Self) -> Self {
564                panic!("atan2 is not supported for complex numbers")
565            }
566
567            fn __exp10(self) -> Self {
568                panic!("exp10 is not supported for complex numbers")
569            }
570        }
571    };
572}
573
574impl_complex!(f32);
575impl_complex!(f64);