#version 450
// GELU activation compute shader
// Computes: result[i] = 0.5 * input[i] * (1.0 + tanh(sqrt(2/π) * (input[i] + 0.044715 * input[i]^3)))
// Uses approximation for computational efficiency
layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in;
layout(set = 0, binding = 0) buffer InputBuffer {
float input_data[];
};
layout(set = 0, binding = 1) buffer OutputBuffer {
float result[];
};
void main() {
uint index = gl_GlobalInvocationID.x;
// Bounds checking
if (index >= input_data.length() || index >= result.length()) {
return;
}
// Perform GELU activation entirely on GPU
// GELU(x) = 0.5 * x * (1 + tanh(sqrt(2/π) * (x + 0.044715 * x^3)))
float x = input_data[index];
// Constants for GELU approximation
const float sqrt_2_over_pi = 0.7978845608028654; // sqrt(2/π)
const float gelu_coeff = 0.044715;
// Compute x^3
float x_cubed = x * x * x;
// Compute the argument to tanh
float tanh_arg = sqrt_2_over_pi * (x + gelu_coeff * x_cubed);
// Use numerically stable tanh computation
float tanh_val;
if (abs(tanh_arg) > 10.0) {
tanh_val = sign(tanh_arg);
} else {
float exp_2x = exp(2.0 * tanh_arg);
tanh_val = (exp_2x - 1.0) / (exp_2x + 1.0);
}
// Final GELU computation
result[index] = 0.5 * x * (1.0 + tanh_val);
}