rlx-vulkan 0.2.10

Native Vulkan compute backend for RLX (raw `ash` + embedded SPIR-V compute kernels)
Documentation
#version 450
// LayerNorm2d on NCHW: normalize across the channel axis at each spatial
// position (candle / SAM semantics). x [N,C,H,W], gamma/beta [C]. One
// invocation per (n, h, w) position.
layout(local_size_x = 64) in;

layout(std430, binding = 0) buffer Arena { float data[]; };

layout(push_constant) uniform PC {
    uint positions; // N * H * W
    uint c;         // channels
    uint hw;        // H * W
    uint x_off;
    uint gamma_off;
    uint beta_off;
    uint out_off;
    float eps;
} pc;

void main() {
    uint p = gl_GlobalInvocationID.x;
    if (p >= pc.positions) { return; }
    uint nidx = p / pc.hw;
    uint sp = p % pc.hw;
    uint base = nidx * pc.c * pc.hw + sp; // element (n, ch, h, w) at base + ch*hw

    float mean = 0.0;
    for (uint ch = 0u; ch < pc.c; ch++) {
        mean += data[pc.x_off + base + ch * pc.hw];
    }
    mean /= float(pc.c);
    float var = 0.0;
    for (uint ch = 0u; ch < pc.c; ch++) {
        float d = data[pc.x_off + base + ch * pc.hw] - mean;
        var += d * d;
    }
    var /= float(pc.c);
    float inv = inversesqrt(var + pc.eps);
    for (uint ch = 0u; ch < pc.c; ch++) {
        float g = data[pc.gamma_off + ch];
        float b = data[pc.beta_off + ch];
        data[pc.out_off + base + ch * pc.hw] =
            (data[pc.x_off + base + ch * pc.hw] - mean) * inv * g + b;
    }
}