rlx-vulkan 0.2.10

Native Vulkan compute backend for RLX (raw `ash` + embedded SPIR-V compute kernels)
Documentation
#version 450
// Direct 2-D convolution on NCHW. x [N,Cin,H,W], weight [Cout,Cin/groups,kH,kW],
// optional bias [Cout]. out [N,Cout,Ho,Wo]. Supports stride/padding/dilation/groups.
// One invocation per output element.
layout(local_size_x = 64) in;

layout(std430, binding = 0) buffer Arena { float data[]; };

layout(push_constant) uniform PC {
    uint nn; uint cin; uint hh; uint ww;
    uint cout; uint kh; uint kw;
    uint oh; uint ow;
    uint sh; uint sw; uint ph; uint pw; uint dh; uint dw;
    uint groups; uint has_bias;
    uint x_off; uint w_off; uint b_off; uint out_off;
} pc;

void main() {
    uint gid = gl_GlobalInvocationID.x;
    uint total = pc.nn * pc.cout * pc.oh * pc.ow;
    if (gid >= total) { return; }
    uint owi = gid % pc.ow;
    uint ohi = (gid / pc.ow) % pc.oh;
    uint co = (gid / (pc.ow * pc.oh)) % pc.cout;
    uint n = gid / (pc.ow * pc.oh * pc.cout);

    uint cin_pg = pc.cin / pc.groups;
    uint cout_pg = pc.cout / pc.groups;
    uint group = co / cout_pg;
    uint ci_start = group * cin_pg;

    float acc = (pc.has_bias != 0u) ? data[pc.b_off + co] : 0.0;
    for (uint cl = 0u; cl < cin_pg; cl++) {
        uint ci = ci_start + cl;
        for (uint i = 0u; i < pc.kh; i++) {
            int h = int(ohi * pc.sh) - int(pc.ph) + int(i * pc.dh);
            if (h < 0 || h >= int(pc.hh)) { continue; }
            for (uint j = 0u; j < pc.kw; j++) {
                int w = int(owi * pc.sw) - int(pc.pw) + int(j * pc.dw);
                if (w < 0 || w >= int(pc.ww)) { continue; }
                float xv = data[pc.x_off + ((n * pc.cin + ci) * pc.hh + uint(h)) * pc.ww + uint(w)];
                float wv = data[pc.w_off + ((co * cin_pg + cl) * pc.kh + i) * pc.kw + j];
                acc += xv * wv;
            }
        }
    }
    data[pc.out_off + gid] = acc;
}