rlx-vulkan 0.2.10

Native Vulkan compute backend for RLX (raw `ash` + embedded SPIR-V compute kernels)
Documentation
#version 450
// Fused GGUF K-quant dequant + GEMV (decode, m == 1) for Op::DequantMatMul.
//
// Computes out[1, n] = x[1, k] @ Wᵀ where W is [n, k] row-major in packed
// GGUF block layout (the same orientation as rlx-cpu `gguf_matmul_bt`). One
// invocation owns one output column `j`: it walks the k/256 super-blocks that
// make up row `j`, dequantizes each 256-element block into a private array
// (mirroring the verified rlx-gguf `dequant_q4_k_block` / `dequant_q6_k_block`
// reference math), and accumulates `Σ x[p] · w[j,p]`.
//
// The f32-uniform arena is viewed here as raw 32-bit words so we can pull the
// packed weight bytes (and f16 scales) out of the same buffer the f32
// activations live in. Scheme: 0 = Q4_K (144 B / 256 elems), 1 = Q6_K (210 B).
layout(local_size_x = 64) in;

layout(std430, binding = 0) buffer Arena { uint data[]; };

layout(push_constant) uniform PC {
    uint n;         // output columns (rows of W)
    uint k;         // contraction dim (multiple of 256)
    uint x_word;    // f32-word offset of x activations
    uint w_word;    // f32-word offset of the packed weight bytes (16-B aligned)
    uint out_word;  // f32-word offset of the output
    uint scheme;    // 0 = Q4_K, 1 = Q6_K
} pc;

// Read one byte at offset `rel` (in bytes) RELATIVE to the weight's word base
// (`pc.w_word`). We index `data[w_word + rel/4]` and extract byte `rel & 3`
// rather than forming an absolute byte address `w_word*4 + rel`: the latter
// overflows u32 once the arena exceeds 4 GiB, so weights past the 4 GiB mark
// would read garbage (→ NaN logits). Word-relative addressing never forms a
// value larger than `w_word + rel/4` words (< 4 Gi words even for a ~16 GiB
// arena). Valid because slots are ≥4-byte aligned, so `w_word*4` is the exact
// 4-aligned byte base and `(w_word*4 + rel) & 3 == rel & 3`.
uint rd_byte(uint rel) {
    uint w = data[pc.w_word + (rel >> 2u)];
    return (w >> ((rel & 3u) * 8u)) & 0xFFu;
}

// Read one little-endian f16 at weight-relative byte offset `rel` (2-byte aligned).
float rd_f16(uint rel) {
    uint w = data[pc.w_word + (rel >> 2u)];
    uint h = (w >> ((rel & 3u) * 8u)) & 0xFFFFu;
    return unpackHalf2x16(h).x;
}

// Sign-extend an 8-bit value to a signed int (i8 cast).
int sx8(uint b) {
    return int(b) - 256 * int((b >> 7u) & 1u);
}

// Q4_K 6-bit (scale, min) bit-interleave — mirrors get_scale_min_k4.
// `sc_base` is the absolute byte address of the 12-byte scales region.
void scale_min_k4(uint j, uint sc_base, out uint sc, out uint mn) {
    if (j < 4u) {
        sc = rd_byte(sc_base + j) & 63u;
        mn = rd_byte(sc_base + j + 4u) & 63u;
    } else {
        uint a = rd_byte(sc_base + j + 4u);
        sc = (a & 0x0Fu) | ((rd_byte(sc_base + j - 4u) >> 6u) << 4u);
        mn = (a >> 4u) | ((rd_byte(sc_base + j) >> 6u) << 4u);
    }
}

void main() {
    uint j = gl_GlobalInvocationID.x;
    if (j >= pc.n) { return; }

    uint blocks_per_row = pc.k / 256u;
    uint block_bytes = (pc.scheme == 0u) ? 144u : 210u;

    float blk[256];
    float acc = 0.0;

    for (uint r = 0u; r < blocks_per_row; r++) {
        uint gbi = j * blocks_per_row + r;        // global super-block index
        uint base = gbi * block_bytes;            // weight-relative byte of block

        if (pc.scheme == 0u) {
            // ── Q4_K: d(f16), dmin(f16), scales[12], qs[128] ──────────────
            float d = rd_f16(base);
            float dmin = rd_f16(base + 2u);
            uint sc_base = base + 4u;
            uint qs_base = base + 16u;
            uint out_i = 0u;
            uint is = 0u;
            for (uint jj = 0u; jj < 8u; jj += 2u) {
                uint sc0; uint m0; uint sc1; uint m1;
                scale_min_k4(jj, sc_base, sc0, m0);
                scale_min_k4(jj + 1u, sc_base, sc1, m1);
                float d0 = d * float(sc0);
                float m0f = dmin * float(m0);
                float d1 = d * float(sc1);
                float m1f = dmin * float(m1);
                for (uint l = 0u; l < 32u; l++) {
                    uint q = rd_byte(qs_base + is + l);
                    blk[out_i] = d0 * float(q & 0x0Fu) - m0f;
                    out_i++;
                }
                for (uint l = 0u; l < 32u; l++) {
                    uint q = rd_byte(qs_base + is + l);
                    blk[out_i] = d1 * float(q >> 4u) - m1f;
                    out_i++;
                }
                is += 32u;
            }
        } else {
            // ── Q6_K: ql[128], qh[64], sc[16] (i8), d(f16) ────────────────
            uint ql_base = base;
            uint qh_base = base + 128u;
            uint sc_base = base + 192u;
            float d = rd_f16(base + 208u);
            for (uint h = 0u; h < 2u; h++) {
                uint dst_base = h * 128u;
                uint ql_off = h * 64u;
                uint qh_off = h * 32u;
                uint sc_off = h * 8u;
                for (uint l = 0u; l < 32u; l++) {
                    uint isb = l / 16u;
                    uint qhb = rd_byte(qh_base + qh_off + l);
                    uint lo0 = rd_byte(ql_base + ql_off + l);
                    uint lo1 = rd_byte(ql_base + ql_off + l + 32u);
                    int q1 = int((lo0 & 0x0Fu) | ((qhb & 3u) << 4u)) - 32;
                    int q2 = int((lo1 & 0x0Fu) | (((qhb >> 2u) & 3u) << 4u)) - 32;
                    int q3 = int((lo0 >> 4u) | (((qhb >> 4u) & 3u) << 4u)) - 32;
                    int q4 = int((lo1 >> 4u) | (((qhb >> 6u) & 3u) << 4u)) - 32;
                    blk[dst_base + l]        = d * float(sx8(rd_byte(sc_base + sc_off + isb)))        * float(q1);
                    blk[dst_base + l + 32u]  = d * float(sx8(rd_byte(sc_base + sc_off + isb + 2u)))  * float(q2);
                    blk[dst_base + l + 64u]  = d * float(sx8(rd_byte(sc_base + sc_off + isb + 4u)))  * float(q3);
                    blk[dst_base + l + 96u]  = d * float(sx8(rd_byte(sc_base + sc_off + isb + 6u)))  * float(q4);
                }
            }
        }

        uint xb = pc.x_word + r * 256u;
        for (uint t = 0u; t < 256u; t++) {
            acc += uintBitsToFloat(data[xb + t]) * blk[t];
        }
    }

    data[pc.out_word + j] = floatBitsToUint(acc);
}