rlx-vulkan 0.2.10

Native Vulkan compute backend for RLX (raw `ash` + embedded SPIR-V compute kernels)
Documentation
#version 450
// Gather along `axis`: data viewed as [out_outer, axis_dim, out_inner], index
// tensor flattened to n_idx elements (f32-encoded). Output [out_outer, n_idx,
// out_inner]:  out[o, t, i] = data[o, idx[t], i].
layout(local_size_x = 256) in;

layout(std430, binding = 0) buffer Arena { float data[]; };

layout(push_constant) uniform PC {
    uint out_outer;
    uint n_idx;
    uint out_inner;
    uint axis_dim;
    uint data_off;
    uint idx_off;
    uint out_off;
} pc;

void main() {
    uint gid = gl_GlobalInvocationID.x;
    uint total = pc.out_outer * pc.n_idx * pc.out_inner;
    if (gid >= total) { return; }
    uint i = gid % pc.out_inner;
    uint t = (gid / pc.out_inner) % pc.n_idx;
    uint o = gid / (pc.n_idx * pc.out_inner);

    float fidx = data[pc.idx_off + t];
    int ix = int(floor(fidx + 0.5));
    if (ix < 0) { ix = 0; }
    if (uint(ix) >= pc.axis_dim) { ix = int(pc.axis_dim) - 1; }
    uint src = (o * pc.axis_dim + uint(ix)) * pc.out_inner + i;
    data[pc.out_off + gid] = data[pc.data_off + src];
}