rlx-vulkan 0.2.10

Native Vulkan compute backend for RLX (raw `ash` + embedded SPIR-V compute kernels)
Documentation
#version 450
// ArgMax / ArgMin along an axis (viewed as [outer, axis_len, inner]). Output is
// the f32-encoded index of the extreme element. op: 0 = argmax, 1 = argmin.
// Ties break to the smaller index (matches NumPy / the CPU reference).
layout(local_size_x = 256) in;

layout(std430, binding = 0) buffer Arena { float data[]; };

layout(push_constant) uniform PC {
    uint outer;
    uint axis_len;
    uint inner;
    uint in_off;
    uint out_off;
    uint op;
} pc;

void main() {
    uint gid = gl_GlobalInvocationID.x;
    uint total = pc.outer * pc.inner;
    if (gid >= total) { return; }
    uint o = gid / pc.inner;
    uint j = gid % pc.inner;
    uint base = o * pc.axis_len * pc.inner + j;

    float best = data[pc.in_off + base];
    uint best_i = 0u;
    for (uint k = 1u; k < pc.axis_len; k++) {
        float v = data[pc.in_off + base + k * pc.inner];
        bool better = (pc.op == 0u) ? (v > best) : (v < best);
        if (better) {
            best = v;
            best_i = k;
        }
    }
    data[pc.out_off + gid] = float(best_i);
}