hpt_types/scalars/
_f32.rs1use crate::type_promote::{
2 BitWiseOut2, Eval2, FloatOutBinary2, FloatOutUnary2, NormalOut2, NormalOutUnary2,
3};
4impl FloatOutBinary2 for f32 {
5 #[inline(always)]
6 fn __div(self, rhs: Self) -> Self {
7 self / rhs
8 }
9 #[inline(always)]
10 fn __log(self, base: Self) -> Self {
11 self.log(base)
12 }
13 #[inline(always)]
14 fn __hypot(self, rhs: Self) -> Self {
15 self.hypot(rhs)
16 }
17}
18
19impl NormalOut2 for f32 {
20 #[inline(always)]
21 fn __add(self, rhs: Self) -> Self {
22 self + rhs
23 }
24
25 #[inline(always)]
26 fn __sub(self, rhs: Self) -> Self {
27 self - rhs
28 }
29
30 #[inline(always)]
31 fn __mul_add(self, a: Self, b: Self) -> Self {
32 #[cfg(target_feature = "fma")]
33 return self.mul_add(a, b);
34 #[cfg(not(target_feature = "fma"))]
35 return std::hint::black_box((self * a) + b);
36 }
37
38 #[inline(always)]
39 fn __mul(self, rhs: Self) -> Self {
40 self * rhs
41 }
42
43 #[inline(always)]
44 fn __pow(self, rhs: Self) -> Self {
45 self.powf(rhs)
46 }
47
48 #[inline(always)]
49 fn __rem(self, rhs: Self) -> Self {
50 self % rhs
51 }
52
53 #[inline(always)]
54 fn __max(self, rhs: Self) -> Self {
55 self.max(rhs)
56 }
57
58 #[inline(always)]
59 fn __min(self, rhs: Self) -> Self {
60 self.min(rhs)
61 }
62
63 #[inline(always)]
64 fn __clamp(self, min: Self, max: Self) -> Self {
65 self.clamp(min, max)
66 }
67}
68
69impl NormalOutUnary2 for f32 {
70 #[inline(always)]
71 fn __square(self) -> Self {
72 self * self
73 }
74
75 #[inline(always)]
76 fn __abs(self) -> Self {
77 self.abs()
78 }
79
80 #[inline(always)]
81 fn __ceil(self) -> Self {
82 self.ceil()
83 }
84
85 #[inline(always)]
86 fn __floor(self) -> Self {
87 self.floor()
88 }
89
90 #[inline(always)]
91 fn __neg(self) -> Self {
92 -self
93 }
94
95 #[inline(always)]
96 fn __round(self) -> Self {
97 self.round()
98 }
99
100 #[inline(always)]
101 fn __signum(self) -> Self {
102 self.signum()
103 }
104
105 #[inline(always)]
106 fn __trunc(self) -> Self {
107 self.trunc()
108 }
109
110 #[inline(always)]
111 fn __leaky_relu(self, alpha: Self) -> Self {
112 self.max(0.0) + alpha * self.min(0.0)
113 }
114
115 #[inline(always)]
116 fn __relu(self) -> Self {
117 self.max(0.0)
118 }
119
120 #[inline(always)]
121 fn __relu6(self) -> Self {
122 self.max(0.0).min(6.0)
123 }
124
125 #[inline(always)]
126 fn __copysign(self, rhs: Self) -> Self {
127 self.copysign(rhs)
128 }
129}
130
131impl BitWiseOut2 for f32 {
132 #[inline(always)]
133 fn __bitand(self, rhs: Self) -> Self {
134 f32::from_bits(self.to_bits() & rhs.to_bits())
135 }
136
137 #[inline(always)]
138 fn __bitor(self, rhs: Self) -> Self {
139 f32::from_bits(self.to_bits() | rhs.to_bits())
140 }
141
142 #[inline(always)]
143 fn __bitxor(self, rhs: Self) -> Self {
144 f32::from_bits(self.to_bits() ^ rhs.to_bits())
145 }
146
147 #[inline(always)]
148 fn __not(self) -> Self {
149 f32::from_bits(!self.to_bits())
150 }
151
152 #[inline(always)]
153 fn __shl(self, _: Self) -> Self {
154 panic!("Shift operations are not supported for f32")
155 }
156
157 #[inline(always)]
158 fn __shr(self, _: Self) -> Self {
159 panic!("Shift operations are not supported for f32")
160 }
161}
162
163impl Eval2 for f32 {
164 type Output = bool;
165 #[inline(always)]
166 fn __is_nan(&self) -> Self::Output {
167 self.is_nan()
168 }
169
170 #[inline(always)]
171 fn __is_true(&self) -> Self::Output {
172 *self != 0.0 && !self.is_nan()
173 }
174
175 #[inline(always)]
176 fn __is_inf(&self) -> Self::Output {
177 self.is_infinite()
178 }
179}
180
181impl FloatOutUnary2 for f32 {
182 #[inline(always)]
183 fn __exp(self) -> Self {
184 self.exp()
185 }
186 #[inline(always)]
187 fn __expm1(self) -> Self {
188 self.exp_m1()
189 }
190 #[inline(always)]
191 fn __exp2(self) -> Self {
192 self.exp2()
193 }
194 #[inline(always)]
195 fn __ln(self) -> Self {
196 self.ln()
197 }
198 #[inline(always)]
199 fn __log1p(self) -> Self {
200 self.ln_1p()
201 }
202 #[inline(always)]
203 fn __celu(self, scale: Self) -> Self {
204 let gt_mask = (self > 0.0) as i32 as f32;
205 gt_mask * self + (1.0 - gt_mask) * (scale * (self.exp() - 1.0))
206 }
207 #[inline(always)]
208 fn __log2(self) -> Self {
209 self.log2()
210 }
211 #[inline(always)]
212 fn __log10(self) -> Self {
213 self.log10()
214 }
215 #[inline(always)]
216 fn __sqrt(self) -> Self {
217 self.sqrt()
218 }
219 #[inline(always)]
220 fn __sin(self) -> Self {
221 self.sin()
222 }
223 #[inline(always)]
224 fn __cos(self) -> Self {
225 self.cos()
226 }
227 #[inline(always)]
228 fn __tan(self) -> Self {
229 self.tan()
230 }
231 #[inline(always)]
232 fn __asin(self) -> Self {
233 self.asin()
234 }
235 #[inline(always)]
236 fn __acos(self) -> Self {
237 self.acos()
238 }
239 #[inline(always)]
240 fn __atan(self) -> Self {
241 self.atan()
242 }
243 #[inline(always)]
244 fn __sinh(self) -> Self {
245 self.sinh()
246 }
247 #[inline(always)]
248 fn __cosh(self) -> Self {
249 self.cosh()
250 }
251 #[inline(always)]
252 fn __tanh(self) -> Self {
253 self.tanh()
254 }
255 #[inline(always)]
256 fn __asinh(self) -> Self {
257 self.asinh()
258 }
259 #[inline(always)]
260 fn __acosh(self) -> Self {
261 self.acosh()
262 }
263 #[inline(always)]
264 fn __atanh(self) -> Self {
265 self.atanh()
266 }
267 #[inline(always)]
268 fn __recip(self) -> Self {
269 self.recip()
270 }
271 #[inline(always)]
272 fn __erf(self) -> Self {
273 libm::erff(self)
274 }
275
276 #[inline(always)]
277 fn __sigmoid(self) -> Self {
278 1.0 / (1.0 + (-self).exp())
279 }
280
281 fn __elu(self, alpha: Self) -> Self {
282 self.max(0.0) + alpha * (self.exp() - 1.0).min(0.0)
283 }
284
285 fn __gelu(self) -> Self {
286 0.5 * self * (libm::erff(self * std::f32::consts::FRAC_1_SQRT_2) + 1.0)
287 }
288
289 fn __selu(self, alpha: Self, scale: Self) -> Self {
290 scale * (self.max(0.0) + alpha * (self.exp() - 1.0).min(0.0))
291 }
292
293 fn __hard_sigmoid(self) -> Self {
294 let result = self * (1.0 / 6.0) + 0.5;
295 result.min(1.0).max(0.0)
296 }
297
298 fn __hard_swish(self) -> Self {
299 self * ((self + 3.0).clamp(0.0, 6.0) / 6.0)
300 }
301
302 fn __softplus(self) -> Self {
303 (1.0 + self.exp()).ln()
304 }
305
306 fn __softsign(self) -> Self {
307 self / (1.0 + self.abs())
308 }
309
310 fn __mish(self) -> Self {
311 self * ((1.0 + self.exp()).ln()).tanh()
312 }
313
314 fn __cbrt(self) -> Self {
315 libm::cbrtf(self)
316 }
317
318 fn __sincos(self) -> (Self, Self) {
319 self.sin_cos()
320 }
321
322 fn __atan2(self, rhs: Self) -> Self {
323 self.atan2(rhs)
324 }
325
326 fn __exp10(self) -> Self {
327 10f32.powf(self)
328 }
329}