hpt_types/scalars/
_bf16.rs1use num_traits::real::Real;
2
3use crate::{
4 convertion::Convertor,
5 type_promote::{
6 BitWiseOut2, Eval2, FloatOutBinary2, FloatOutUnary2, NormalOut2, NormalOutUnary2,
7 },
8};
9
10impl FloatOutBinary2 for half::bf16 {
11 #[inline(always)]
12 fn __div(self, rhs: Self) -> Self {
13 self / rhs
14 }
15 #[inline(always)]
16 fn __log(self, base: Self) -> Self {
17 self.log(base)
18 }
19 #[inline(always)]
20 fn __hypot(self, rhs: Self) -> Self {
21 self.hypot(rhs)
22 }
23}
24
25impl NormalOut2 for half::bf16 {
26 #[inline(always)]
27 fn __add(self, rhs: Self) -> Self {
28 self + rhs
29 }
30
31 #[inline(always)]
32 fn __sub(self, rhs: Self) -> Self {
33 self - rhs
34 }
35
36 #[inline(always)]
37 fn __mul_add(self, a: Self, b: Self) -> Self {
38 (self * a) + b
39 }
40
41 #[inline(always)]
42 fn __mul(self, rhs: Self) -> Self {
43 self * rhs
44 }
45
46 #[inline(always)]
47 fn __pow(self, rhs: Self) -> Self {
48 self.powf(rhs)
49 }
50
51 #[inline(always)]
52 fn __rem(self, rhs: Self) -> Self {
53 self % rhs
54 }
55
56 #[inline(always)]
57 fn __max(self, rhs: Self) -> Self {
58 self.max(rhs)
59 }
60
61 #[inline(always)]
62 fn __min(self, rhs: Self) -> Self {
63 self.min(rhs)
64 }
65
66 #[inline(always)]
67 fn __clamp(self, min: Self, max: Self) -> Self {
68 self.clamp(min, max)
69 }
70}
71
72impl NormalOutUnary2 for half::bf16 {
73 #[inline(always)]
74 fn __square(self) -> Self {
75 self * self
76 }
77
78 #[inline(always)]
79 fn __abs(self) -> Self {
80 self.abs()
81 }
82
83 #[inline(always)]
84 fn __ceil(self) -> Self {
85 self.ceil()
86 }
87
88 #[inline(always)]
89 fn __floor(self) -> Self {
90 self.floor()
91 }
92
93 #[inline(always)]
94 fn __neg(self) -> Self {
95 -self
96 }
97
98 #[inline(always)]
99 fn __round(self) -> Self {
100 self.round()
101 }
102
103 #[inline(always)]
104 fn __signum(self) -> Self {
105 self.signum()
106 }
107
108 #[inline(always)]
109 fn __trunc(self) -> Self {
110 self.trunc()
111 }
112
113 #[inline(always)]
114 fn __leaky_relu(self, alpha: Self) -> Self {
115 self.max(half::bf16::from_f32_const(0.0))
116 + alpha * self.min(half::bf16::from_f32_const(0.0))
117 }
118
119 #[inline(always)]
120 fn __relu(self) -> Self {
121 self.max(half::bf16::from_f32_const(0.0))
122 }
123
124 #[inline(always)]
125 fn __relu6(self) -> Self {
126 self.min(half::bf16::from_f32_const(6.0))
127 .max(half::bf16::from_f32_const(0.0))
128 }
129
130 #[inline(always)]
131 fn __copysign(self, rhs: Self) -> Self {
132 self.copysign(rhs)
133 }
134}
135
136impl BitWiseOut2 for half::bf16 {
137 #[inline(always)]
138 fn __bitand(self, rhs: Self) -> Self {
139 half::bf16::from_bits(self.to_bits() & rhs.to_bits())
140 }
141
142 #[inline(always)]
143 fn __bitor(self, rhs: Self) -> Self {
144 half::bf16::from_bits(self.to_bits() | rhs.to_bits())
145 }
146
147 #[inline(always)]
148 fn __bitxor(self, rhs: Self) -> Self {
149 half::bf16::from_bits(self.to_bits() ^ rhs.to_bits())
150 }
151
152 #[inline(always)]
153 fn __not(self) -> Self {
154 half::bf16::from_bits(!self.to_bits())
155 }
156
157 #[inline(always)]
158 fn __shl(self, _: Self) -> Self {
159 panic!("Shift operations are not supported for half::bf16")
160 }
161
162 #[inline(always)]
163 fn __shr(self, _: Self) -> Self {
164 panic!("Shift operations are not supported for half::bf16")
165 }
166}
167
168impl Eval2 for half::bf16 {
169 type Output = bool;
170 #[inline(always)]
171 fn __is_nan(&self) -> Self::Output {
172 self.is_nan()
173 }
174
175 #[inline(always)]
176 fn __is_true(&self) -> Self::Output {
177 *self != half::bf16::from_f32_const(0.0) && !self.is_nan()
178 }
179
180 #[inline(always)]
181 fn __is_inf(&self) -> Self::Output {
182 self.is_infinite()
183 }
184}
185
186impl FloatOutUnary2 for half::bf16 {
187 #[inline(always)]
188 fn __exp(self) -> Self {
189 self.exp()
190 }
191 #[inline(always)]
192 fn __expm1(self) -> Self {
193 self.to_f32().__expm1().to_bf16()
194 }
195 #[inline(always)]
196 fn __exp2(self) -> Self {
197 self.exp2()
198 }
199 #[inline(always)]
200 fn __ln(self) -> Self {
201 self.ln()
202 }
203 #[inline(always)]
204 fn __log1p(self) -> Self {
205 self.to_f32().__log1p().to_bf16()
206 }
207 #[inline(always)]
208 fn __celu(self, alpha: Self) -> Self {
209 self.to_f32().__celu(alpha.to_f32()).to_bf16()
210 }
211 #[inline(always)]
212 fn __log2(self) -> Self {
213 self.log2()
214 }
215 #[inline(always)]
216 fn __log10(self) -> Self {
217 self.log10()
218 }
219 #[inline(always)]
220 fn __sqrt(self) -> Self {
221 self.sqrt()
222 }
223 #[inline(always)]
224 fn __sin(self) -> Self {
225 self.sin()
226 }
227 #[inline(always)]
228 fn __cos(self) -> Self {
229 self.cos()
230 }
231 #[inline(always)]
232 fn __tan(self) -> Self {
233 self.tan()
234 }
235 #[inline(always)]
236 fn __asin(self) -> Self {
237 self.asin()
238 }
239 #[inline(always)]
240 fn __acos(self) -> Self {
241 self.acos()
242 }
243 #[inline(always)]
244 fn __atan(self) -> Self {
245 self.atan()
246 }
247 #[inline(always)]
248 fn __sinh(self) -> Self {
249 self.sinh()
250 }
251 #[inline(always)]
252 fn __cosh(self) -> Self {
253 self.cosh()
254 }
255 #[inline(always)]
256 fn __tanh(self) -> Self {
257 self.tanh()
258 }
259 #[inline(always)]
260 fn __asinh(self) -> Self {
261 self.asinh()
262 }
263 #[inline(always)]
264 fn __acosh(self) -> Self {
265 self.acosh()
266 }
267 #[inline(always)]
268 fn __atanh(self) -> Self {
269 self.atanh()
270 }
271 #[inline(always)]
272 fn __recip(self) -> Self {
273 self.recip()
274 }
275 #[inline(always)]
276 fn __erf(self) -> Self {
277 self.to_f32().__erf().to_bf16()
278 }
279 #[inline(always)]
280 fn __sigmoid(self) -> Self {
281 self.to_f32().__sigmoid().to_bf16()
282 }
283 #[inline(always)]
284 fn __elu(self, alpha: Self) -> Self {
285 self.to_f32().__elu(alpha.to_f32()).to_bf16()
286 }
287 #[inline(always)]
288 fn __gelu(self) -> Self {
289 self.to_f32().__gelu().to_bf16()
290 }
291 #[inline(always)]
292 fn __selu(self, alpha: Self, scale: Self) -> Self {
293 self.to_f32()
294 .__selu(alpha.to_f32(), scale.to_f32())
295 .to_bf16()
296 }
297 #[inline(always)]
298 fn __hard_sigmoid(self) -> Self {
299 self.to_f32().__hard_sigmoid().to_bf16()
300 }
301 #[inline(always)]
302 fn __hard_swish(self) -> Self {
303 self.to_f32().__hard_swish().to_bf16()
304 }
305 #[inline(always)]
306 fn __softplus(self) -> Self {
307 self.to_f32().__softplus().to_bf16()
308 }
309 #[inline(always)]
310 fn __softsign(self) -> Self {
311 self.to_f32().__softsign().to_bf16()
312 }
313 #[inline(always)]
314 fn __mish(self) -> Self {
315 self.to_f32().__mish().to_bf16()
316 }
317 #[inline(always)]
318 fn __cbrt(self) -> Self {
319 self.to_f32().__cbrt().to_bf16()
320 }
321
322 #[inline(always)]
323 fn __sincos(self) -> (Self, Self) {
324 let res = self.to_f32().sin_cos();
325 (res.0.to_bf16(), res.1.to_bf16())
326 }
327
328 #[inline(always)]
329 fn __atan2(self, rhs: Self) -> Self {
330 self.atan2(rhs)
331 }
332
333 #[inline(always)]
334 fn __exp10(self) -> Self {
335 self.to_f32().__exp10().to_bf16()
336 }
337}