rlx-vulkan 0.2.10

Native Vulkan compute backend for RLX (raw `ash` + embedded SPIR-V compute kernels)
Documentation
#version 450
// Mamba selective-scan (plan #15), matching the CPU reference recurrence:
//   state[c,n] = exp(Δ·A[c,n])·state[c,n] + Δ·B[s,n]·x[s,c]
//   y[s,c]     = Σ_n C[s,n]·state[c,n]
// Inputs: x,Δ [B,S,H], A [H,N], B,C [B,S,N]. One invocation per (batch, channel),
// walking the sequence sequentially while channels/batches run in parallel.
layout(local_size_x = 64) in;

layout(std430, binding = 0) buffer Arena { float data[]; };

const uint MAX_N = 256u;

layout(push_constant) uniform PC {
    uint bb;       // batch
    uint ss;       // seq
    uint hh;       // hidden channels
    uint nn;       // state size
    uint x_off;
    uint delta_off;
    uint a_off;
    uint b_off;
    uint c_off;
    uint out_off;
} pc;

void main() {
    uint gid = gl_GlobalInvocationID.x;
    uint total = pc.bb * pc.hh;
    if (gid >= total || pc.nn > MAX_N) { return; }
    uint bi = gid / pc.hh;
    uint ci = gid % pc.hh;

    float state[MAX_N];
    for (uint n = 0u; n < pc.nn; n++) { state[n] = 0.0; }

    for (uint si = 0u; si < pc.ss; si++) {
        uint bsh = (bi * pc.ss + si) * pc.hh + ci;
        uint bsn = (bi * pc.ss + si) * pc.nn;
        float d = data[pc.delta_off + bsh];
        float xv = data[pc.x_off + bsh];
        float acc = 0.0;
        for (uint n = 0u; n < pc.nn; n++) {
            float da = exp(d * data[pc.a_off + ci * pc.nn + n]);
            float sv = da * state[n] + d * data[pc.b_off + bsn + n] * xv;
            state[n] = sv;
            acc += data[pc.c_off + bsn + n] * sv;
        }
        data[pc.out_off + bsh] = acc;
    }
}