hpt_types/scalars/
_f16.rs

1use num_traits::real::Real;
2
3use crate::{
4    convertion::Convertor,
5    type_promote::{
6        BitWiseOut2, Eval2, FloatOutBinary2, FloatOutUnary2, NormalOut2, NormalOutUnary2,
7    },
8};
9
10impl FloatOutBinary2 for half::f16 {
11    #[inline(always)]
12    fn __div(self, rhs: Self) -> Self {
13        self / rhs
14    }
15    #[inline(always)]
16    fn __log(self, base: Self) -> Self {
17        self.log(base)
18    }
19    #[inline(always)]
20    fn __hypot(self, rhs: Self) -> Self {
21        self.hypot(rhs)
22    }
23}
24
25impl NormalOut2 for half::f16 {
26    #[inline(always)]
27    fn __add(self, rhs: Self) -> Self {
28        self + rhs
29    }
30
31    #[inline(always)]
32    fn __sub(self, rhs: Self) -> Self {
33        self - rhs
34    }
35
36    #[inline(always)]
37    fn __mul_add(self, a: Self, b: Self) -> Self {
38        (self * a) + b
39    }
40
41    #[inline(always)]
42    fn __mul(self, rhs: Self) -> Self {
43        self * rhs
44    }
45
46    #[inline(always)]
47    fn __pow(self, rhs: Self) -> Self {
48        self.powf(rhs)
49    }
50
51    #[inline(always)]
52    fn __rem(self, rhs: Self) -> Self {
53        self % rhs
54    }
55
56    #[inline(always)]
57    fn __max(self, rhs: Self) -> Self {
58        self.max(rhs)
59    }
60
61    #[inline(always)]
62    fn __min(self, rhs: Self) -> Self {
63        self.min(rhs)
64    }
65
66    #[inline(always)]
67    fn __clamp(self, min: Self, max: Self) -> Self {
68        self.clamp(min, max)
69    }
70}
71
72impl NormalOutUnary2 for half::f16 {
73    #[inline(always)]
74    fn __square(self) -> Self {
75        self * self
76    }
77
78    #[inline(always)]
79    fn __abs(self) -> Self {
80        self.abs()
81    }
82
83    #[inline(always)]
84    fn __ceil(self) -> Self {
85        self.ceil()
86    }
87
88    #[inline(always)]
89    fn __floor(self) -> Self {
90        self.floor()
91    }
92
93    #[inline(always)]
94    fn __neg(self) -> Self {
95        -self
96    }
97
98    #[inline(always)]
99    fn __round(self) -> Self {
100        self.round()
101    }
102
103    #[inline(always)]
104    fn __signum(self) -> Self {
105        self.signum()
106    }
107
108    #[inline(always)]
109    fn __trunc(self) -> Self {
110        self.trunc()
111    }
112
113    #[inline(always)]
114    fn __leaky_relu(self, alpha: Self) -> Self {
115        self.max(half::f16::from_f32_const(0.0)) + alpha * self.min(half::f16::from_f32_const(0.0))
116    }
117
118    #[inline(always)]
119    fn __relu(self) -> Self {
120        self.max(half::f16::from_f32_const(0.0))
121    }
122
123    #[inline(always)]
124    fn __relu6(self) -> Self {
125        self.min(half::f16::from_f32_const(6.0))
126            .max(half::f16::from_f32_const(0.0))
127    }
128
129    #[inline(always)]
130    fn __copysign(self, rhs: Self) -> Self {
131        self.copysign(rhs)
132    }
133}
134
135impl BitWiseOut2 for half::f16 {
136    #[inline(always)]
137    fn __bitand(self, rhs: Self) -> Self {
138        half::f16::from_bits(self.to_bits() & rhs.to_bits())
139    }
140
141    #[inline(always)]
142    fn __bitor(self, rhs: Self) -> Self {
143        half::f16::from_bits(self.to_bits() | rhs.to_bits())
144    }
145
146    #[inline(always)]
147    fn __bitxor(self, rhs: Self) -> Self {
148        half::f16::from_bits(self.to_bits() ^ rhs.to_bits())
149    }
150
151    #[inline(always)]
152    fn __not(self) -> Self {
153        half::f16::from_bits(!self.to_bits())
154    }
155
156    #[inline(always)]
157    fn __shl(self, _: Self) -> Self {
158        panic!("Shift operations are not supported for half::f16")
159    }
160
161    #[inline(always)]
162    fn __shr(self, _: Self) -> Self {
163        panic!("Shift operations are not supported for half::f16")
164    }
165}
166
167impl Eval2 for half::f16 {
168    type Output = bool;
169    #[inline(always)]
170    fn __is_nan(&self) -> Self::Output {
171        self.is_nan()
172    }
173
174    #[inline(always)]
175    fn __is_true(&self) -> Self::Output {
176        *self != half::f16::from_f32_const(0.0) && !self.is_nan()
177    }
178
179    #[inline(always)]
180    fn __is_inf(&self) -> Self::Output {
181        self.is_infinite()
182    }
183}
184
185impl FloatOutUnary2 for half::f16 {
186    #[inline(always)]
187    fn __exp(self) -> Self {
188        self.exp()
189    }
190    #[inline(always)]
191    fn __expm1(self) -> Self {
192        self.to_f32().__expm1().to_f16()
193    }
194    #[inline(always)]
195    fn __exp2(self) -> Self {
196        self.exp2()
197    }
198    #[inline(always)]
199    fn __ln(self) -> Self {
200        self.ln()
201    }
202    #[inline(always)]
203    fn __log1p(self) -> Self {
204        self.to_f32().__log1p().to_f16()
205    }
206    #[inline(always)]
207    fn __celu(self, alpha: Self) -> Self {
208        self.to_f32().__celu(alpha.to_f32()).to_f16()
209    }
210    #[inline(always)]
211    fn __log2(self) -> Self {
212        self.log2()
213    }
214    #[inline(always)]
215    fn __log10(self) -> Self {
216        self.log10()
217    }
218    #[inline(always)]
219    fn __sqrt(self) -> Self {
220        self.sqrt()
221    }
222    #[inline(always)]
223    fn __sin(self) -> Self {
224        self.sin()
225    }
226    #[inline(always)]
227    fn __cos(self) -> Self {
228        self.cos()
229    }
230    #[inline(always)]
231    fn __tan(self) -> Self {
232        self.tan()
233    }
234    #[inline(always)]
235    fn __asin(self) -> Self {
236        self.asin()
237    }
238    #[inline(always)]
239    fn __acos(self) -> Self {
240        self.acos()
241    }
242    #[inline(always)]
243    fn __atan(self) -> Self {
244        self.atan()
245    }
246    #[inline(always)]
247    fn __sinh(self) -> Self {
248        self.sinh()
249    }
250    #[inline(always)]
251    fn __cosh(self) -> Self {
252        self.cosh()
253    }
254    #[inline(always)]
255    fn __tanh(self) -> Self {
256        self.tanh()
257    }
258    #[inline(always)]
259    fn __asinh(self) -> Self {
260        self.asinh()
261    }
262    #[inline(always)]
263    fn __acosh(self) -> Self {
264        self.acosh()
265    }
266    #[inline(always)]
267    fn __atanh(self) -> Self {
268        self.atanh()
269    }
270    #[inline(always)]
271    fn __recip(self) -> Self {
272        self.recip()
273    }
274    #[inline(always)]
275    fn __erf(self) -> Self {
276        self.to_f32().__erf().to_f16()
277    }
278    #[inline(always)]
279    fn __sigmoid(self) -> Self {
280        self.to_f32().__sigmoid().to_f16()
281    }
282    #[inline(always)]
283    fn __elu(self, alpha: Self) -> Self {
284        self.to_f32().__elu(alpha.to_f32()).to_f16()
285    }
286    #[inline(always)]
287    fn __gelu(self) -> Self {
288        self.to_f32().__gelu().to_f16()
289    }
290    #[inline(always)]
291    fn __selu(self, alpha: Self, scale: Self) -> Self {
292        self.to_f32()
293            .__selu(alpha.to_f32(), scale.to_f32())
294            .to_f16()
295    }
296    #[inline(always)]
297    fn __hard_sigmoid(self) -> Self {
298        self.to_f32().__hard_sigmoid().to_f16()
299    }
300    #[inline(always)]
301    fn __hard_swish(self) -> Self {
302        self.to_f32().__hard_swish().to_f16()
303    }
304    #[inline(always)]
305    fn __softplus(self) -> Self {
306        self.to_f32().__softplus().to_f16()
307    }
308    #[inline(always)]
309    fn __softsign(self) -> Self {
310        self.to_f32().__softsign().to_f16()
311    }
312    #[inline(always)]
313    fn __mish(self) -> Self {
314        self.to_f32().__mish().to_f16()
315    }
316    #[inline(always)]
317    fn __cbrt(self) -> Self {
318        self.to_f32().__cbrt().to_f16()
319    }
320    #[inline(always)]
321    fn __sincos(self) -> (Self, Self) {
322        let res = self.to_f32().sin_cos();
323        (res.0.to_f16(), res.1.to_f16())
324    }
325    #[inline(always)]
326    fn __atan2(self, rhs: Self) -> Self {
327        self.atan2(rhs)
328    }
329    #[inline(always)]
330    fn __exp10(self) -> Self {
331        self.to_f32().__exp10().to_f16()
332    }
333}