1#![allow(clippy::excessive_precision)]
4
5use rten_simd::ops::{FloatOps, IntOps, NumOps};
6use rten_simd::{Isa, Simd, SimdUnaryOp};
7
8const INV_LOG2: f32 = std::f32::consts::LOG2_E; const ROUNDING_MAGIC: f32 = 12582912.; const LOG2_HI: f32 = -6.93145752e-1;
13const LOG2_LO: f32 = -1.42860677e-6;
14
15const EXP_POLY_0: f32 = 1.0;
20const EXP_POLY_1: f32 = 1.0;
21const EXP_POLY_2: f32 = 4.99999851e-1; const EXP_POLY_3: f32 = 1.66664720e-1; const EXP_POLY_4: f32 = 4.16695364e-2; const EXP_POLY_5: f32 = 8.37312452e-3; const EXP_POLY_6: f32 = 1.37805939e-3; #[derive(Default)]
32pub struct Exp {}
33
34impl SimdUnaryOp<f32> for Exp {
62 #[inline(always)]
63 fn eval<I: Isa, S: Simd<Elem = f32, Isa = I>>(&self, isa: I, x: S) -> S {
64 let ops = isa.f32();
65 let int_ops = isa.i32();
66
67 let x = x.same_cast();
68
69 let inv_log_2 = ops.splat(INV_LOG2);
71 let rounding_magic = ops.splat(ROUNDING_MAGIC);
72 let ln2_hi = ops.splat(LOG2_HI);
73 let ln2_lo = ops.splat(LOG2_LO);
74
75 let p6 = ops.splat(EXP_POLY_6);
76 let p5 = ops.splat(EXP_POLY_5);
77 let p4 = ops.splat(EXP_POLY_4);
78 let p3 = ops.splat(EXP_POLY_3);
79 let p2 = ops.splat(EXP_POLY_2);
80 let p1 = ops.splat(EXP_POLY_1);
81 let p0 = ops.splat(EXP_POLY_0);
82
83 let j = ops.mul_add(x, inv_log_2, rounding_magic);
85 let j = ops.sub(j, rounding_magic);
86 let r = ops.mul_add(j, ln2_hi, x);
87 let r = ops.mul_add(j, ln2_lo, r);
88 let k = ops.to_int_trunc(j);
89
90 let mut tmp = p6;
92 tmp = ops.mul_add(tmp, r, p5);
93 tmp = ops.mul_add(tmp, r, p4);
94 tmp = ops.mul_add(tmp, r, p3);
95 tmp = ops.mul_add(tmp, r, p2);
96 tmp = ops.mul_add(tmp, r, p1);
97 let r = ops.mul_add(tmp, r, p0);
98
99 let ia = int_ops.gt(k, int_ops.zero());
109 let x7f = int_ops.splat(0x7f000000);
110 #[allow(overflowing_literals)]
111 let x83 = int_ops.splat(0x83000000);
112 let ia = int_ops.select(int_ops.zero(), x83, ia);
113 let is = int_ops.add(ia, x7f);
114
115 let it = int_ops.shift_left::<23>(k);
116 let it = int_ops.sub(it, ia);
117
118 let s: I::F32 = is.reinterpret_cast();
119 let t: I::F32 = it.reinterpret_cast();
120 let r = ops.mul(r, s);
121 let r = ops.mul(r, t);
122
123 let overflow_mask = ops.ge(x, ops.splat(104.0));
125 let underflow_mask = ops.le(x, ops.splat(-104.0));
126 let r = ops.select(ops.splat(f32::INFINITY), r, overflow_mask);
127 ops.select(ops.zero(), r, underflow_mask).same_cast()
128 }
129}
130
131const EXP_LOWER_CUTOFF: f32 = -126.5 * std::f32::consts::LN_2 + 0.01; #[derive(Default)]
140pub struct ReducedRangeExp {}
141
142impl SimdUnaryOp<f32> for ReducedRangeExp {
143 #[inline(always)]
144 fn eval<I: Isa, S: Simd<Elem = f32, Isa = I>>(&self, isa: I, x: S) -> S {
145 let ops = isa.f32();
146 let int_ops = isa.i32();
147
148 let x = x.same_cast();
149
150 let inv_log_2 = ops.splat(INV_LOG2);
152 let rounding_magic = ops.splat(ROUNDING_MAGIC);
153 let ln2_hi = ops.splat(LOG2_HI);
154 let ln2_lo = ops.splat(LOG2_LO);
155
156 let p6 = ops.splat(EXP_POLY_6);
157 let p5 = ops.splat(EXP_POLY_5);
158 let p4 = ops.splat(EXP_POLY_4);
159 let p3 = ops.splat(EXP_POLY_3);
160 let p2 = ops.splat(EXP_POLY_2);
161 let p1 = ops.splat(EXP_POLY_1);
162 let p0 = ops.splat(EXP_POLY_0);
163
164 let j = ops.mul_add(x, inv_log_2, rounding_magic);
168 let j = ops.sub(j, rounding_magic);
169 let r = ops.mul_add(j, ln2_hi, x);
170 let r = ops.mul_add(j, ln2_lo, r);
171 let k = ops.to_int_trunc(j);
172
173 let mut tmp = p6;
175 tmp = ops.mul_add(tmp, r, p5);
176 tmp = ops.mul_add(tmp, r, p4);
177 tmp = ops.mul_add(tmp, r, p3);
178 tmp = ops.mul_add(tmp, r, p2);
179 tmp = ops.mul_add(tmp, r, p1);
180 let r = ops.mul_add(tmp, r, p0);
181
182 let exponent_bias = int_ops.splat(127);
187 let k_pow2 = int_ops.shift_left::<23>(int_ops.add(k, exponent_bias));
188 let k_pow2: I::F32 = k_pow2.reinterpret_cast();
189 let r = ops.mul(r, k_pow2);
190
191 let underflow_mask = ops.lt(x, ops.splat(EXP_LOWER_CUTOFF));
193 ops.select(ops.zero(), r, underflow_mask).same_cast()
194 }
195}
196
197#[derive(Default)]
205pub struct Sigmoid {}
206
207impl SimdUnaryOp<f32> for Sigmoid {
208 #[inline(always)]
209 fn eval<I: Isa, S: Simd<Elem = f32, Isa = I>>(&self, isa: I, x: S) -> S {
210 let ops = isa.f32();
211 let x = x.same_cast();
212
213 let denom = ops.add(ops.one(), Exp::apply(isa, ops.neg(x)));
215 ops.reciprocal(denom).same_cast()
216 }
217}
218
219pub struct Silu {}
223
224impl SimdUnaryOp<f32> for Silu {
225 #[inline(always)]
226 fn eval<I: Isa, S: Simd<Elem = f32, Isa = I>>(&self, isa: I, x: S) -> S {
227 let ops = isa.f32();
228 let x = x.same_cast();
229
230 ops.mul(x, Sigmoid::apply(isa, x)).same_cast()
231 }
232}
233
234pub struct Swish {
238 pub beta: f32,
239}
240
241impl SimdUnaryOp<f32> for Swish {
242 #[inline(always)]
243 fn eval<I: Isa, S: Simd<Elem = f32, Isa = I>>(&self, isa: I, x: S) -> S {
244 let ops = isa.f32();
245 let x = x.same_cast();
246
247 let beta = ops.splat(self.beta);
248 ops.mul(x, Sigmoid::apply(isa, ops.mul(x, beta)))
249 .same_cast()
250 }
251}
252
253#[cfg(test)]
254mod tests {
255 use std::mem::MaybeUninit;
256
257 use rten_simd::SimdUnaryOp;
258
259 use super::{ReducedRangeExp, EXP_LOWER_CUTOFF};
260 use crate::testing::{
261 arange, benchmark_op, check_f32s_are_equal_ulps, check_with_all_f32s, AsUninit,
262 };
263 use crate::{Exp, Sigmoid, Silu, Swish};
264
265 const MAX_EXP_ERROR_ULPS: f32 = 1.0;
267
268 const MAX_SIGMOID_ERROR_ULPS: f32 = 4.0;
270
271 fn reference_sigmoid(x: f32) -> f32 {
272 1. / (1. + (-x).exp())
273 }
274
275 fn reference_silu(x: f32) -> f32 {
276 x * reference_sigmoid(x)
277 }
278
279 fn reference_swish(x: f32, beta: f32) -> f32 {
280 x * reference_sigmoid(beta * x)
281 }
282
283 fn check_simd_vs_reference<
286 F: Fn(&[f32], &mut [MaybeUninit<f32>]),
287 R: Fn(f32) -> f32,
288 I: Iterator<Item = f32>,
289 >(
290 simd_op: F,
291 reference_op: R,
292 max_error_ulps: f32,
293 values: I,
294 ) {
295 let cases: Vec<_> = values.collect();
296 let expected: Vec<_> = cases.iter().copied().map(reference_op).collect();
297 let mut actual = cases.clone();
298
299 simd_op(&cases, actual.as_mut_slice().as_uninit());
300
301 let results = cases
302 .iter()
303 .zip(actual.iter().zip(expected.iter()))
304 .map(|(x, (actual, expected))| (*x, *actual, *expected));
305 check_f32s_are_equal_ulps(results, max_error_ulps);
306 }
307
308 #[test]
309 fn test_exp_basic() {
310 let cases = [-2.0f32, -1., -0.5, 0.1, 0., 0.1, 0.5, 1., 2., -105., 105.];
313
314 let exp_op = Exp {};
315 for case in cases {
316 let expected = case.exp();
317 let actual = exp_op.scalar_eval(case);
318 let diff = (expected - actual).abs();
319
320 if actual.is_infinite() || expected.is_infinite() {
321 assert_eq!(actual, expected);
322 } else {
323 assert_eq!(diff, 0.);
326 };
327 }
328 }
329
330 #[test]
331 fn test_exp() {
332 check_simd_vs_reference(
333 |src, dest| Exp {}.map(src, dest),
334 f32::exp,
335 MAX_EXP_ERROR_ULPS,
336 arange(-6., 6., 0.001f32),
337 );
338 }
339
340 #[test]
341 fn test_reduced_range_exp() {
342 check_simd_vs_reference(
343 |src, dest| ReducedRangeExp {}.map(src, dest),
344 f32::exp,
345 MAX_EXP_ERROR_ULPS,
346 arange(EXP_LOWER_CUTOFF, 0., 0.015f32),
347 );
348 }
349
350 #[test]
351 #[ignore] fn test_exp_exhaustive() {
353 let exp_op = Exp {};
354 check_with_all_f32s(
355 |x| (exp_op.scalar_eval(x), x.exp()),
356 MAX_EXP_ERROR_ULPS,
357 "testing exp",
358 );
359 check_with_all_f32s(
360 |x| {
361 let mut y = [0.; 1];
362 exp_op.map(&[x], y.as_mut().as_uninit());
363 (y[0], x.exp())
364 },
365 MAX_EXP_ERROR_ULPS,
366 "testing vec_expf",
367 );
368 }
369
370 #[test]
371 fn test_sigmoid() {
372 check_simd_vs_reference(
373 |src, dest| Sigmoid {}.map(src, dest),
374 reference_sigmoid,
375 MAX_SIGMOID_ERROR_ULPS,
376 arange(-6., 6., 0.001f32),
377 );
378 }
379
380 #[test]
381 #[ignore] fn test_sigmoid_exhaustive() {
383 check_with_all_f32s(
384 |x| {
385 let mut y = [0.; 1];
386 Sigmoid {}.map(&[x], y.as_mut().as_uninit());
387 (y[0], reference_sigmoid(x))
388 },
389 MAX_SIGMOID_ERROR_ULPS,
390 "testing vec_sigmoid",
391 );
392 }
393
394 #[test]
395 fn test_silu() {
396 check_simd_vs_reference(
397 |src, dest| Silu {}.map(src, dest),
398 reference_silu,
399 MAX_SIGMOID_ERROR_ULPS,
400 arange(-6., 6., 0.001f32),
401 );
402 }
403
404 #[test]
405 fn test_swish() {
406 let beta = 1.7;
407 check_simd_vs_reference(
408 |src, dest| Swish { beta }.map(src, dest),
409 |x| reference_swish(x, beta),
410 MAX_SIGMOID_ERROR_ULPS,
411 arange(-6., 6., 0.001f32),
412 )
413 }
414
415 #[test]
416 #[ignore]
417 fn bench_exp() {
418 benchmark_op(
419 |xs, ys| xs.iter().zip(ys.iter_mut()).for_each(|(x, y)| *y = x.exp()),
420 |xs, ys| Exp {}.map(xs, ys),
421 );
422 }
423
424 #[test]
425 #[ignore]
426 fn bench_sigmoid() {
427 benchmark_op(
428 |xs, ys| {
429 xs.iter()
430 .zip(ys.iter_mut())
431 .for_each(|(x, y)| *y = reference_sigmoid(*x))
432 },
433 |xs, ys| Sigmoid {}.map(xs, ys),
434 );
435 }
436}