#version 450
// Fused GGUF K-quant dequant + GEMV (decode, m == 1) for Op::DequantMatMul.
//
// Computes out[1, n] = x[1, k] @ Wᵀ where W is [n, k] row-major in packed
// GGUF block layout (the same orientation as rlx-cpu `gguf_matmul_bt`). One
// invocation owns one output column `j`: it walks the k/256 super-blocks that
// make up row `j`, dequantizes each 256-element block into a private array
// (mirroring the verified rlx-gguf `dequant_q4_k_block` / `dequant_q6_k_block`
// reference math), and accumulates `Σ x[p] · w[j,p]`.
//
// The f32-uniform arena is viewed here as raw 32-bit words so we can pull the
// packed weight bytes (and f16 scales) out of the same buffer the f32
// activations live in. Scheme: 0 = Q4_K (144 B / 256 elems), 1 = Q6_K (210 B).
layout(local_size_x = 64) in;
layout(std430, binding = 0) buffer Arena { uint data[]; };
layout(push_constant) uniform PC {
uint n; // output columns (rows of W)
uint k; // contraction dim (multiple of 256)
uint x_word; // f32-word offset of x activations
uint w_word; // f32-word offset of the packed weight bytes (16-B aligned)
uint out_word; // f32-word offset of the output
uint scheme; // 0 = Q4_K, 1 = Q6_K
} pc;
// Read one byte at offset `rel` (in bytes) RELATIVE to the weight's word base
// (`pc.w_word`). We index `data[w_word + rel/4]` and extract byte `rel & 3`
// rather than forming an absolute byte address `w_word*4 + rel`: the latter
// overflows u32 once the arena exceeds 4 GiB, so weights past the 4 GiB mark
// would read garbage (→ NaN logits). Word-relative addressing never forms a
// value larger than `w_word + rel/4` words (< 4 Gi words even for a ~16 GiB
// arena). Valid because slots are ≥4-byte aligned, so `w_word*4` is the exact
// 4-aligned byte base and `(w_word*4 + rel) & 3 == rel & 3`.
uint rd_byte(uint rel) {
uint w = data[pc.w_word + (rel >> 2u)];
return (w >> ((rel & 3u) * 8u)) & 0xFFu;
}
// Read one little-endian f16 at weight-relative byte offset `rel` (2-byte aligned).
float rd_f16(uint rel) {
uint w = data[pc.w_word + (rel >> 2u)];
uint h = (w >> ((rel & 3u) * 8u)) & 0xFFFFu;
return unpackHalf2x16(h).x;
}
// Sign-extend an 8-bit value to a signed int (i8 cast).
int sx8(uint b) {
return int(b) - 256 * int((b >> 7u) & 1u);
}
// Q4_K 6-bit (scale, min) bit-interleave — mirrors get_scale_min_k4.
// `sc_base` is the absolute byte address of the 12-byte scales region.
void scale_min_k4(uint j, uint sc_base, out uint sc, out uint mn) {
if (j < 4u) {
sc = rd_byte(sc_base + j) & 63u;
mn = rd_byte(sc_base + j + 4u) & 63u;
} else {
uint a = rd_byte(sc_base + j + 4u);
sc = (a & 0x0Fu) | ((rd_byte(sc_base + j - 4u) >> 6u) << 4u);
mn = (a >> 4u) | ((rd_byte(sc_base + j) >> 6u) << 4u);
}
}
void main() {
uint j = gl_GlobalInvocationID.x;
if (j >= pc.n) { return; }
uint blocks_per_row = pc.k / 256u;
uint block_bytes = (pc.scheme == 0u) ? 144u : 210u;
float blk[256];
float acc = 0.0;
for (uint r = 0u; r < blocks_per_row; r++) {
uint gbi = j * blocks_per_row + r; // global super-block index
uint base = gbi * block_bytes; // weight-relative byte of block
if (pc.scheme == 0u) {
// ── Q4_K: d(f16), dmin(f16), scales[12], qs[128] ──────────────
float d = rd_f16(base);
float dmin = rd_f16(base + 2u);
uint sc_base = base + 4u;
uint qs_base = base + 16u;
uint out_i = 0u;
uint is = 0u;
for (uint jj = 0u; jj < 8u; jj += 2u) {
uint sc0; uint m0; uint sc1; uint m1;
scale_min_k4(jj, sc_base, sc0, m0);
scale_min_k4(jj + 1u, sc_base, sc1, m1);
float d0 = d * float(sc0);
float m0f = dmin * float(m0);
float d1 = d * float(sc1);
float m1f = dmin * float(m1);
for (uint l = 0u; l < 32u; l++) {
uint q = rd_byte(qs_base + is + l);
blk[out_i] = d0 * float(q & 0x0Fu) - m0f;
out_i++;
}
for (uint l = 0u; l < 32u; l++) {
uint q = rd_byte(qs_base + is + l);
blk[out_i] = d1 * float(q >> 4u) - m1f;
out_i++;
}
is += 32u;
}
} else {
// ── Q6_K: ql[128], qh[64], sc[16] (i8), d(f16) ────────────────
uint ql_base = base;
uint qh_base = base + 128u;
uint sc_base = base + 192u;
float d = rd_f16(base + 208u);
for (uint h = 0u; h < 2u; h++) {
uint dst_base = h * 128u;
uint ql_off = h * 64u;
uint qh_off = h * 32u;
uint sc_off = h * 8u;
for (uint l = 0u; l < 32u; l++) {
uint isb = l / 16u;
uint qhb = rd_byte(qh_base + qh_off + l);
uint lo0 = rd_byte(ql_base + ql_off + l);
uint lo1 = rd_byte(ql_base + ql_off + l + 32u);
int q1 = int((lo0 & 0x0Fu) | ((qhb & 3u) << 4u)) - 32;
int q2 = int((lo1 & 0x0Fu) | (((qhb >> 2u) & 3u) << 4u)) - 32;
int q3 = int((lo0 >> 4u) | (((qhb >> 4u) & 3u) << 4u)) - 32;
int q4 = int((lo1 >> 4u) | (((qhb >> 6u) & 3u) << 4u)) - 32;
blk[dst_base + l] = d * float(sx8(rd_byte(sc_base + sc_off + isb))) * float(q1);
blk[dst_base + l + 32u] = d * float(sx8(rd_byte(sc_base + sc_off + isb + 2u))) * float(q2);
blk[dst_base + l + 64u] = d * float(sx8(rd_byte(sc_base + sc_off + isb + 4u))) * float(q3);
blk[dst_base + l + 96u] = d * float(sx8(rd_byte(sc_base + sc_off + isb + 6u))) * float(q4);
}
}
}
uint xb = pc.x_word + r * 256u;
for (uint t = 0u; t < 256u; t++) {
acc += uintBitsToFloat(data[xb + t]) * blk[t];
}
}
data[pc.out_word + j] = floatBitsToUint(acc);
}