hpt_types/scalars/
_bf16.rs

1use num_traits::real::Real;
2
3use crate::{
4    convertion::Convertor,
5    type_promote::{
6        BitWiseOut2, Eval2, FloatOutBinary2, FloatOutUnary2, NormalOut2, NormalOutUnary2,
7    },
8};
9
10impl FloatOutBinary2 for half::bf16 {
11    #[inline(always)]
12    fn __div(self, rhs: Self) -> Self {
13        self / rhs
14    }
15    #[inline(always)]
16    fn __log(self, base: Self) -> Self {
17        self.log(base)
18    }
19    #[inline(always)]
20    fn __hypot(self, rhs: Self) -> Self {
21        self.hypot(rhs)
22    }
23}
24
25impl NormalOut2 for half::bf16 {
26    #[inline(always)]
27    fn __add(self, rhs: Self) -> Self {
28        self + rhs
29    }
30
31    #[inline(always)]
32    fn __sub(self, rhs: Self) -> Self {
33        self - rhs
34    }
35
36    #[inline(always)]
37    fn __mul_add(self, a: Self, b: Self) -> Self {
38        (self * a) + b
39    }
40
41    #[inline(always)]
42    fn __mul(self, rhs: Self) -> Self {
43        self * rhs
44    }
45
46    #[inline(always)]
47    fn __pow(self, rhs: Self) -> Self {
48        self.powf(rhs)
49    }
50
51    #[inline(always)]
52    fn __rem(self, rhs: Self) -> Self {
53        self % rhs
54    }
55
56    #[inline(always)]
57    fn __max(self, rhs: Self) -> Self {
58        self.max(rhs)
59    }
60
61    #[inline(always)]
62    fn __min(self, rhs: Self) -> Self {
63        self.min(rhs)
64    }
65
66    #[inline(always)]
67    fn __clamp(self, min: Self, max: Self) -> Self {
68        self.clamp(min, max)
69    }
70}
71
72impl NormalOutUnary2 for half::bf16 {
73    #[inline(always)]
74    fn __square(self) -> Self {
75        self * self
76    }
77
78    #[inline(always)]
79    fn __abs(self) -> Self {
80        self.abs()
81    }
82
83    #[inline(always)]
84    fn __ceil(self) -> Self {
85        self.ceil()
86    }
87
88    #[inline(always)]
89    fn __floor(self) -> Self {
90        self.floor()
91    }
92
93    #[inline(always)]
94    fn __neg(self) -> Self {
95        -self
96    }
97
98    #[inline(always)]
99    fn __round(self) -> Self {
100        self.round()
101    }
102
103    #[inline(always)]
104    fn __signum(self) -> Self {
105        self.signum()
106    }
107
108    #[inline(always)]
109    fn __trunc(self) -> Self {
110        self.trunc()
111    }
112
113    #[inline(always)]
114    fn __leaky_relu(self, alpha: Self) -> Self {
115        self.max(half::bf16::from_f32_const(0.0))
116            + alpha * self.min(half::bf16::from_f32_const(0.0))
117    }
118
119    #[inline(always)]
120    fn __relu(self) -> Self {
121        self.max(half::bf16::from_f32_const(0.0))
122    }
123
124    #[inline(always)]
125    fn __relu6(self) -> Self {
126        self.min(half::bf16::from_f32_const(6.0))
127            .max(half::bf16::from_f32_const(0.0))
128    }
129
130    #[inline(always)]
131    fn __copysign(self, rhs: Self) -> Self {
132        self.copysign(rhs)
133    }
134}
135
136impl BitWiseOut2 for half::bf16 {
137    #[inline(always)]
138    fn __bitand(self, rhs: Self) -> Self {
139        half::bf16::from_bits(self.to_bits() & rhs.to_bits())
140    }
141
142    #[inline(always)]
143    fn __bitor(self, rhs: Self) -> Self {
144        half::bf16::from_bits(self.to_bits() | rhs.to_bits())
145    }
146
147    #[inline(always)]
148    fn __bitxor(self, rhs: Self) -> Self {
149        half::bf16::from_bits(self.to_bits() ^ rhs.to_bits())
150    }
151
152    #[inline(always)]
153    fn __not(self) -> Self {
154        half::bf16::from_bits(!self.to_bits())
155    }
156
157    #[inline(always)]
158    fn __shl(self, _: Self) -> Self {
159        panic!("Shift operations are not supported for half::bf16")
160    }
161
162    #[inline(always)]
163    fn __shr(self, _: Self) -> Self {
164        panic!("Shift operations are not supported for half::bf16")
165    }
166}
167
168impl Eval2 for half::bf16 {
169    type Output = bool;
170    #[inline(always)]
171    fn __is_nan(&self) -> Self::Output {
172        self.is_nan()
173    }
174
175    #[inline(always)]
176    fn __is_true(&self) -> Self::Output {
177        *self != half::bf16::from_f32_const(0.0) && !self.is_nan()
178    }
179
180    #[inline(always)]
181    fn __is_inf(&self) -> Self::Output {
182        self.is_infinite()
183    }
184}
185
186impl FloatOutUnary2 for half::bf16 {
187    #[inline(always)]
188    fn __exp(self) -> Self {
189        self.exp()
190    }
191    #[inline(always)]
192    fn __expm1(self) -> Self {
193        self.to_f32().__expm1().to_bf16()
194    }
195    #[inline(always)]
196    fn __exp2(self) -> Self {
197        self.exp2()
198    }
199    #[inline(always)]
200    fn __ln(self) -> Self {
201        self.ln()
202    }
203    #[inline(always)]
204    fn __log1p(self) -> Self {
205        self.to_f32().__log1p().to_bf16()
206    }
207    #[inline(always)]
208    fn __celu(self, alpha: Self) -> Self {
209        self.to_f32().__celu(alpha.to_f32()).to_bf16()
210    }
211    #[inline(always)]
212    fn __log2(self) -> Self {
213        self.log2()
214    }
215    #[inline(always)]
216    fn __log10(self) -> Self {
217        self.log10()
218    }
219    #[inline(always)]
220    fn __sqrt(self) -> Self {
221        self.sqrt()
222    }
223    #[inline(always)]
224    fn __sin(self) -> Self {
225        self.sin()
226    }
227    #[inline(always)]
228    fn __cos(self) -> Self {
229        self.cos()
230    }
231    #[inline(always)]
232    fn __tan(self) -> Self {
233        self.tan()
234    }
235    #[inline(always)]
236    fn __asin(self) -> Self {
237        self.asin()
238    }
239    #[inline(always)]
240    fn __acos(self) -> Self {
241        self.acos()
242    }
243    #[inline(always)]
244    fn __atan(self) -> Self {
245        self.atan()
246    }
247    #[inline(always)]
248    fn __sinh(self) -> Self {
249        self.sinh()
250    }
251    #[inline(always)]
252    fn __cosh(self) -> Self {
253        self.cosh()
254    }
255    #[inline(always)]
256    fn __tanh(self) -> Self {
257        self.tanh()
258    }
259    #[inline(always)]
260    fn __asinh(self) -> Self {
261        self.asinh()
262    }
263    #[inline(always)]
264    fn __acosh(self) -> Self {
265        self.acosh()
266    }
267    #[inline(always)]
268    fn __atanh(self) -> Self {
269        self.atanh()
270    }
271    #[inline(always)]
272    fn __recip(self) -> Self {
273        self.recip()
274    }
275    #[inline(always)]
276    fn __erf(self) -> Self {
277        self.to_f32().__erf().to_bf16()
278    }
279    #[inline(always)]
280    fn __sigmoid(self) -> Self {
281        self.to_f32().__sigmoid().to_bf16()
282    }
283    #[inline(always)]
284    fn __elu(self, alpha: Self) -> Self {
285        self.to_f32().__elu(alpha.to_f32()).to_bf16()
286    }
287    #[inline(always)]
288    fn __gelu(self) -> Self {
289        self.to_f32().__gelu().to_bf16()
290    }
291    #[inline(always)]
292    fn __selu(self, alpha: Self, scale: Self) -> Self {
293        self.to_f32()
294            .__selu(alpha.to_f32(), scale.to_f32())
295            .to_bf16()
296    }
297    #[inline(always)]
298    fn __hard_sigmoid(self) -> Self {
299        self.to_f32().__hard_sigmoid().to_bf16()
300    }
301    #[inline(always)]
302    fn __hard_swish(self) -> Self {
303        self.to_f32().__hard_swish().to_bf16()
304    }
305    #[inline(always)]
306    fn __softplus(self) -> Self {
307        self.to_f32().__softplus().to_bf16()
308    }
309    #[inline(always)]
310    fn __softsign(self) -> Self {
311        self.to_f32().__softsign().to_bf16()
312    }
313    #[inline(always)]
314    fn __mish(self) -> Self {
315        self.to_f32().__mish().to_bf16()
316    }
317    #[inline(always)]
318    fn __cbrt(self) -> Self {
319        self.to_f32().__cbrt().to_bf16()
320    }
321
322    #[inline(always)]
323    fn __sincos(self) -> (Self, Self) {
324        let res = self.to_f32().sin_cos();
325        (res.0.to_bf16(), res.1.to_bf16())
326    }
327
328    #[inline(always)]
329    fn __atan2(self, rhs: Self) -> Self {
330        self.atan2(rhs)
331    }
332
333    #[inline(always)]
334    fn __exp10(self) -> Self {
335        self.to_f32().__exp10().to_bf16()
336    }
337}