hpt_types/scalars/
_f32.rs

1use crate::type_promote::{
2    BitWiseOut2, Eval2, FloatOutBinary2, FloatOutUnary2, NormalOut2, NormalOutUnary2,
3};
4impl FloatOutBinary2 for f32 {
5    #[inline(always)]
6    fn __div(self, rhs: Self) -> Self {
7        self / rhs
8    }
9    #[inline(always)]
10    fn __log(self, base: Self) -> Self {
11        self.log(base)
12    }
13    #[inline(always)]
14    fn __hypot(self, rhs: Self) -> Self {
15        self.hypot(rhs)
16    }
17}
18
19impl NormalOut2 for f32 {
20    #[inline(always)]
21    fn __add(self, rhs: Self) -> Self {
22        self + rhs
23    }
24
25    #[inline(always)]
26    fn __sub(self, rhs: Self) -> Self {
27        self - rhs
28    }
29
30    #[inline(always)]
31    fn __mul_add(self, a: Self, b: Self) -> Self {
32        #[cfg(target_feature = "fma")]
33        return self.mul_add(a, b);
34        #[cfg(not(target_feature = "fma"))]
35        return std::hint::black_box((self * a) + b);
36    }
37
38    #[inline(always)]
39    fn __mul(self, rhs: Self) -> Self {
40        self * rhs
41    }
42
43    #[inline(always)]
44    fn __pow(self, rhs: Self) -> Self {
45        self.powf(rhs)
46    }
47
48    #[inline(always)]
49    fn __rem(self, rhs: Self) -> Self {
50        self % rhs
51    }
52
53    #[inline(always)]
54    fn __max(self, rhs: Self) -> Self {
55        self.max(rhs)
56    }
57
58    #[inline(always)]
59    fn __min(self, rhs: Self) -> Self {
60        self.min(rhs)
61    }
62
63    #[inline(always)]
64    fn __clamp(self, min: Self, max: Self) -> Self {
65        self.clamp(min, max)
66    }
67}
68
69impl NormalOutUnary2 for f32 {
70    #[inline(always)]
71    fn __square(self) -> Self {
72        self * self
73    }
74
75    #[inline(always)]
76    fn __abs(self) -> Self {
77        self.abs()
78    }
79
80    #[inline(always)]
81    fn __ceil(self) -> Self {
82        self.ceil()
83    }
84
85    #[inline(always)]
86    fn __floor(self) -> Self {
87        self.floor()
88    }
89
90    #[inline(always)]
91    fn __neg(self) -> Self {
92        -self
93    }
94
95    #[inline(always)]
96    fn __round(self) -> Self {
97        self.round()
98    }
99
100    #[inline(always)]
101    fn __signum(self) -> Self {
102        self.signum()
103    }
104
105    #[inline(always)]
106    fn __trunc(self) -> Self {
107        self.trunc()
108    }
109
110    #[inline(always)]
111    fn __leaky_relu(self, alpha: Self) -> Self {
112        self.max(0.0) + alpha * self.min(0.0)
113    }
114
115    #[inline(always)]
116    fn __relu(self) -> Self {
117        self.max(0.0)
118    }
119
120    #[inline(always)]
121    fn __relu6(self) -> Self {
122        self.max(0.0).min(6.0)
123    }
124
125    #[inline(always)]
126    fn __copysign(self, rhs: Self) -> Self {
127        self.copysign(rhs)
128    }
129}
130
131impl BitWiseOut2 for f32 {
132    #[inline(always)]
133    fn __bitand(self, rhs: Self) -> Self {
134        f32::from_bits(self.to_bits() & rhs.to_bits())
135    }
136
137    #[inline(always)]
138    fn __bitor(self, rhs: Self) -> Self {
139        f32::from_bits(self.to_bits() | rhs.to_bits())
140    }
141
142    #[inline(always)]
143    fn __bitxor(self, rhs: Self) -> Self {
144        f32::from_bits(self.to_bits() ^ rhs.to_bits())
145    }
146
147    #[inline(always)]
148    fn __not(self) -> Self {
149        f32::from_bits(!self.to_bits())
150    }
151
152    #[inline(always)]
153    fn __shl(self, _: Self) -> Self {
154        panic!("Shift operations are not supported for f32")
155    }
156
157    #[inline(always)]
158    fn __shr(self, _: Self) -> Self {
159        panic!("Shift operations are not supported for f32")
160    }
161}
162
163impl Eval2 for f32 {
164    type Output = bool;
165    #[inline(always)]
166    fn __is_nan(&self) -> Self::Output {
167        self.is_nan()
168    }
169
170    #[inline(always)]
171    fn __is_true(&self) -> Self::Output {
172        *self != 0.0 && !self.is_nan()
173    }
174
175    #[inline(always)]
176    fn __is_inf(&self) -> Self::Output {
177        self.is_infinite()
178    }
179}
180
181impl FloatOutUnary2 for f32 {
182    #[inline(always)]
183    fn __exp(self) -> Self {
184        self.exp()
185    }
186    #[inline(always)]
187    fn __expm1(self) -> Self {
188        self.exp_m1()
189    }
190    #[inline(always)]
191    fn __exp2(self) -> Self {
192        self.exp2()
193    }
194    #[inline(always)]
195    fn __ln(self) -> Self {
196        self.ln()
197    }
198    #[inline(always)]
199    fn __log1p(self) -> Self {
200        self.ln_1p()
201    }
202    #[inline(always)]
203    fn __celu(self, scale: Self) -> Self {
204        let gt_mask = (self > 0.0) as i32 as f32;
205        gt_mask * self + (1.0 - gt_mask) * (scale * (self.exp() - 1.0))
206    }
207    #[inline(always)]
208    fn __log2(self) -> Self {
209        self.log2()
210    }
211    #[inline(always)]
212    fn __log10(self) -> Self {
213        self.log10()
214    }
215    #[inline(always)]
216    fn __sqrt(self) -> Self {
217        self.sqrt()
218    }
219    #[inline(always)]
220    fn __sin(self) -> Self {
221        self.sin()
222    }
223    #[inline(always)]
224    fn __cos(self) -> Self {
225        self.cos()
226    }
227    #[inline(always)]
228    fn __tan(self) -> Self {
229        self.tan()
230    }
231    #[inline(always)]
232    fn __asin(self) -> Self {
233        self.asin()
234    }
235    #[inline(always)]
236    fn __acos(self) -> Self {
237        self.acos()
238    }
239    #[inline(always)]
240    fn __atan(self) -> Self {
241        self.atan()
242    }
243    #[inline(always)]
244    fn __sinh(self) -> Self {
245        self.sinh()
246    }
247    #[inline(always)]
248    fn __cosh(self) -> Self {
249        self.cosh()
250    }
251    #[inline(always)]
252    fn __tanh(self) -> Self {
253        self.tanh()
254    }
255    #[inline(always)]
256    fn __asinh(self) -> Self {
257        self.asinh()
258    }
259    #[inline(always)]
260    fn __acosh(self) -> Self {
261        self.acosh()
262    }
263    #[inline(always)]
264    fn __atanh(self) -> Self {
265        self.atanh()
266    }
267    #[inline(always)]
268    fn __recip(self) -> Self {
269        self.recip()
270    }
271    #[inline(always)]
272    fn __erf(self) -> Self {
273        libm::erff(self)
274    }
275
276    #[inline(always)]
277    fn __sigmoid(self) -> Self {
278        1.0 / (1.0 + (-self).exp())
279    }
280
281    fn __elu(self, alpha: Self) -> Self {
282        self.max(0.0) + alpha * (self.exp() - 1.0).min(0.0)
283    }
284
285    fn __gelu(self) -> Self {
286        0.5 * self * (libm::erff(self * std::f32::consts::FRAC_1_SQRT_2) + 1.0)
287    }
288
289    fn __selu(self, alpha: Self, scale: Self) -> Self {
290        scale * (self.max(0.0) + alpha * (self.exp() - 1.0).min(0.0))
291    }
292
293    fn __hard_sigmoid(self) -> Self {
294        let result = self * (1.0 / 6.0) + 0.5;
295        result.min(1.0).max(0.0)
296    }
297
298    fn __hard_swish(self) -> Self {
299        self * ((self + 3.0).clamp(0.0, 6.0) / 6.0)
300    }
301
302    fn __softplus(self) -> Self {
303        (1.0 + self.exp()).ln()
304    }
305
306    fn __softsign(self) -> Self {
307        self / (1.0 + self.abs())
308    }
309
310    fn __mish(self) -> Self {
311        self * ((1.0 + self.exp()).ln()).tanh()
312    }
313
314    fn __cbrt(self) -> Self {
315        libm::cbrtf(self)
316    }
317
318    fn __sincos(self) -> (Self, Self) {
319        self.sin_cos()
320    }
321
322    fn __atan2(self, rhs: Self) -> Self {
323        self.atan2(rhs)
324    }
325
326    fn __exp10(self) -> Self {
327        10f32.powf(self)
328    }
329}