rlx-vulkan 0.2.10

Native Vulkan compute backend for RLX (raw `ash` + embedded SPIR-V compute kernels)
Documentation
#version 450
// Scaled dot-product attention with online (flash-style) softmax, mirroring
// the CPU reference semantics. One invocation per (batch, head, query).
// Layouts: bhsd=1 ⇒ [B,H,S,D]; bhsd=0 ⇒ [B,S,H,D] with per-tensor row stride.
// KV heads are assumed equal to query heads (GQA expanded upstream).
// mask_kind: 0 none, 1 causal, 2 sliding-window, 3 custom (per b,k threshold),
//            4 bias (additive per b,h,q,k). q_offset = k_s - q_s (KV-cache).
layout(local_size_x = 64) in;

layout(std430, binding = 0) buffer Arena { float data[]; };

const uint MAX_DH = 256u;

layout(push_constant) uniform PC {
    uint b;
    uint nh;
    uint q_s;
    uint k_s;
    uint dh;
    uint q_off;
    uint k_off;
    uint v_off;
    uint o_off;
    uint qrs;        // [B,S,H,D] row strides; unused when bhsd=1
    uint krs;
    uint vrs;
    uint bhsd;
    uint mask_kind;
    uint mask_off;
    uint window;
    float scale;
    float neg;       // masked-out score value
    float thr;       // custom-mask threshold
} pc;

void main() {
    uint gid = gl_GlobalInvocationID.x;
    uint total = pc.b * pc.nh * pc.q_s;
    if (gid >= total) { return; }
    uint qi = gid % pc.q_s;
    uint hi = (gid / pc.q_s) % pc.nh;
    uint bi = gid / (pc.q_s * pc.nh);
    if (pc.dh > MAX_DH) { return; }

    uint q_base;
    uint o_base;
    if (pc.bhsd != 0u) {
        q_base = pc.q_off + bi * pc.nh * pc.q_s * pc.dh + hi * pc.q_s * pc.dh + qi * pc.dh;
        o_base = pc.o_off + bi * pc.nh * pc.q_s * pc.dh + hi * pc.q_s * pc.dh + qi * pc.dh;
    } else {
        q_base = pc.q_off + bi * pc.q_s * pc.qrs + qi * pc.qrs + hi * pc.dh;
        o_base = pc.o_off + bi * pc.q_s * (pc.nh * pc.dh) + qi * (pc.nh * pc.dh) + hi * pc.dh;
    }

    uint q_offset = (pc.k_s > pc.q_s) ? (pc.k_s - pc.q_s) : 0u;
    uint abs_q = q_offset + qi;

    float acc[MAX_DH];
    for (uint d = 0u; d < pc.dh; d++) { acc[d] = 0.0; }
    float m = -3.402823466e38;
    float l = 0.0;

    for (uint ki = 0u; ki < pc.k_s; ki++) {
        // Masking.
        bool masked = false;
        float bias = 0.0;
        if (pc.mask_kind == 1u) {                 // causal
            masked = (ki > abs_q);
        } else if (pc.mask_kind == 2u) {          // sliding window
            uint lo = (abs_q > pc.window) ? (abs_q - pc.window) : 0u;
            masked = (ki < lo) || (ki > abs_q);
        } else if (pc.mask_kind == 3u) {          // custom (per b,k)
            masked = data[pc.mask_off + bi * pc.k_s + ki] < pc.thr;
        } else if (pc.mask_kind == 4u) {          // additive bias (per b,h,q,k)
            uint per_bh = pc.q_s * pc.k_s;
            bias = data[pc.mask_off + (bi * pc.nh + hi) * per_bh + qi * pc.k_s + ki];
        }

        uint k_base;
        uint v_base;
        if (pc.bhsd != 0u) {
            k_base = pc.k_off + bi * pc.nh * pc.k_s * pc.dh + hi * pc.k_s * pc.dh + ki * pc.dh;
            v_base = pc.v_off + bi * pc.nh * pc.k_s * pc.dh + hi * pc.k_s * pc.dh + ki * pc.dh;
        } else {
            k_base = pc.k_off + bi * pc.k_s * pc.krs + ki * pc.krs + hi * pc.dh;
            v_base = pc.v_off + bi * pc.k_s * pc.vrs + ki * pc.vrs + hi * pc.dh;
        }

        float score;
        if (masked) {
            score = pc.neg;
        } else {
            float dot = 0.0;
            for (uint d = 0u; d < pc.dh; d++) {
                dot += data[q_base + d] * data[k_base + d];
            }
            score = dot * pc.scale + bias;
        }

        // Online softmax update.
        float m_new = max(m, score);
        float corr = exp(m - m_new);
        float p = exp(score - m_new);
        l = l * corr + p;
        for (uint d = 0u; d < pc.dh; d++) {
            acc[d] = acc[d] * corr + p * data[v_base + d];
        }
        m = m_new;
    }

    float inv = (l > 0.0) ? (1.0 / l) : 0.0;
    for (uint d = 0u; d < pc.dh; d++) {
        data[o_base + d] = acc[d] * inv;
    }
}