1#![allow(clippy::excessive_precision)]
4
5use rten_simd::ops::{FloatOps, IntOps, NumOps};
6use rten_simd::{Isa, Simd, SimdUnaryOp};
7
8const INV_LOG2: f32 = std::f32::consts::LOG2_E; const ROUNDING_MAGIC: f32 = 12582912.; const LOG2_HI: f32 = -6.93145752e-1;
13const LOG2_LO: f32 = -1.42860677e-6;
14
15const EXP_POLY_0: f32 = 1.0;
20const EXP_POLY_1: f32 = 1.0;
21const EXP_POLY_2: f32 = 4.99999851e-1; const EXP_POLY_3: f32 = 1.66664720e-1; const EXP_POLY_4: f32 = 4.16695364e-2; const EXP_POLY_5: f32 = 8.37312452e-3; const EXP_POLY_6: f32 = 1.37805939e-3; #[derive(Default)]
32pub struct Exp {}
33
34impl SimdUnaryOp<f32> for Exp {
62 #[inline(always)]
63 fn eval<I: Isa>(&self, isa: I, x: I::F32) -> I::F32 {
64 let ops = isa.f32();
65 let int_ops = isa.i32();
66
67 let inv_log_2 = ops.splat(INV_LOG2);
69 let rounding_magic = ops.splat(ROUNDING_MAGIC);
70 let ln2_hi = ops.splat(LOG2_HI);
71 let ln2_lo = ops.splat(LOG2_LO);
72
73 let p6 = ops.splat(EXP_POLY_6);
74 let p5 = ops.splat(EXP_POLY_5);
75 let p4 = ops.splat(EXP_POLY_4);
76 let p3 = ops.splat(EXP_POLY_3);
77 let p2 = ops.splat(EXP_POLY_2);
78 let p1 = ops.splat(EXP_POLY_1);
79 let p0 = ops.splat(EXP_POLY_0);
80
81 let j = ops.mul_add(x, inv_log_2, rounding_magic);
83 let j = ops.sub(j, rounding_magic);
84 let r = ops.mul_add(j, ln2_hi, x);
85 let r = ops.mul_add(j, ln2_lo, r);
86 let k = ops.to_int_trunc(j);
87
88 let mut tmp = p6;
90 tmp = ops.mul_add(tmp, r, p5);
91 tmp = ops.mul_add(tmp, r, p4);
92 tmp = ops.mul_add(tmp, r, p3);
93 tmp = ops.mul_add(tmp, r, p2);
94 tmp = ops.mul_add(tmp, r, p1);
95 let r = ops.mul_add(tmp, r, p0);
96
97 let ia = int_ops.gt(k, int_ops.zero());
107 let x7f = int_ops.splat(0x7f000000);
108 #[allow(overflowing_literals)]
109 let x83 = int_ops.splat(0x83000000);
110 let ia = int_ops.select(int_ops.zero(), x83, ia);
111 let is = int_ops.add(ia, x7f);
112
113 let it = int_ops.shift_left::<23>(k);
114 let it = int_ops.sub(it, ia);
115
116 let s: I::F32 = is.reinterpret_cast();
117 let t: I::F32 = it.reinterpret_cast();
118 let r = ops.mul(r, s);
119 let r = ops.mul(r, t);
120
121 let overflow_mask = ops.ge(x, ops.splat(104.0));
123 let underflow_mask = ops.le(x, ops.splat(-104.0));
124 let r = ops.select(ops.splat(f32::INFINITY), r, overflow_mask);
125 ops.select(ops.zero(), r, underflow_mask)
126 }
127}
128
129const EXP_LOWER_CUTOFF: f32 = -126.5 * std::f32::consts::LN_2 + 0.01; #[derive(Default)]
138pub struct ReducedRangeExp {}
139
140impl SimdUnaryOp<f32> for ReducedRangeExp {
141 #[inline(always)]
142 fn eval<I: Isa>(&self, isa: I, x: I::F32) -> I::F32 {
143 let ops = isa.f32();
144 let int_ops = isa.i32();
145
146 let inv_log_2 = ops.splat(INV_LOG2);
148 let rounding_magic = ops.splat(ROUNDING_MAGIC);
149 let ln2_hi = ops.splat(LOG2_HI);
150 let ln2_lo = ops.splat(LOG2_LO);
151
152 let p6 = ops.splat(EXP_POLY_6);
153 let p5 = ops.splat(EXP_POLY_5);
154 let p4 = ops.splat(EXP_POLY_4);
155 let p3 = ops.splat(EXP_POLY_3);
156 let p2 = ops.splat(EXP_POLY_2);
157 let p1 = ops.splat(EXP_POLY_1);
158 let p0 = ops.splat(EXP_POLY_0);
159
160 let j = ops.mul_add(x, inv_log_2, rounding_magic);
164 let j = ops.sub(j, rounding_magic);
165 let r = ops.mul_add(j, ln2_hi, x);
166 let r = ops.mul_add(j, ln2_lo, r);
167 let k = ops.to_int_trunc(j);
168
169 let mut tmp = p6;
171 tmp = ops.mul_add(tmp, r, p5);
172 tmp = ops.mul_add(tmp, r, p4);
173 tmp = ops.mul_add(tmp, r, p3);
174 tmp = ops.mul_add(tmp, r, p2);
175 tmp = ops.mul_add(tmp, r, p1);
176 let r = ops.mul_add(tmp, r, p0);
177
178 let exponent_bias = int_ops.splat(127);
183 let k_pow2 = int_ops.shift_left::<23>(int_ops.add(k, exponent_bias));
184 let k_pow2: I::F32 = k_pow2.reinterpret_cast();
185 let r = ops.mul(r, k_pow2);
186
187 let underflow_mask = ops.lt(x, ops.splat(EXP_LOWER_CUTOFF));
189 ops.select(ops.zero(), r, underflow_mask)
190 }
191}
192
193#[derive(Default)]
201pub struct Sigmoid {}
202
203impl SimdUnaryOp<f32> for Sigmoid {
204 #[inline(always)]
205 fn eval<I: Isa>(&self, isa: I, x: I::F32) -> I::F32 {
206 let ops = isa.f32();
207
208 let denom = ops.add(ops.one(), Exp::apply(isa, ops.neg(x)));
210 ops.reciprocal(denom)
211 }
212}
213
214pub struct Silu {}
218
219impl SimdUnaryOp<f32> for Silu {
220 #[inline(always)]
221 fn eval<I: Isa>(&self, isa: I, x: I::F32) -> I::F32 {
222 let ops = isa.f32();
223
224 let denom = ops.add(ops.one(), Exp::apply(isa, ops.neg(x)));
226 ops.div(x, denom)
227 }
228}
229
230pub struct Swish {
234 pub beta: f32,
235}
236
237impl SimdUnaryOp<f32> for Swish {
238 #[inline(always)]
239 fn eval<I: Isa>(&self, isa: I, x: I::F32) -> I::F32 {
240 let ops = isa.f32();
241
242 let beta = ops.splat(self.beta);
243 ops.mul(x, Sigmoid::apply(isa, ops.mul(x, beta)))
244 }
245}
246
247pub struct Elu {
251 pub alpha: f32,
252}
253
254impl SimdUnaryOp<f32> for Elu {
255 #[inline(always)]
256 fn eval<I: Isa>(&self, isa: I, x: I::F32) -> I::F32 {
257 let ops = isa.f32();
266 let x_pos = ops.ge(x, ops.zero());
267 let x_exp = ops.mul(
268 ops.splat(self.alpha),
269 ops.sub(Exp::apply(isa, x), ops.splat(1.)),
270 );
271 ops.select(x, x_exp, x_pos)
272 }
273}
274
275#[cfg(test)]
276mod tests {
277 use rten_simd::SimdUnaryOp;
278
279 use super::{EXP_LOWER_CUTOFF, ReducedRangeExp};
280 use crate::testing::{AllF32s, Tolerance, UnaryOpTester, arange, benchmark_op};
281 use crate::{Elu, Exp, Sigmoid, Silu, Swish};
282
283 const MAX_EXP_ERROR_ULPS: f32 = 1.0;
285
286 const MAX_SIGMOID_ERROR_ULPS: f32 = 4.0;
288
289 fn reference_elu(x: f32, alpha: f32) -> f32 {
290 if x >= 0. { x } else { alpha * (x.exp() - 1.) }
291 }
292
293 fn reference_sigmoid(x: f32) -> f32 {
294 1. / (1. + (-x).exp())
295 }
296
297 fn reference_silu(x: f32) -> f32 {
298 x * reference_sigmoid(x)
299 }
300
301 fn reference_swish(x: f32, beta: f32) -> f32 {
302 x * reference_sigmoid(beta * x)
303 }
304
305 #[test]
306 fn test_exp_basic() {
307 let cases = [-2.0f32, -1., -0.5, 0.1, 0., 0.1, 0.5, 1., 2., -105., 105.];
310
311 let exp_op = Exp {};
312 for case in cases {
313 let expected = case.exp();
314 let actual = exp_op.scalar_eval(case);
315 let diff = (expected - actual).abs();
316
317 if actual.is_infinite() || expected.is_infinite() {
318 assert_eq!(actual, expected);
319 } else {
320 assert_eq!(diff, 0.);
323 };
324 }
325 }
326
327 #[test]
328 fn test_exp() {
329 let test = UnaryOpTester {
330 reference: f32::exp,
331 simd: Exp {},
332 range: arange(-6., 6., 0.001),
333 tolerance: Tolerance::Ulp(MAX_EXP_ERROR_ULPS),
334 };
335 test.run();
336 }
337
338 #[test]
339 fn test_reduced_range_exp() {
340 let test = UnaryOpTester {
341 reference: f32::exp,
342 simd: ReducedRangeExp {},
343 range: arange(EXP_LOWER_CUTOFF, 0., 0.015),
344 tolerance: Tolerance::Ulp(MAX_EXP_ERROR_ULPS),
345 };
346 test.run();
347 }
348
349 #[test]
350 fn test_elu() {
351 let alpha = 0.5;
352 let test = UnaryOpTester {
353 reference: |x| reference_elu(x, alpha),
354 simd: Elu { alpha },
355 range: [-2., -1., 0., 1., 2.].into_iter(),
356 tolerance: Tolerance::Ulp(1.0),
357 };
358 test.run();
359 }
360
361 #[test]
362 #[ignore] fn test_exp_exhaustive() {
364 let test = UnaryOpTester {
365 reference: f32::exp,
366 simd: Exp {},
367 range: AllF32s::new(),
368 tolerance: Tolerance::Ulp(MAX_EXP_ERROR_ULPS),
369 };
370 test.run_with_progress();
371 }
372
373 #[test]
374 fn test_sigmoid() {
375 let test = UnaryOpTester {
376 reference: reference_sigmoid,
377 simd: Sigmoid {},
378 range: arange(-6., 6., 0.001),
379 tolerance: Tolerance::Ulp(MAX_SIGMOID_ERROR_ULPS),
380 };
381 test.run();
382 }
383
384 #[test]
385 #[ignore] fn test_sigmoid_exhaustive() {
387 let test = UnaryOpTester {
388 reference: reference_sigmoid,
389 simd: Sigmoid {},
390 range: AllF32s::new(),
391 tolerance: Tolerance::Ulp(MAX_SIGMOID_ERROR_ULPS),
392 };
393 test.run_with_progress();
394 }
395
396 #[test]
397 fn test_silu() {
398 let test = UnaryOpTester {
399 reference: reference_silu,
400 simd: Silu {},
401 range: arange(-6., 6., 0.001),
402 tolerance: Tolerance::Ulp(MAX_SIGMOID_ERROR_ULPS),
403 };
404 test.run();
405 }
406
407 #[test]
408 fn test_swish() {
409 let beta = 1.7;
410 let test = UnaryOpTester {
411 reference: |x| reference_swish(x, beta),
412 simd: Swish { beta },
413 range: arange(-6., 6., 0.001),
414 tolerance: Tolerance::Ulp(MAX_SIGMOID_ERROR_ULPS),
415 };
416 test.run();
417 }
418
419 #[test]
420 #[ignore]
421 fn bench_elu() {
422 let alpha = 0.5;
423 benchmark_op(
424 |xs, ys| {
425 xs.iter()
426 .zip(ys.iter_mut())
427 .for_each(|(x, y)| *y = reference_elu(*x, alpha))
428 },
429 |xs, ys| {
430 Elu { alpha }.map(xs, ys);
431 },
432 );
433 }
434
435 #[test]
436 #[ignore]
437 fn bench_exp() {
438 benchmark_op(
439 |xs, ys| xs.iter().zip(ys.iter_mut()).for_each(|(x, y)| *y = x.exp()),
440 |xs, ys| {
441 Exp {}.map(xs, ys);
442 },
443 );
444 }
445
446 #[test]
447 #[ignore]
448 fn bench_sigmoid() {
449 benchmark_op(
450 |xs, ys| {
451 xs.iter()
452 .zip(ys.iter_mut())
453 .for_each(|(x, y)| *y = reference_sigmoid(*x))
454 },
455 |xs, ys| {
456 Sigmoid {}.map(xs, ys);
457 },
458 );
459 }
460}