hpt_types/type_promote.rs
1use crate::into_scalar::Cast;
2use crate::into_vec::IntoVec;
3#[cfg(any(
4 all(not(target_feature = "avx2"), target_feature = "sse"),
5 target_arch = "arm",
6 target_arch = "aarch64",
7 target_feature = "neon"
8))]
9use crate::simd::_128bit::*;
10#[cfg(target_feature = "avx2")]
11use crate::simd::_256bit::*;
12#[cfg(target_feature = "avx512f")]
13use crate::simd::_512bit::*;
14use crate::traits::SimdMath;
15use crate::vectors::traits::SimdCompare;
16use crate::vectors::traits::VecTrait;
17use half::bf16;
18use half::f16;
19use hpt_macros::{
20 float_out_binary, float_out_binary_simd_with_lhs_scalar, float_out_binary_simd_with_rhs_scalar,
21 float_out_unary, impl_bitwise_out, impl_cmp, impl_eval, impl_normal_out_binary,
22 impl_normal_out_simd, impl_normal_out_simd_with_lhs_scalar,
23 impl_normal_out_simd_with_rhs_scalar, impl_normal_out_unary, impl_normal_out_unary_simd,
24 simd_cmp, simd_eval, simd_float_out_unary,
25};
26use num_complex::{Complex32, Complex64};
27use num_traits::float::Float;
28#[cfg(feature = "cuda")]
29mod cuda_imports {
30 use super::*;
31 use crate::cuda_types::scalar::Scalar;
32 use hpt_macros::{
33 float_out_binary_cuda, float_out_unary_cuda, impl_cmp_cuda, impl_cuda_bitwise_out,
34 impl_cuda_normal_out_binary, impl_normal_out_unary_cuda,
35 };
36 float_out_binary_cuda!();
37 impl_cuda_normal_out_binary!();
38 impl_normal_out_unary_cuda!();
39 impl_cuda_bitwise_out!();
40 impl_cmp_cuda!();
41 float_out_unary_cuda!();
42}
43
44use hpt_macros::{float_out_binary_simd, simd_bitwise};
45
46/// this trait is used to perform type promotion in dynamic graph
47pub trait FloatOutBinary<RHS = Self> {
48 /// the output type
49 type Output;
50 /// perform a / b
51 fn _div(self, rhs: RHS) -> Self::Output;
52 /// perform log<sub>b</sub>(x)
53 fn _log(self, base: RHS) -> Self::Output;
54 /// perform hypot(x, y)
55 fn _hypot(self, rhs: RHS) -> Self::Output;
56}
57
58/// this trait is used to perform type promotion for float out binary operations
59pub trait FloatOutBinaryPromote<RHS = Self> {
60 /// the output type
61 type Output;
62}
63
64/// internal trait for float out binary
65pub trait FloatOutBinary2 {
66 /// perform a / b
67 fn __div(self, rhs: Self) -> Self;
68 /// perform log<sub>b</sub>(x)
69 fn __log(self, base: Self) -> Self;
70 /// perform hypot(x, y)
71 fn __hypot(self, rhs: Self) -> Self;
72}
73
74float_out_binary!();
75float_out_binary_simd!();
76float_out_binary_simd_with_rhs_scalar!();
77float_out_binary_simd_with_lhs_scalar!();
78
79/// this trait is used to perform normal operations that don't require type promotion
80pub trait NormalOut<RHS = Self> {
81 /// the output type
82 type Output;
83 /// perform a + b
84 fn _add(self, rhs: RHS) -> Self::Output;
85 /// perform a - b
86 fn _sub(self, rhs: RHS) -> Self::Output;
87 /// perform self * a + b, fused multiply add
88 /// if the hardware supports it, it can speed up the calculation and reduce the rounding error
89 fn _mul_add(self, a: RHS, b: RHS) -> Self::Output;
90 /// perform a * b
91 fn _mul(self, rhs: RHS) -> Self::Output;
92 /// perform a<sup>b</sup>
93 fn _pow(self, rhs: RHS) -> Self::Output;
94 /// perform a % b
95 fn _rem(self, rhs: RHS) -> Self::Output;
96 /// perform max(x, y)
97 fn _max(self, rhs: RHS) -> Self::Output;
98 /// perform min(x, y)
99 fn _min(self, rhs: RHS) -> Self::Output;
100 /// restrict the value of x to the range [min, max]
101 fn _clamp(self, min: RHS, max: RHS) -> Self::Output;
102}
103
104/// internal trait for normal out
105pub trait NormalOut2 {
106 /// perform a + b
107 fn __add(self, rhs: Self) -> Self;
108 /// perform a - b
109 fn __sub(self, rhs: Self) -> Self;
110 /// perform self * a + b, fused multiply add
111 /// if the hardware supports it, it can speed up the calculation and reduce the rounding error
112 fn __mul_add(self, a: Self, b: Self) -> Self;
113 /// perform a * b
114 fn __mul(self, rhs: Self) -> Self;
115 /// perform a<sup>b</sup>
116 fn __pow(self, rhs: Self) -> Self;
117 /// perform a % b
118 fn __rem(self, rhs: Self) -> Self;
119 /// perform max(x, y)
120 fn __max(self, rhs: Self) -> Self;
121 /// perform min(x, y)
122 fn __min(self, rhs: Self) -> Self;
123 /// restrict the value of x to the range [min, max]
124 fn __clamp(self, min: Self, max: Self) -> Self;
125}
126
127/// this trait is used to perform type promotion for normal out operations
128pub trait NormalOutPromote<RHS = Self> {
129 /// the output type
130 type Output;
131}
132
133impl_normal_out_binary!();
134
135impl_normal_out_simd!();
136
137impl_normal_out_simd_with_rhs_scalar!();
138
139impl_normal_out_simd_with_lhs_scalar!();
140
141//~^ NormalOutUnary is not implemented for {Self}
142/// this trait is used to perform normal unary operations that don't require type promotion
143pub trait NormalOutUnary {
144 /// perform x<sup>2</sup>
145 fn _square(self) -> Self;
146 /// perform |x|
147 fn _abs(self) -> Self;
148 /// perform ⌈x⌉
149 fn _ceil(self) -> Self;
150 /// perform ⌊x⌋
151 fn _floor(self) -> Self;
152 /// perform -x
153 fn _neg(self) -> Self;
154 /// perform rounding
155 fn _round(self) -> Self;
156 /// get the sign of x
157 fn _signum(self) -> Self;
158 /// perform truncation
159 fn _trunc(self) -> Self;
160
161 /// Perform the leaky ReLU (Rectified Linear Unit) activation function.
162 ///
163 /// Formula: f(x) = x if x > 0 else alpha * x
164 fn _leaky_relu(self, alpha: Self) -> Self;
165
166 /// Perform the ReLU (Rectified Linear Unit) activation function.
167 ///
168 /// Formula: f(x) = max(0, x)
169 fn _relu(self) -> Self;
170
171 /// Perform the ReLU6 activation function.
172 ///
173 /// Formula: f(x) = min(6, max(0, x))
174 fn _relu6(self) -> Self;
175
176 /// Perform the copysign function.
177 ///
178 /// Formula: f(x, y) = x * sign(y)
179 fn _copysign(self, rhs: Self) -> Self;
180}
181
182/// internal trait for normal out unary
183pub trait NormalOutUnary2 {
184 /// perform x<sup>2</sup>
185 fn __square(self) -> Self;
186 /// perform |x|
187 fn __abs(self) -> Self;
188 /// perform ⌈x⌉
189 fn __ceil(self) -> Self;
190 /// perform ⌊x⌋
191 fn __floor(self) -> Self;
192 /// perform -x
193 fn __neg(self) -> Self;
194 /// perform rounding
195 fn __round(self) -> Self;
196 /// get the sign of x
197 fn __signum(self) -> Self;
198 /// perform truncation
199 fn __trunc(self) -> Self;
200 /// Perform the leaky ReLU (Rectified Linear Unit) activation function.
201 ///
202 /// Formula: f(x) = x if x > 0 else alpha * x
203 fn __leaky_relu(self, alpha: Self) -> Self;
204
205 /// Perform the ReLU (Rectified Linear Unit) activation function.
206 ///
207 /// Formula: f(x) = max(0, x)
208 fn __relu(self) -> Self;
209
210 /// Perform the ReLU6 activation function.
211 ///
212 /// Formula: f(x) = min(6, max(0, x))
213 fn __relu6(self) -> Self;
214
215 /// Perform the copysign function.
216 ///
217 /// Formula: f(x, y) = x * sign(y)
218 fn __copysign(self, rhs: Self) -> Self;
219}
220
221impl_normal_out_unary!();
222
223impl_normal_out_unary_simd!();
224
225/// this trait is used to perform bitwise operations
226pub trait BitWiseOut<RHS = Self> {
227 /// the output type
228 type Output;
229 /// perform a & b
230 fn _bitand(self, rhs: RHS) -> Self::Output;
231 /// perform a | b
232 fn _bitor(self, rhs: RHS) -> Self::Output;
233 /// perform a ^ b
234 fn _bitxor(self, rhs: RHS) -> Self::Output;
235 /// perform !a
236 fn _not(self) -> Self::Output;
237 /// perform a << b
238 fn _shl(self, rhs: RHS) -> Self::Output;
239 /// perform a >> b
240 fn _shr(self, rhs: RHS) -> Self::Output;
241}
242
243/// internal trait for bitwise out
244pub trait BitWiseOut2 {
245 /// perform a & b
246 fn __bitand(self, rhs: Self) -> Self;
247 /// perform a | b
248 fn __bitor(self, rhs: Self) -> Self;
249 /// perform a ^ b
250 fn __bitxor(self, rhs: Self) -> Self;
251 /// perform !a
252 fn __not(self) -> Self;
253 /// perform a << b
254 fn __shl(self, rhs: Self) -> Self;
255 /// perform a >> b
256 fn __shr(self, rhs: Self) -> Self;
257}
258
259impl_bitwise_out!();
260
261simd_bitwise!();
262
263/// this trait is used to perform comparison operations
264pub trait Cmp<RHS = Self> {
265 /// the output type
266 type Output;
267 /// perform a == b
268 fn _eq(self, rhs: RHS) -> Self::Output;
269 /// perform a != b
270 fn _ne(self, rhs: RHS) -> Self::Output;
271 /// perform a < b
272 fn _lt(self, rhs: RHS) -> Self::Output;
273 /// perform a <= b
274 fn _le(self, rhs: RHS) -> Self::Output;
275 /// perform a > b
276 fn _gt(self, rhs: RHS) -> Self::Output;
277 /// perform a >= b
278 fn _ge(self, rhs: RHS) -> Self::Output;
279}
280impl_cmp!();
281
282/// this trait is used to perform comparison operations on simd
283pub trait SimdCmp<RHS = Self> {
284 /// the output type
285 type Output;
286 /// perform a == b, return a mask
287 ///
288 /// # Note
289 ///
290 /// The mask may not be a boolean value, the type is based on the byte width of the simd
291 fn _eq(self, rhs: RHS) -> Self::Output;
292 /// perform a != b, return a mask
293 ///
294 /// # Note
295 ///
296 /// The mask may not be a boolean value, the type is based on the byte width of the simd
297 fn _ne(self, rhs: RHS) -> Self::Output;
298 /// perform a < b, return a mask
299 ///
300 /// # Note
301 ///
302 /// The mask may not be a boolean value, the type is based on the byte width of the simd
303 fn _lt(self, rhs: RHS) -> Self::Output;
304 /// perform a <= b, return a mask
305 ///
306 /// # Note
307 ///
308 /// The mask may not be a boolean value, the type is based on the byte width of the simd
309 fn _le(self, rhs: RHS) -> Self::Output;
310 /// perform a > b, return a mask
311 ///
312 /// # Note
313 ///
314 /// The mask may not be a boolean value, the type is based on the byte width of the simd
315 fn _gt(self, rhs: RHS) -> Self::Output;
316 /// perform a >= b, return a mask
317 ///
318 /// # Note
319 ///
320 /// The mask may not be a boolean value, the type is based on the byte width of the simd
321 fn _ge(self, rhs: RHS) -> Self::Output;
322}
323
324/// this trait is used to perform comparison operations on simd
325pub trait SimdCmpPromote<RHS = Self> {
326 /// the output type
327 type Output;
328}
329
330simd_cmp!();
331
332/// this trait is used to perform evaluation operations
333pub trait Eval {
334 /// the output type
335 type Output;
336 /// check if the value is nan
337 fn _is_nan(&self) -> Self::Output;
338 /// check if the value is finite
339 fn _is_true(&self) -> Self::Output;
340 /// check if the value is infinite
341 fn _is_inf(&self) -> Self::Output;
342}
343
344/// internal trait for eval
345pub trait Eval2 {
346 /// the output type
347 type Output;
348 /// check if the value is nan
349 fn __is_nan(&self) -> Self::Output;
350 /// check if the value is finite
351 fn __is_true(&self) -> Self::Output;
352 /// check if the value is infinite
353 fn __is_inf(&self) -> Self::Output;
354}
355
356impl_eval!();
357simd_eval!();
358
359//~^ FloatOutUnary is not implemented for {Self}
360/// This trait is used to perform various unary floating-point operations.
361pub trait FloatOutUnary {
362 /// The output type.
363 type Output;
364
365 /// Perform the natural exponential function: e<sup>x</sup>.
366 fn _exp(self) -> Self::Output;
367
368 /// Perform the natural exponential function: e<sup>x</sup> - 1.
369 fn _expm1(self) -> Self::Output;
370
371 /// Perform the base-2 exponential function: 2<sup>x</sup>.
372 fn _exp2(self) -> Self::Output;
373
374 /// Perform the base-10 exponential function: 10<sup>x</sup>.
375 fn _exp10(self) -> Self::Output;
376
377 /// Perform the natural logarithm: ln(x).
378 fn _ln(self) -> Self::Output;
379
380 /// Perform the natural logarithm: ln(x + 1).
381 fn _log1p(self) -> Self::Output;
382
383 /// Perform the CELU (Continuously Differentiable Exponential Linear Unit) activation function.
384 ///
385 /// Formula: f(x) = max(0, x) + min(0, alpha * (e<sup>(x / alpha)</sup> - 1))
386 fn _celu(self, alpha: Self::Output) -> Self::Output;
387
388 /// Perform the base-2 logarithm: log<sub>2</sub>(x).
389 fn _log2(self) -> Self::Output;
390
391 /// Perform the base-10 logarithm: log<sub>10</sub>(x).
392 fn _log10(self) -> Self::Output;
393
394 /// Perform the square root: √x.
395 fn _sqrt(self) -> Self::Output;
396
397 /// Perform the sine function: sin(x).
398 fn _sin(self) -> Self::Output;
399
400 /// Perform the cosine function: cos(x).
401 fn _cos(self) -> Self::Output;
402
403 /// Perform the sine and cosine functions: sin(x) and cos(x).
404 fn _sincos(self) -> (Self::Output, Self::Output);
405
406 /// Perform the tangent function: tan(x).
407 fn _tan(self) -> Self::Output;
408
409 /// Perform the inverse sine (arcsin) function: asin(x).
410 fn _asin(self) -> Self::Output;
411
412 /// Perform the inverse cosine (arccos) function: acos(x).
413 fn _acos(self) -> Self::Output;
414
415 /// Perform the inverse tangent (arctan) function: atan(x).
416 fn _atan(self) -> Self::Output;
417
418 /// Perform the inverse tangent function: atan2(y, x).
419 fn _atan2(self, rhs: Self::Output) -> Self::Output;
420
421 /// Perform the hyperbolic sine function: sinh(x).
422 fn _sinh(self) -> Self::Output;
423
424 /// Perform the hyperbolic cosine function: cosh(x).
425 fn _cosh(self) -> Self::Output;
426
427 /// Perform the hyperbolic tangent function: tanh(x).
428 fn _tanh(self) -> Self::Output;
429
430 /// Perform the inverse hyperbolic sine (arsinh) function: asinh(x).
431 fn _asinh(self) -> Self::Output;
432
433 /// Perform the inverse hyperbolic cosine (arcosh) function: acosh(x).
434 fn _acosh(self) -> Self::Output;
435
436 /// Perform the inverse hyperbolic tangent (artanh) function: atanh(x).
437 fn _atanh(self) -> Self::Output;
438
439 /// Perform the reciprocal function: 1 / x.
440 fn _recip(self) -> Self::Output;
441
442 /// Perform the error function (erf).
443 fn _erf(self) -> Self::Output;
444
445 /// Perform the sigmoid function: 1 / (1 + e<sup>-x</sup>).
446 fn _sigmoid(self) -> Self::Output;
447
448 /// Perform the ELU (Exponential Linear Unit) activation function.
449 ///
450 /// Formula: f(x) = x if x > 0 else alpha * (e<sup>x</sup> - 1)
451 fn _elu(self, alpha: Self::Output) -> Self::Output;
452
453 /// Perform the GELU (Gaussian Error Linear Unit) activation function.
454 fn _gelu(self) -> Self::Output;
455
456 /// Perform the SELU (Scaled Exponential Linear Unit) activation function.
457 ///
458 /// Formula: f(x) = scale * (x if x > 0 else alpha * (e<sup>x</sup> - 1))
459 fn _selu(self, alpha: Self::Output, scale: Self::Output) -> Self::Output;
460
461 /// Perform the hard sigmoid activation function.
462 ///
463 /// Formula: f(x) = min(1, max(0, 0.2 * x + 0.5))
464 fn _hard_sigmoid(self) -> Self::Output;
465
466 /// Perform the hard swish activation function.
467 ///
468 /// Formula: f(x) = x * min(1, max(0, 0.2 * x + 0.5))
469 fn _hard_swish(self) -> Self::Output;
470
471 /// Perform the softplus activation function.
472 ///
473 /// Formula: f(x) = ln(1 + e<sup>x</sup>)
474 fn _softplus(self) -> Self::Output;
475
476 /// Perform the softsign activation function.
477 ///
478 /// Formula: f(x) = x / (1 + |x|)
479 fn _softsign(self) -> Self::Output;
480
481 /// Perform the mish activation function.
482 ///
483 /// Formula: f(x) = x * tanh(ln(1 + e<sup>x</sup>))
484 fn _mish(self) -> Self::Output;
485
486 /// Perform the cube root function: ∛x.
487 fn _cbrt(self) -> Self::Output;
488}
489
490/// internal trait for float out unary
491pub trait FloatOutUnary2 {
492 /// Perform the natural exponential function: e<sup>x</sup>.
493 fn __exp(self) -> Self;
494
495 /// Perform the natural exponential function: e<sup>x</sup> - 1.
496 fn __expm1(self) -> Self;
497
498 /// Perform the base-2 exponential function: 2<sup>x</sup>.
499 fn __exp2(self) -> Self;
500
501 /// Perform the base-10 exponential function: 10<sup>x</sup>.
502 fn __exp10(self) -> Self;
503
504 /// Perform the natural logarithm: ln(x).
505 fn __ln(self) -> Self;
506
507 /// Perform the natural logarithm: ln(x + 1).
508 fn __log1p(self) -> Self;
509
510 /// Perform the CELU (Continuously Differentiable Exponential Linear Unit) activation function.
511 ///
512 /// Formula: f(x) = max(0, x) + min(0, alpha * (e<sup>(x / alpha)</sup> - 1))
513 fn __celu(self, alpha: Self) -> Self;
514
515 /// Perform the base-2 logarithm: log<sub>2</sub>(x).
516 fn __log2(self) -> Self;
517
518 /// Perform the base-10 logarithm: log<sub>10</sub>(x).
519 fn __log10(self) -> Self;
520
521 /// Perform the square root: √x.
522 fn __sqrt(self) -> Self;
523
524 /// Perform the sine function: sin(x).
525 fn __sin(self) -> Self;
526
527 /// Perform the cosine function: cos(x).
528 fn __cos(self) -> Self;
529
530 /// Perform the sine and cosine functions: sin(x) and cos(x).
531 fn __sincos(self) -> (Self, Self)
532 where
533 Self: Sized;
534
535 /// Perform the tangent function: tan(x).
536 fn __tan(self) -> Self;
537
538 /// Perform the inverse sine (arcsin) function: asin(x).
539 fn __asin(self) -> Self;
540
541 /// Perform the inverse cosine (arccos) function: acos(x).
542 fn __acos(self) -> Self;
543
544 /// Perform the inverse tangent (arctan) function: atan(x).
545 fn __atan(self) -> Self;
546
547 /// Perform the inverse tangent function: atan2(y, x).
548 fn __atan2(self, rhs: Self) -> Self;
549
550 /// Perform the hyperbolic sine function: sinh(x).
551 fn __sinh(self) -> Self;
552
553 /// Perform the hyperbolic cosine function: cosh(x).
554 fn __cosh(self) -> Self;
555
556 /// Perform the hyperbolic tangent function: tanh(x).
557 fn __tanh(self) -> Self;
558
559 /// Perform the inverse hyperbolic sine (arsinh) function: asinh(x).
560 fn __asinh(self) -> Self;
561
562 /// Perform the inverse hyperbolic cosine (arcosh) function: acosh(x).
563 fn __acosh(self) -> Self;
564
565 /// Perform the inverse hyperbolic tangent (artanh) function: atanh(x).
566 fn __atanh(self) -> Self;
567
568 /// Perform the reciprocal function: 1 / x.
569 fn __recip(self) -> Self;
570
571 /// Perform the error function (erf).
572 fn __erf(self) -> Self;
573
574 /// Perform the sigmoid function: 1 / (1 + e<sup>-x</sup>).
575 fn __sigmoid(self) -> Self;
576
577 /// Perform the ELU (Exponential Linear Unit) activation function.
578 ///
579 /// Formula: f(x) = x if x > 0 else alpha * (e<sup>x</sup> - 1)
580 fn __elu(self, alpha: Self) -> Self;
581
582 /// Perform the GELU (Gaussian Error Linear Unit) activation function.
583 fn __gelu(self) -> Self;
584
585 /// Perform the SELU (Scaled Exponential Linear Unit) activation function.
586 ///
587 /// Formula: f(x) = scale * (x if x > 0 else alpha * (e<sup>x</sup> - 1))
588 fn __selu(self, alpha: Self, scale: Self) -> Self;
589
590 /// Perform the hard sigmoid activation function.
591 ///
592 /// Formula: f(x) = min(1, max(0, 0.2 * x + 0.5))
593 fn __hard_sigmoid(self) -> Self;
594
595 /// Perform the hard swish activation function.
596 ///
597 /// Formula: f(x) = x * min(1, max(0, 0.2 * x + 0.5))
598 fn __hard_swish(self) -> Self;
599
600 /// Perform the softplus activation function.
601 ///
602 /// Formula: f(x) = ln(1 + e<sup>x</sup>)
603 fn __softplus(self) -> Self;
604
605 /// Perform the softsign activation function.
606 ///
607 /// Formula: f(x) = x / (1 + |x|)
608 fn __softsign(self) -> Self;
609
610 /// Perform the mish activation function.
611 ///
612 /// Formula: f(x) = x * tanh(ln(1 + e<sup>x</sup>))
613 fn __mish(self) -> Self;
614
615 /// Perform the cube root function: ∛x.
616 fn __cbrt(self) -> Self;
617}
618
619/// this trait is used to promote the float out unary trait to the output type
620pub trait FloatOutUnaryPromote {
621 /// the output type
622 type Output;
623}
624
625float_out_unary!();
626
627simd_float_out_unary!();