#version 450
// LayerNorm: out[i] = (x[i] - mean) / sqrt(var + eps) * weight[i] + bias[i]
// Single-workgroup approach for vectors up to ~16K elements.
layout(local_size_x = 256) in;
layout(set = 0, binding = 0) readonly buffer Input { float x[]; };
layout(set = 0, binding = 1) readonly buffer Weight { float w[]; };
layout(set = 0, binding = 2) readonly buffer Bias { float b[]; };
layout(set = 0, binding = 3) writeonly buffer Output { float result[]; };
layout(push_constant) uniform Params {
int n;
float eps;
};
shared float sdata[256];
void main() {
uint tid = gl_LocalInvocationID.x;
// Pass 1: compute sum for mean
float local_sum = 0.0;
for (uint i = tid; i < n; i += 256) {
local_sum += x[i];
}
sdata[tid] = local_sum;
barrier();
for (uint s = 128; s > 0; s >>= 1) {
if (tid < s) sdata[tid] += sdata[tid + s];
barrier();
}
float mean = sdata[0] / float(n);
barrier();
// Pass 2: compute variance
float local_var = 0.0;
for (uint i = tid; i < n; i += 256) {
float d = x[i] - mean;
local_var += d * d;
}
sdata[tid] = local_var;
barrier();
for (uint s = 128; s > 0; s >>= 1) {
if (tid < s) sdata[tid] += sdata[tid + s];
barrier();
}
float inv_std = inversesqrt(sdata[0] / float(n) + eps);
// Pass 3: normalize and scale
for (uint i = tid; i < n; i += 256) {
float normed = (x[i] - mean) * inv_std;
result[i] = normed * w[i] + b[i];
}
}