#version 450
// LayerNorm2d on NCHW: normalize across the channel axis at each spatial
// position (candle / SAM semantics). x [N,C,H,W], gamma/beta [C]. One
// invocation per (n, h, w) position.
layout(local_size_x = 64) in;
layout(std430, binding = 0) buffer Arena { float data[]; };
layout(push_constant) uniform PC {
uint positions; // N * H * W
uint c; // channels
uint hw; // H * W
uint x_off;
uint gamma_off;
uint beta_off;
uint out_off;
float eps;
} pc;
void main() {
uint p = gl_GlobalInvocationID.x;
if (p >= pc.positions) { return; }
uint nidx = p / pc.hw;
uint sp = p % pc.hw;
uint base = nidx * pc.c * pc.hw + sp; // element (n, ch, h, w) at base + ch*hw
float mean = 0.0;
for (uint ch = 0u; ch < pc.c; ch++) {
mean += data[pc.x_off + base + ch * pc.hw];
}
mean /= float(pc.c);
float var = 0.0;
for (uint ch = 0u; ch < pc.c; ch++) {
float d = data[pc.x_off + base + ch * pc.hw] - mean;
var += d * d;
}
var /= float(pc.c);
float inv = inversesqrt(var + pc.eps);
for (uint ch = 0u; ch < pc.c; ch++) {
float g = data[pc.gamma_off + ch];
float b = data[pc.beta_off + ch];
data[pc.out_off + base + ch * pc.hw] =
(data[pc.x_off + base + ch * pc.hw] - mean) * inv * g + b;
}
}