rlx-vulkan 0.2.10

Native Vulkan compute backend for RLX (raw `ash` + embedded SPIR-V compute kernels)
Documentation
#version 450
// Rotary position embedding, honoring the pairing flavor (RopeStyle):
//   style 0 = NeoX (HF rotate-half): dim i pairs with i + n_rot/2.
//   style 1 = GptJ (llama.cpp NORM / interleaved): pairs (2i, 2i+1).
// Both index cos/sin row = token*tab_half at freq i in 0..n_rot/2, and copy
// dims [n_rot, head_dim) through unchanged (partial rotary). One invocation
// per (token, head). x rows walk with src_row_stride; output is dense `hidden`.
layout(local_size_x = 64) in;

layout(std430, binding = 0) buffer Arena { float data[]; };

layout(push_constant) uniform PC {
    uint batch;
    uint seq;
    uint hidden;
    uint head_dim;
    uint n_rot;
    uint nh;            // heads = hidden / head_dim
    uint tab_half;      // head_dim / 2
    uint src_row_stride;
    uint per_token;     // 1 ⇒ cos/sin row indexed by global token, else by seq pos
    uint style;         // 0 = NeoX, 1 = GptJ
    uint x_off;
    uint cos_off;
    uint sin_off;
    uint out_off;
} pc;

void main() {
    uint gid = gl_GlobalInvocationID.x;
    uint total = pc.batch * pc.seq * pc.nh;
    if (gid >= total) { return; }
    uint idx = gid / pc.nh;       // token index = bi*seq + si
    uint hi  = gid % pc.nh;
    uint bi  = idx / pc.seq;
    uint si  = idx % pc.seq;

    uint tab_off = (pc.per_token != 0u ? idx : si) * pc.tab_half;
    uint src_base = bi * pc.seq * pc.src_row_stride + si * pc.src_row_stride + hi * pc.head_dim;
    uint dst_base = bi * pc.seq * pc.hidden + si * pc.hidden + hi * pc.head_dim;
    uint rot_half = pc.n_rot / 2u;

    for (uint i = 0u; i < rot_half; i++) {
        float cv = data[pc.cos_off + tab_off + i];
        float sv = data[pc.sin_off + tab_off + i];
        uint a;
        uint b;
        if (pc.style == 0u) {        // NeoX rotate-half
            a = i;
            b = i + rot_half;
        } else {                     // GptJ interleaved
            a = 2u * i;
            b = 2u * i + 1u;
        }
        float x1 = data[pc.x_off + src_base + a];
        float x2 = data[pc.x_off + src_base + b];
        data[pc.out_off + dst_base + a] = x1 * cv - x2 * sv;
        data[pc.out_off + dst_base + b] = x2 * cv + x1 * sv;
    }
    for (uint j = pc.n_rot; j < pc.head_dim; j++) {
        data[pc.out_off + dst_base + j] = data[pc.x_off + src_base + j];
    }
}