#version 450
// Layer norm over the last axis (length n):
// y = (x - mean) / sqrt(var + eps) * gamma + beta
// One invocation per row. has_beta=0 ⇒ beta term skipped.
layout(local_size_x = 64) in;
layout(std430, binding = 0) buffer Arena { float data[]; };
layout(push_constant) uniform PC {
uint rows;
uint n;
uint x_off;
uint gamma_off;
uint beta_off;
uint out_off;
uint has_beta;
float eps;
} pc;
void main() {
uint row = gl_GlobalInvocationID.x;
if (row >= pc.rows) { return; }
uint base = row * pc.n;
float mean = 0.0;
for (uint i = 0u; i < pc.n; i++) { mean += data[pc.x_off + base + i]; }
mean /= float(pc.n);
float var = 0.0;
for (uint i = 0u; i < pc.n; i++) {
float d = data[pc.x_off + base + i] - mean;
var += d * d;
}
var /= float(pc.n);
float inv = inversesqrt(var + pc.eps);
for (uint i = 0u; i < pc.n; i++) {
float g = data[pc.gamma_off + i];
float b = (pc.has_beta != 0u) ? data[pc.beta_off + i] : 0.0;
data[pc.out_off + base + i] = (data[pc.x_off + base + i] - mean) * inv * g + b;
}
}