#version 450
layout(local_size_x = 256) in;
layout(set = 0, binding = 0) readonly buffer Input { float x[]; };
layout(set = 0, binding = 1) writeonly buffer Output { float result[]; };
layout(push_constant) uniform Params {
int n;
};
void main() {
uint idx = gl_GlobalInvocationID.x;
if (idx < n) {
float val = x[idx];
// GELU approximation: 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))
float c = 0.7978845608; // sqrt(2/pi)
float inner = c * (val + 0.044715 * val * val * val);
inner = clamp(inner, -10.0, 10.0);
result[idx] = 0.5 * val * (1.0 + tanh(inner));
}
}