#version 450
// Fused SwiGLU activation: out[i] = silu(a[i]) * b[i] = (a / (1 + exp(-a))) * b.
// One dispatch instead of a separate silu then mul -- cuts a dispatch + an intermediate buffer per
// MLP per layer (the per-op dispatch cost dominates the forward).
layout(local_size_x = 64) in;
layout(set = 0, binding = 0) readonly buffer A { float a[]; };
layout(set = 0, binding = 1) readonly buffer B { float b[]; };
layout(set = 0, binding = 2) writeonly buffer Out { float o[]; };
layout(push_constant) uniform Pc { uint n; };
void main() {
uint i = gl_GlobalInvocationID.x;
if (i < n) {
float x = a[i];
o[i] = (x / (1.0 + exp(-x))) * b[i];
}
}