rlx-vulkan 0.2.10

Native Vulkan compute backend for RLX (raw `ash` + embedded SPIR-V compute kernels)
Documentation
#version 450
// Top-K indices along the last axis (f32-encoded), sorted descending, ties
// broken by smaller index (torch.topk largest=True, sorted=True). One
// invocation per row; array-free selection that carries the previous pick so
// duplicate values still advance. Input [..., n] → output [..., k].
layout(local_size_x = 64) in;

layout(std430, binding = 0) buffer Arena { float data[]; };

layout(push_constant) uniform PC {
    uint rows;
    uint n;
    uint k;
    uint in_off;
    uint out_off;
} pc;

void main() {
    uint row = gl_GlobalInvocationID.x;
    if (row >= pc.rows) { return; }
    uint base = pc.in_off + row * pc.n;

    float prev_val = 3.402823466e38; // sentinel: everything is "below" this
    uint prev_idx = 0u;
    bool have_prev = false;
    for (uint s = 0u; s < pc.k; s++) {
        float best = -3.402823466e38;
        uint best_i = 0u;
        bool found = false;
        for (uint i = 0u; i < pc.n; i++) {
            float v = data[base + i];
            // Strictly after the previous pick in (value desc, index asc) order.
            bool below = !have_prev || (v < prev_val) || (v == prev_val && i > prev_idx);
            bool better = (v > best) || (v == best && i < best_i);
            if (below && (!found || better)) {
                best = v;
                best_i = i;
                found = true;
            }
        }
        data[pc.out_off + row * pc.k + s] = float(best_i);
        prev_val = best;
        prev_idx = best_i;
        have_prev = true;
    }
}