#version 450
#include "types.glsl"
#include "generic_unary_head.glsl"
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
float op_abs(float x) {
return abs(x);
}
float op_sgn(float x) {
return sign(x);
}
float op_neg(float x) {
return -x;
}
float op_step(float x) {
return x >= 0.0f ? 1.0f : 0.0f;
}
float op_tanh(float x) {
return 1.0f - 2.0f / (exp(2.0f*x) + 1.0f);
}
float op_elu(float x) {
return x < 0.0f ? exp(x) - 1.0f : x;
}
float op_relu(float x) {
return max(x, 0.0f);
}
float op_sigmoid(float x) {
return 1.0f / (1.0f + exp(-x));
}
float op_gelu(float x) {
const float GELU_COEF_A = 0.044715f;
const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
const float val = SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x);
return 0.5f*x*(2.0f - 2.0f / (exp(2.0f * val) + 1.0f));
}
float op_gelu_quick(float x) {
const float GELU_QUICK_COEF = -1.702f;
return x * (1.0f / (1.0f + exp(GELU_QUICK_COEF * x)));
}
float op_silu(float x) {
return x / (1.0f + exp(-x));
}
float op_hardswish(float x) {
return x * min(1.0f, max(0.0f, (x + 3.0f) / 6.0f));
}
float op_hardsigmoid(float x) {
return min(1.0f, max(0.0f, (x + 3.0f) / 6.0f));
}
float op_exp(float x) {
return exp(x);
}
float op_expm1(float x) {
// exp(x) - 1 loses many ulps to cancellation near zero. Use a degree-6
// Taylor expansion for |x| <= 1/4: the omitted x^7/5040 term is < 1.3e-8,
// about 0.5 ulp at expm1(0.25), and a host-side f32 model stays within
// 2 ulps over the interval. The first native exp(x)-1 values outside the
// cutoff are about 1 ulp for +0.25 and 2 ulps for -0.25.
if (abs(x) <= 0.25f) {
return x * (1.0f + x * (0.5f + x * ((1.0f/6.0f) + x * ((1.0f/24.0f) + x * ((1.0f/120.0f) + x * (1.0f/720.0f))))));
}
return exp(x) - 1.0f;
}
float op_softplus(float x) {
return (x > 20.0f) ? x : log(1.0f + exp(x));
}
float op_gelu_erf(float a) {
// based on Abramowitz and Stegun formula 7.1.26 or similar Hastings' approximation
const float p_erf = 0.3275911f;
const float a1_erf = 0.254829592f;
const float a2_erf = -0.284496736f;
const float a3_erf = 1.421413741f;
const float a4_erf = -1.453152027f;
const float a5_erf = 1.061405429f;
const float SQRT_2_INV = 0.70710678118654752440084436210484f;
const float a_div_sqr2 = a * SQRT_2_INV;
const float sign_x = sign(a_div_sqr2);
const float x = abs(a_div_sqr2);
const float t = 1.0f / (1.0f + p_erf * x);
const float y = 1.0f - (((((a5_erf * t + a4_erf) * t) + a3_erf) * t + a2_erf) * t + a1_erf) * t * exp(-x * x);
return 0.5f * a * (1.0f + sign_x * y);
}
float op_xielu(float x) {
const float alpha_n = p.param1;
const float alpha_p = p.param2;
const float beta = p.param3;
const float eps = p.param4;
if (x > 0.0f) {
return alpha_p * x * x + beta * x;
}
const float min_x_eps = min(x, eps);
return (op_expm1(min_x_eps) - x) * alpha_n + beta * x;
}
float op_floor(float x) {
return floor(x);
}
float op_ceil(float x) {
return ceil(x);
}
float op_round(float x) {
// Round halfway cases away from zero as roundf does.
return x >= 0.0f ? floor(x + 0.5f) : ceil(x - 0.5f);
}
float op_trunc(float x) {
return trunc(x);
}
void main() {
const uint idx = get_idx();
if (idx >= p.ne) {
return;
}
const uint a_idx = get_aoffset() + src0_idx(idx);
const uint d_idx = get_doffset() + dst_idx(idx);
data_d[d_idx] = D_TYPE(OP(float(data_a[a_idx])));
}