hpt_types/scalars/
_f16.rs1use num_traits::real::Real;
2
3use crate::{
4 convertion::Convertor,
5 type_promote::{
6 BitWiseOut2, Eval2, FloatOutBinary2, FloatOutUnary2, NormalOut2, NormalOutUnary2,
7 },
8};
9
10impl FloatOutBinary2 for half::f16 {
11 #[inline(always)]
12 fn __div(self, rhs: Self) -> Self {
13 self / rhs
14 }
15 #[inline(always)]
16 fn __log(self, base: Self) -> Self {
17 self.log(base)
18 }
19 #[inline(always)]
20 fn __hypot(self, rhs: Self) -> Self {
21 self.hypot(rhs)
22 }
23}
24
25impl NormalOut2 for half::f16 {
26 #[inline(always)]
27 fn __add(self, rhs: Self) -> Self {
28 self + rhs
29 }
30
31 #[inline(always)]
32 fn __sub(self, rhs: Self) -> Self {
33 self - rhs
34 }
35
36 #[inline(always)]
37 fn __mul_add(self, a: Self, b: Self) -> Self {
38 (self * a) + b
39 }
40
41 #[inline(always)]
42 fn __mul(self, rhs: Self) -> Self {
43 self * rhs
44 }
45
46 #[inline(always)]
47 fn __pow(self, rhs: Self) -> Self {
48 self.powf(rhs)
49 }
50
51 #[inline(always)]
52 fn __rem(self, rhs: Self) -> Self {
53 self % rhs
54 }
55
56 #[inline(always)]
57 fn __max(self, rhs: Self) -> Self {
58 self.max(rhs)
59 }
60
61 #[inline(always)]
62 fn __min(self, rhs: Self) -> Self {
63 self.min(rhs)
64 }
65
66 #[inline(always)]
67 fn __clamp(self, min: Self, max: Self) -> Self {
68 self.clamp(min, max)
69 }
70}
71
72impl NormalOutUnary2 for half::f16 {
73 #[inline(always)]
74 fn __square(self) -> Self {
75 self * self
76 }
77
78 #[inline(always)]
79 fn __abs(self) -> Self {
80 self.abs()
81 }
82
83 #[inline(always)]
84 fn __ceil(self) -> Self {
85 self.ceil()
86 }
87
88 #[inline(always)]
89 fn __floor(self) -> Self {
90 self.floor()
91 }
92
93 #[inline(always)]
94 fn __neg(self) -> Self {
95 -self
96 }
97
98 #[inline(always)]
99 fn __round(self) -> Self {
100 self.round()
101 }
102
103 #[inline(always)]
104 fn __signum(self) -> Self {
105 self.signum()
106 }
107
108 #[inline(always)]
109 fn __trunc(self) -> Self {
110 self.trunc()
111 }
112
113 #[inline(always)]
114 fn __leaky_relu(self, alpha: Self) -> Self {
115 self.max(half::f16::from_f32_const(0.0)) + alpha * self.min(half::f16::from_f32_const(0.0))
116 }
117
118 #[inline(always)]
119 fn __relu(self) -> Self {
120 self.max(half::f16::from_f32_const(0.0))
121 }
122
123 #[inline(always)]
124 fn __relu6(self) -> Self {
125 self.min(half::f16::from_f32_const(6.0))
126 .max(half::f16::from_f32_const(0.0))
127 }
128
129 #[inline(always)]
130 fn __copysign(self, rhs: Self) -> Self {
131 self.copysign(rhs)
132 }
133}
134
135impl BitWiseOut2 for half::f16 {
136 #[inline(always)]
137 fn __bitand(self, rhs: Self) -> Self {
138 half::f16::from_bits(self.to_bits() & rhs.to_bits())
139 }
140
141 #[inline(always)]
142 fn __bitor(self, rhs: Self) -> Self {
143 half::f16::from_bits(self.to_bits() | rhs.to_bits())
144 }
145
146 #[inline(always)]
147 fn __bitxor(self, rhs: Self) -> Self {
148 half::f16::from_bits(self.to_bits() ^ rhs.to_bits())
149 }
150
151 #[inline(always)]
152 fn __not(self) -> Self {
153 half::f16::from_bits(!self.to_bits())
154 }
155
156 #[inline(always)]
157 fn __shl(self, _: Self) -> Self {
158 panic!("Shift operations are not supported for half::f16")
159 }
160
161 #[inline(always)]
162 fn __shr(self, _: Self) -> Self {
163 panic!("Shift operations are not supported for half::f16")
164 }
165}
166
167impl Eval2 for half::f16 {
168 type Output = bool;
169 #[inline(always)]
170 fn __is_nan(&self) -> Self::Output {
171 self.is_nan()
172 }
173
174 #[inline(always)]
175 fn __is_true(&self) -> Self::Output {
176 *self != half::f16::from_f32_const(0.0) && !self.is_nan()
177 }
178
179 #[inline(always)]
180 fn __is_inf(&self) -> Self::Output {
181 self.is_infinite()
182 }
183}
184
185impl FloatOutUnary2 for half::f16 {
186 #[inline(always)]
187 fn __exp(self) -> Self {
188 self.exp()
189 }
190 #[inline(always)]
191 fn __expm1(self) -> Self {
192 self.to_f32().__expm1().to_f16()
193 }
194 #[inline(always)]
195 fn __exp2(self) -> Self {
196 self.exp2()
197 }
198 #[inline(always)]
199 fn __ln(self) -> Self {
200 self.ln()
201 }
202 #[inline(always)]
203 fn __log1p(self) -> Self {
204 self.to_f32().__log1p().to_f16()
205 }
206 #[inline(always)]
207 fn __celu(self, alpha: Self) -> Self {
208 self.to_f32().__celu(alpha.to_f32()).to_f16()
209 }
210 #[inline(always)]
211 fn __log2(self) -> Self {
212 self.log2()
213 }
214 #[inline(always)]
215 fn __log10(self) -> Self {
216 self.log10()
217 }
218 #[inline(always)]
219 fn __sqrt(self) -> Self {
220 self.sqrt()
221 }
222 #[inline(always)]
223 fn __sin(self) -> Self {
224 self.sin()
225 }
226 #[inline(always)]
227 fn __cos(self) -> Self {
228 self.cos()
229 }
230 #[inline(always)]
231 fn __tan(self) -> Self {
232 self.tan()
233 }
234 #[inline(always)]
235 fn __asin(self) -> Self {
236 self.asin()
237 }
238 #[inline(always)]
239 fn __acos(self) -> Self {
240 self.acos()
241 }
242 #[inline(always)]
243 fn __atan(self) -> Self {
244 self.atan()
245 }
246 #[inline(always)]
247 fn __sinh(self) -> Self {
248 self.sinh()
249 }
250 #[inline(always)]
251 fn __cosh(self) -> Self {
252 self.cosh()
253 }
254 #[inline(always)]
255 fn __tanh(self) -> Self {
256 self.tanh()
257 }
258 #[inline(always)]
259 fn __asinh(self) -> Self {
260 self.asinh()
261 }
262 #[inline(always)]
263 fn __acosh(self) -> Self {
264 self.acosh()
265 }
266 #[inline(always)]
267 fn __atanh(self) -> Self {
268 self.atanh()
269 }
270 #[inline(always)]
271 fn __recip(self) -> Self {
272 self.recip()
273 }
274 #[inline(always)]
275 fn __erf(self) -> Self {
276 self.to_f32().__erf().to_f16()
277 }
278 #[inline(always)]
279 fn __sigmoid(self) -> Self {
280 self.to_f32().__sigmoid().to_f16()
281 }
282 #[inline(always)]
283 fn __elu(self, alpha: Self) -> Self {
284 self.to_f32().__elu(alpha.to_f32()).to_f16()
285 }
286 #[inline(always)]
287 fn __gelu(self) -> Self {
288 self.to_f32().__gelu().to_f16()
289 }
290 #[inline(always)]
291 fn __selu(self, alpha: Self, scale: Self) -> Self {
292 self.to_f32()
293 .__selu(alpha.to_f32(), scale.to_f32())
294 .to_f16()
295 }
296 #[inline(always)]
297 fn __hard_sigmoid(self) -> Self {
298 self.to_f32().__hard_sigmoid().to_f16()
299 }
300 #[inline(always)]
301 fn __hard_swish(self) -> Self {
302 self.to_f32().__hard_swish().to_f16()
303 }
304 #[inline(always)]
305 fn __softplus(self) -> Self {
306 self.to_f32().__softplus().to_f16()
307 }
308 #[inline(always)]
309 fn __softsign(self) -> Self {
310 self.to_f32().__softsign().to_f16()
311 }
312 #[inline(always)]
313 fn __mish(self) -> Self {
314 self.to_f32().__mish().to_f16()
315 }
316 #[inline(always)]
317 fn __cbrt(self) -> Self {
318 self.to_f32().__cbrt().to_f16()
319 }
320 #[inline(always)]
321 fn __sincos(self) -> (Self, Self) {
322 let res = self.to_f32().sin_cos();
323 (res.0.to_f16(), res.1.to_f16())
324 }
325 #[inline(always)]
326 fn __atan2(self, rhs: Self) -> Self {
327 self.atan2(rhs)
328 }
329 #[inline(always)]
330 fn __exp10(self) -> Self {
331 self.to_f32().__exp10().to_f16()
332 }
333}