rlx-vulkan 0.2.10

Native Vulkan compute backend for RLX (raw `ash` + embedded SPIR-V compute kernels)
Documentation
#version 450
// 2-D pooling on NCHW. kind: 2 = max, 1 = avg (ReduceOp::Max / ::Mean). One
// invocation per output element. Avg divides by the valid (in-bounds) count.
layout(local_size_x = 64) in;

layout(std430, binding = 0) buffer Arena { float data[]; };

layout(push_constant) uniform PC {
    uint nn; uint cc; uint hh; uint ww;   // input N,C,H,W
    uint ho; uint wo;                     // output H,W
    uint kh; uint kw; uint sh; uint sw; uint ph; uint pw;
    uint x_off; uint out_off;
    uint kind;                            // 2 = max, 1 = avg
} pc;

void main() {
    uint gid = gl_GlobalInvocationID.x;
    uint total = pc.nn * pc.cc * pc.ho * pc.wo;
    if (gid >= total) { return; }
    uint owi = gid % pc.wo;
    uint ohi = (gid / pc.wo) % pc.ho;
    uint c = (gid / (pc.wo * pc.ho)) % pc.cc;
    uint n = gid / (pc.wo * pc.ho * pc.cc);

    int h0 = int(ohi * pc.sh) - int(pc.ph);
    int w0 = int(owi * pc.sw) - int(pc.pw);
    float acc = (pc.kind == 2u) ? -3.402823466e38 : 0.0;
    uint cnt = 0u;
    for (uint i = 0u; i < pc.kh; i++) {
        int h = h0 + int(i);
        if (h < 0 || h >= int(pc.hh)) { continue; }
        for (uint j = 0u; j < pc.kw; j++) {
            int w = w0 + int(j);
            if (w < 0 || w >= int(pc.ww)) { continue; }
            uint idx = ((n * pc.cc + c) * pc.hh + uint(h)) * pc.ww + uint(w);
            float v = data[pc.x_off + idx];
            if (pc.kind == 2u) { acc = max(acc, v); } else { acc += v; }
            cnt++;
        }
    }
    if (pc.kind == 1u) { acc = (cnt > 0u) ? acc / float(cnt) : 0.0; }
    data[pc.out_off + gid] = acc;
}