hpt_types/
type_promote.rs

1use crate::into_scalar::Cast;
2use crate::into_vec::IntoVec;
3#[cfg(any(
4    all(not(target_feature = "avx2"), target_feature = "sse"),
5    target_arch = "arm",
6    target_arch = "aarch64",
7    target_feature = "neon"
8))]
9use crate::simd::_128bit::*;
10#[cfg(target_feature = "avx2")]
11use crate::simd::_256bit::*;
12#[cfg(target_feature = "avx512f")]
13use crate::simd::_512bit::*;
14use crate::traits::SimdMath;
15use crate::vectors::traits::SimdCompare;
16use crate::vectors::traits::VecTrait;
17use half::bf16;
18use half::f16;
19use hpt_macros::{
20    float_out_binary, float_out_binary_simd_with_lhs_scalar, float_out_binary_simd_with_rhs_scalar,
21    float_out_unary, impl_bitwise_out, impl_cmp, impl_eval, impl_normal_out_binary,
22    impl_normal_out_simd, impl_normal_out_simd_with_lhs_scalar,
23    impl_normal_out_simd_with_rhs_scalar, impl_normal_out_unary, impl_normal_out_unary_simd,
24    simd_cmp, simd_eval, simd_float_out_unary,
25};
26use num_complex::{Complex32, Complex64};
27use num_traits::float::Float;
28#[cfg(feature = "cuda")]
29mod cuda_imports {
30    use super::*;
31    use crate::cuda_types::scalar::Scalar;
32    use hpt_macros::{
33        float_out_binary_cuda, float_out_unary_cuda, impl_cmp_cuda, impl_cuda_bitwise_out,
34        impl_cuda_normal_out_binary, impl_normal_out_unary_cuda,
35    };
36    float_out_binary_cuda!();
37    impl_cuda_normal_out_binary!();
38    impl_normal_out_unary_cuda!();
39    impl_cuda_bitwise_out!();
40    impl_cmp_cuda!();
41    float_out_unary_cuda!();
42}
43
44use hpt_macros::{float_out_binary_simd, simd_bitwise};
45
46/// this trait is used to perform type promotion in dynamic graph
47pub trait FloatOutBinary<RHS = Self> {
48    /// the output type
49    type Output;
50    /// perform a / b
51    fn _div(self, rhs: RHS) -> Self::Output;
52    /// perform log<sub>b</sub>(x)
53    fn _log(self, base: RHS) -> Self::Output;
54    /// perform hypot(x, y)
55    fn _hypot(self, rhs: RHS) -> Self::Output;
56}
57
58/// this trait is used to perform type promotion for float out binary operations
59pub trait FloatOutBinaryPromote<RHS = Self> {
60    /// the output type
61    type Output;
62}
63
64/// internal trait for float out binary
65pub trait FloatOutBinary2 {
66    /// perform a / b
67    fn __div(self, rhs: Self) -> Self;
68    /// perform log<sub>b</sub>(x)
69    fn __log(self, base: Self) -> Self;
70    /// perform hypot(x, y)
71    fn __hypot(self, rhs: Self) -> Self;
72}
73
74float_out_binary!();
75float_out_binary_simd!();
76float_out_binary_simd_with_rhs_scalar!();
77float_out_binary_simd_with_lhs_scalar!();
78
79/// this trait is used to perform normal operations that don't require type promotion
80pub trait NormalOut<RHS = Self> {
81    /// the output type
82    type Output;
83    /// perform a + b
84    fn _add(self, rhs: RHS) -> Self::Output;
85    /// perform a - b
86    fn _sub(self, rhs: RHS) -> Self::Output;
87    /// perform self * a + b, fused multiply add
88    /// if the hardware supports it, it can speed up the calculation and reduce the rounding error
89    fn _mul_add(self, a: RHS, b: RHS) -> Self::Output;
90    /// perform a * b
91    fn _mul(self, rhs: RHS) -> Self::Output;
92    /// perform a<sup>b</sup>
93    fn _pow(self, rhs: RHS) -> Self::Output;
94    /// perform a % b
95    fn _rem(self, rhs: RHS) -> Self::Output;
96    /// perform max(x, y)
97    fn _max(self, rhs: RHS) -> Self::Output;
98    /// perform min(x, y)
99    fn _min(self, rhs: RHS) -> Self::Output;
100    /// restrict the value of x to the range [min, max]
101    fn _clamp(self, min: RHS, max: RHS) -> Self::Output;
102}
103
104/// internal trait for normal out
105pub trait NormalOut2 {
106    /// perform a + b
107    fn __add(self, rhs: Self) -> Self;
108    /// perform a - b
109    fn __sub(self, rhs: Self) -> Self;
110    /// perform self * a + b, fused multiply add
111    /// if the hardware supports it, it can speed up the calculation and reduce the rounding error
112    fn __mul_add(self, a: Self, b: Self) -> Self;
113    /// perform a * b
114    fn __mul(self, rhs: Self) -> Self;
115    /// perform a<sup>b</sup>
116    fn __pow(self, rhs: Self) -> Self;
117    /// perform a % b
118    fn __rem(self, rhs: Self) -> Self;
119    /// perform max(x, y)
120    fn __max(self, rhs: Self) -> Self;
121    /// perform min(x, y)
122    fn __min(self, rhs: Self) -> Self;
123    /// restrict the value of x to the range [min, max]
124    fn __clamp(self, min: Self, max: Self) -> Self;
125}
126
127/// this trait is used to perform type promotion for normal out operations
128pub trait NormalOutPromote<RHS = Self> {
129    /// the output type
130    type Output;
131}
132
133impl_normal_out_binary!();
134
135impl_normal_out_simd!();
136
137impl_normal_out_simd_with_rhs_scalar!();
138
139impl_normal_out_simd_with_lhs_scalar!();
140
141//~^ NormalOutUnary is not implemented for {Self}
142/// this trait is used to perform normal unary operations that don't require type promotion
143pub trait NormalOutUnary {
144    /// perform x<sup>2</sup>
145    fn _square(self) -> Self;
146    /// perform |x|
147    fn _abs(self) -> Self;
148    /// perform &lceil;x&rceil;
149    fn _ceil(self) -> Self;
150    /// perform &lfloor;x&rfloor;
151    fn _floor(self) -> Self;
152    /// perform -x
153    fn _neg(self) -> Self;
154    /// perform rounding
155    fn _round(self) -> Self;
156    /// get the sign of x
157    fn _signum(self) -> Self;
158    /// perform truncation
159    fn _trunc(self) -> Self;
160
161    /// Perform the leaky ReLU (Rectified Linear Unit) activation function.
162    ///
163    /// Formula: f(x) = x if x > 0 else alpha * x
164    fn _leaky_relu(self, alpha: Self) -> Self;
165
166    /// Perform the ReLU (Rectified Linear Unit) activation function.
167    ///
168    /// Formula: f(x) = max(0, x)
169    fn _relu(self) -> Self;
170
171    /// Perform the ReLU6 activation function.
172    ///
173    /// Formula: f(x) = min(6, max(0, x))
174    fn _relu6(self) -> Self;
175
176    /// Perform the copysign function.
177    ///
178    /// Formula: f(x, y) = x * sign(y)
179    fn _copysign(self, rhs: Self) -> Self;
180}
181
182/// internal trait for normal out unary
183pub trait NormalOutUnary2 {
184    /// perform x<sup>2</sup>
185    fn __square(self) -> Self;
186    /// perform |x|
187    fn __abs(self) -> Self;
188    /// perform &lceil;x&rceil;
189    fn __ceil(self) -> Self;
190    /// perform &lfloor;x&rfloor;
191    fn __floor(self) -> Self;
192    /// perform -x
193    fn __neg(self) -> Self;
194    /// perform rounding
195    fn __round(self) -> Self;
196    /// get the sign of x
197    fn __signum(self) -> Self;
198    /// perform truncation
199    fn __trunc(self) -> Self;
200    /// Perform the leaky ReLU (Rectified Linear Unit) activation function.
201    ///
202    /// Formula: f(x) = x if x > 0 else alpha * x
203    fn __leaky_relu(self, alpha: Self) -> Self;
204
205    /// Perform the ReLU (Rectified Linear Unit) activation function.
206    ///
207    /// Formula: f(x) = max(0, x)
208    fn __relu(self) -> Self;
209
210    /// Perform the ReLU6 activation function.
211    ///
212    /// Formula: f(x) = min(6, max(0, x))
213    fn __relu6(self) -> Self;
214
215    /// Perform the copysign function.
216    ///
217    /// Formula: f(x, y) = x * sign(y)
218    fn __copysign(self, rhs: Self) -> Self;
219}
220
221impl_normal_out_unary!();
222
223impl_normal_out_unary_simd!();
224
225/// this trait is used to perform bitwise operations
226pub trait BitWiseOut<RHS = Self> {
227    /// the output type
228    type Output;
229    /// perform a & b
230    fn _bitand(self, rhs: RHS) -> Self::Output;
231    /// perform a | b
232    fn _bitor(self, rhs: RHS) -> Self::Output;
233    /// perform a ^ b
234    fn _bitxor(self, rhs: RHS) -> Self::Output;
235    /// perform !a
236    fn _not(self) -> Self::Output;
237    /// perform a << b
238    fn _shl(self, rhs: RHS) -> Self::Output;
239    /// perform a >> b
240    fn _shr(self, rhs: RHS) -> Self::Output;
241}
242
243/// internal trait for bitwise out
244pub trait BitWiseOut2 {
245    /// perform a & b
246    fn __bitand(self, rhs: Self) -> Self;
247    /// perform a | b
248    fn __bitor(self, rhs: Self) -> Self;
249    /// perform a ^ b
250    fn __bitxor(self, rhs: Self) -> Self;
251    /// perform !a
252    fn __not(self) -> Self;
253    /// perform a << b
254    fn __shl(self, rhs: Self) -> Self;
255    /// perform a >> b
256    fn __shr(self, rhs: Self) -> Self;
257}
258
259impl_bitwise_out!();
260
261simd_bitwise!();
262
263/// this trait is used to perform comparison operations
264pub trait Cmp<RHS = Self> {
265    /// the output type
266    type Output;
267    /// perform a == b
268    fn _eq(self, rhs: RHS) -> Self::Output;
269    /// perform a != b
270    fn _ne(self, rhs: RHS) -> Self::Output;
271    /// perform a < b
272    fn _lt(self, rhs: RHS) -> Self::Output;
273    /// perform a <= b
274    fn _le(self, rhs: RHS) -> Self::Output;
275    /// perform a > b
276    fn _gt(self, rhs: RHS) -> Self::Output;
277    /// perform a >= b
278    fn _ge(self, rhs: RHS) -> Self::Output;
279}
280impl_cmp!();
281
282/// this trait is used to perform comparison operations on simd
283pub trait SimdCmp<RHS = Self> {
284    /// the output type
285    type Output;
286    /// perform a == b, return a mask
287    ///
288    /// # Note
289    ///
290    /// The mask may not be a boolean value, the type is based on the byte width of the simd
291    fn _eq(self, rhs: RHS) -> Self::Output;
292    /// perform a != b, return a mask
293    ///
294    /// # Note
295    ///
296    /// The mask may not be a boolean value, the type is based on the byte width of the simd
297    fn _ne(self, rhs: RHS) -> Self::Output;
298    /// perform a < b, return a mask
299    ///
300    /// # Note
301    ///
302    /// The mask may not be a boolean value, the type is based on the byte width of the simd
303    fn _lt(self, rhs: RHS) -> Self::Output;
304    /// perform a <= b, return a mask
305    ///
306    /// # Note
307    ///
308    /// The mask may not be a boolean value, the type is based on the byte width of the simd
309    fn _le(self, rhs: RHS) -> Self::Output;
310    /// perform a > b, return a mask
311    ///
312    /// # Note
313    ///
314    /// The mask may not be a boolean value, the type is based on the byte width of the simd
315    fn _gt(self, rhs: RHS) -> Self::Output;
316    /// perform a >= b, return a mask
317    ///
318    /// # Note
319    ///
320    /// The mask may not be a boolean value, the type is based on the byte width of the simd
321    fn _ge(self, rhs: RHS) -> Self::Output;
322}
323
324/// this trait is used to perform comparison operations on simd
325pub trait SimdCmpPromote<RHS = Self> {
326    /// the output type
327    type Output;
328}
329
330simd_cmp!();
331
332/// this trait is used to perform evaluation operations
333pub trait Eval {
334    /// the output type
335    type Output;
336    /// check if the value is nan
337    fn _is_nan(&self) -> Self::Output;
338    /// check if the value is finite
339    fn _is_true(&self) -> Self::Output;
340    /// check if the value is infinite
341    fn _is_inf(&self) -> Self::Output;
342}
343
344/// internal trait for eval
345pub trait Eval2 {
346    /// the output type
347    type Output;
348    /// check if the value is nan
349    fn __is_nan(&self) -> Self::Output;
350    /// check if the value is finite
351    fn __is_true(&self) -> Self::Output;
352    /// check if the value is infinite
353    fn __is_inf(&self) -> Self::Output;
354}
355
356impl_eval!();
357simd_eval!();
358
359//~^ FloatOutUnary is not implemented for {Self}
360/// This trait is used to perform various unary floating-point operations.
361pub trait FloatOutUnary {
362    /// The output type.
363    type Output;
364
365    /// Perform the natural exponential function: e<sup>x</sup>.
366    fn _exp(self) -> Self::Output;
367
368    /// Perform the natural exponential function: e<sup>x</sup> - 1.
369    fn _expm1(self) -> Self::Output;
370
371    /// Perform the base-2 exponential function: 2<sup>x</sup>.
372    fn _exp2(self) -> Self::Output;
373
374    /// Perform the base-10 exponential function: 10<sup>x</sup>.
375    fn _exp10(self) -> Self::Output;
376
377    /// Perform the natural logarithm: ln(x).
378    fn _ln(self) -> Self::Output;
379
380    /// Perform the natural logarithm: ln(x + 1).
381    fn _log1p(self) -> Self::Output;
382
383    /// Perform the CELU (Continuously Differentiable Exponential Linear Unit) activation function.
384    ///
385    /// Formula: f(x) = max(0, x) + min(0, alpha * (e<sup>(x / alpha)</sup> - 1))
386    fn _celu(self, alpha: Self::Output) -> Self::Output;
387
388    /// Perform the base-2 logarithm: log<sub>2</sub>(x).
389    fn _log2(self) -> Self::Output;
390
391    /// Perform the base-10 logarithm: log<sub>10</sub>(x).
392    fn _log10(self) -> Self::Output;
393
394    /// Perform the square root: √x.
395    fn _sqrt(self) -> Self::Output;
396
397    /// Perform the sine function: sin(x).
398    fn _sin(self) -> Self::Output;
399
400    /// Perform the cosine function: cos(x).
401    fn _cos(self) -> Self::Output;
402
403    /// Perform the sine and cosine functions: sin(x) and cos(x).
404    fn _sincos(self) -> (Self::Output, Self::Output);
405
406    /// Perform the tangent function: tan(x).
407    fn _tan(self) -> Self::Output;
408
409    /// Perform the inverse sine (arcsin) function: asin(x).
410    fn _asin(self) -> Self::Output;
411
412    /// Perform the inverse cosine (arccos) function: acos(x).
413    fn _acos(self) -> Self::Output;
414
415    /// Perform the inverse tangent (arctan) function: atan(x).
416    fn _atan(self) -> Self::Output;
417
418    /// Perform the inverse tangent function: atan2(y, x).
419    fn _atan2(self, rhs: Self::Output) -> Self::Output;
420
421    /// Perform the hyperbolic sine function: sinh(x).
422    fn _sinh(self) -> Self::Output;
423
424    /// Perform the hyperbolic cosine function: cosh(x).
425    fn _cosh(self) -> Self::Output;
426
427    /// Perform the hyperbolic tangent function: tanh(x).
428    fn _tanh(self) -> Self::Output;
429
430    /// Perform the inverse hyperbolic sine (arsinh) function: asinh(x).
431    fn _asinh(self) -> Self::Output;
432
433    /// Perform the inverse hyperbolic cosine (arcosh) function: acosh(x).
434    fn _acosh(self) -> Self::Output;
435
436    /// Perform the inverse hyperbolic tangent (artanh) function: atanh(x).
437    fn _atanh(self) -> Self::Output;
438
439    /// Perform the reciprocal function: 1 / x.
440    fn _recip(self) -> Self::Output;
441
442    /// Perform the error function (erf).
443    fn _erf(self) -> Self::Output;
444
445    /// Perform the sigmoid function: 1 / (1 + e<sup>-x</sup>).
446    fn _sigmoid(self) -> Self::Output;
447
448    /// Perform the ELU (Exponential Linear Unit) activation function.
449    ///
450    /// Formula: f(x) = x if x > 0 else alpha * (e<sup>x</sup> - 1)
451    fn _elu(self, alpha: Self::Output) -> Self::Output;
452
453    /// Perform the GELU (Gaussian Error Linear Unit) activation function.
454    fn _gelu(self) -> Self::Output;
455
456    /// Perform the SELU (Scaled Exponential Linear Unit) activation function.
457    ///
458    /// Formula: f(x) = scale * (x if x > 0 else alpha * (e<sup>x</sup> - 1))
459    fn _selu(self, alpha: Self::Output, scale: Self::Output) -> Self::Output;
460
461    /// Perform the hard sigmoid activation function.
462    ///
463    /// Formula: f(x) = min(1, max(0, 0.2 * x + 0.5))
464    fn _hard_sigmoid(self) -> Self::Output;
465
466    /// Perform the hard swish activation function.
467    ///
468    /// Formula: f(x) = x * min(1, max(0, 0.2 * x + 0.5))
469    fn _hard_swish(self) -> Self::Output;
470
471    /// Perform the softplus activation function.
472    ///
473    /// Formula: f(x) = ln(1 + e<sup>x</sup>)
474    fn _softplus(self) -> Self::Output;
475
476    /// Perform the softsign activation function.
477    ///
478    /// Formula: f(x) = x / (1 + |x|)
479    fn _softsign(self) -> Self::Output;
480
481    /// Perform the mish activation function.
482    ///
483    /// Formula: f(x) = x * tanh(ln(1 + e<sup>x</sup>))
484    fn _mish(self) -> Self::Output;
485
486    /// Perform the cube root function: ∛x.
487    fn _cbrt(self) -> Self::Output;
488}
489
490/// internal trait for float out unary
491pub trait FloatOutUnary2 {
492    /// Perform the natural exponential function: e<sup>x</sup>.
493    fn __exp(self) -> Self;
494
495    /// Perform the natural exponential function: e<sup>x</sup> - 1.
496    fn __expm1(self) -> Self;
497
498    /// Perform the base-2 exponential function: 2<sup>x</sup>.
499    fn __exp2(self) -> Self;
500
501    /// Perform the base-10 exponential function: 10<sup>x</sup>.
502    fn __exp10(self) -> Self;
503
504    /// Perform the natural logarithm: ln(x).
505    fn __ln(self) -> Self;
506
507    /// Perform the natural logarithm: ln(x + 1).
508    fn __log1p(self) -> Self;
509
510    /// Perform the CELU (Continuously Differentiable Exponential Linear Unit) activation function.
511    ///
512    /// Formula: f(x) = max(0, x) + min(0, alpha * (e<sup>(x / alpha)</sup> - 1))
513    fn __celu(self, alpha: Self) -> Self;
514
515    /// Perform the base-2 logarithm: log<sub>2</sub>(x).
516    fn __log2(self) -> Self;
517
518    /// Perform the base-10 logarithm: log<sub>10</sub>(x).
519    fn __log10(self) -> Self;
520
521    /// Perform the square root: √x.
522    fn __sqrt(self) -> Self;
523
524    /// Perform the sine function: sin(x).
525    fn __sin(self) -> Self;
526
527    /// Perform the cosine function: cos(x).
528    fn __cos(self) -> Self;
529
530    /// Perform the sine and cosine functions: sin(x) and cos(x).
531    fn __sincos(self) -> (Self, Self)
532    where
533        Self: Sized;
534
535    /// Perform the tangent function: tan(x).
536    fn __tan(self) -> Self;
537
538    /// Perform the inverse sine (arcsin) function: asin(x).
539    fn __asin(self) -> Self;
540
541    /// Perform the inverse cosine (arccos) function: acos(x).
542    fn __acos(self) -> Self;
543
544    /// Perform the inverse tangent (arctan) function: atan(x).
545    fn __atan(self) -> Self;
546
547    /// Perform the inverse tangent function: atan2(y, x).
548    fn __atan2(self, rhs: Self) -> Self;
549
550    /// Perform the hyperbolic sine function: sinh(x).
551    fn __sinh(self) -> Self;
552
553    /// Perform the hyperbolic cosine function: cosh(x).
554    fn __cosh(self) -> Self;
555
556    /// Perform the hyperbolic tangent function: tanh(x).
557    fn __tanh(self) -> Self;
558
559    /// Perform the inverse hyperbolic sine (arsinh) function: asinh(x).
560    fn __asinh(self) -> Self;
561
562    /// Perform the inverse hyperbolic cosine (arcosh) function: acosh(x).
563    fn __acosh(self) -> Self;
564
565    /// Perform the inverse hyperbolic tangent (artanh) function: atanh(x).
566    fn __atanh(self) -> Self;
567
568    /// Perform the reciprocal function: 1 / x.
569    fn __recip(self) -> Self;
570
571    /// Perform the error function (erf).
572    fn __erf(self) -> Self;
573
574    /// Perform the sigmoid function: 1 / (1 + e<sup>-x</sup>).
575    fn __sigmoid(self) -> Self;
576
577    /// Perform the ELU (Exponential Linear Unit) activation function.
578    ///
579    /// Formula: f(x) = x if x > 0 else alpha * (e<sup>x</sup> - 1)
580    fn __elu(self, alpha: Self) -> Self;
581
582    /// Perform the GELU (Gaussian Error Linear Unit) activation function.
583    fn __gelu(self) -> Self;
584
585    /// Perform the SELU (Scaled Exponential Linear Unit) activation function.
586    ///
587    /// Formula: f(x) = scale * (x if x > 0 else alpha * (e<sup>x</sup> - 1))
588    fn __selu(self, alpha: Self, scale: Self) -> Self;
589
590    /// Perform the hard sigmoid activation function.
591    ///
592    /// Formula: f(x) = min(1, max(0, 0.2 * x + 0.5))
593    fn __hard_sigmoid(self) -> Self;
594
595    /// Perform the hard swish activation function.
596    ///
597    /// Formula: f(x) = x * min(1, max(0, 0.2 * x + 0.5))
598    fn __hard_swish(self) -> Self;
599
600    /// Perform the softplus activation function.
601    ///
602    /// Formula: f(x) = ln(1 + e<sup>x</sup>)
603    fn __softplus(self) -> Self;
604
605    /// Perform the softsign activation function.
606    ///
607    /// Formula: f(x) = x / (1 + |x|)
608    fn __softsign(self) -> Self;
609
610    /// Perform the mish activation function.
611    ///
612    /// Formula: f(x) = x * tanh(ln(1 + e<sup>x</sup>))
613    fn __mish(self) -> Self;
614
615    /// Perform the cube root function: ∛x.
616    fn __cbrt(self) -> Self;
617}
618
619/// this trait is used to promote the float out unary trait to the output type
620pub trait FloatOutUnaryPromote {
621    /// the output type
622    type Output;
623}
624
625float_out_unary!();
626
627simd_float_out_unary!();