#version 450
// RMS norm over the last axis (length n), matching the CPU reference:
// y = x / sqrt(mean(x^2) + eps) * gamma + beta
// Inputs: x, gamma, beta (Op::RmsNorm carries all three). One invocation/row.
layout(local_size_x = 64) in;
layout(std430, binding = 0) buffer Arena { float data[]; };
layout(push_constant) uniform PC {
uint rows;
uint n;
uint x_off;
uint gamma_off;
uint beta_off;
uint out_off;
float eps;
} pc;
void main() {
uint row = gl_GlobalInvocationID.x;
if (row >= pc.rows) { return; }
uint base = row * pc.n;
float ss = 0.0;
for (uint i = 0u; i < pc.n; i++) {
float v = data[pc.x_off + base + i];
ss += v * v;
}
float inv = inversesqrt(ss / float(pc.n) + pc.eps);
for (uint i = 0u; i < pc.n; i++) {
float g = data[pc.gamma_off + i];
float b = data[pc.beta_off + i];
data[pc.out_off + base + i] = data[pc.x_off + base + i] * inv * g + b;
}
}