#version 450
// Q4_K matrix-vector product reading the *native GGML* Q4_K super-block format straight from a GPU
// buffer (no CPU dequant). One GGML BlockQ4K = 144 bytes = { f16 d ; f16 dmin ; u8 scales[12] ;
// u8 qs[128] } holding 256 weights in 8 sub-blocks of 32, each sub-block carrying a 6-bit scale and
// a 6-bit min packed into the 12-byte `scales`. Dequant is byte-exact with k_quants.rs
// BlockQ4K::to_float: weight = d*sc*(nibble) - dmin*m, where (sc,m) come from get_scale_min_k4.
// One invocation = one output row dotted with the activation. ~4.5 bits/weight => ~7x less decode
// traffic than the f32 dequant round-trip.
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in;
layout(set = 0, binding = 0) readonly buffer W { uint w[]; }; // raw Q4_K blocks, 144 B each
layout(set = 0, binding = 1) readonly buffer X { float x[]; }; // activation vector, length k
layout(set = 0, binding = 2) writeonly buffer Y { float y[]; }; // output, length nout
layout(push_constant) uniform Pc { uint nout; uint k; }; // k is a multiple of 256
uint rdbyte(uint bo) {
return bitfieldExtract(w[bo >> 2u], int((bo & 3u) * 8u), 8);
}
float rdscale(uint bo) {
uint lo = rdbyte(bo);
uint hi = rdbyte(bo + 1u);
return unpackHalf2x16(lo | (hi << 8u)).x;
}
// get_scale_min_k4(j, scales) -> (scale6bit, min6bit), reading from the 12 packed scale bytes that
// begin at byte offset `sbase`. Mirrors quantized/utils.rs exactly.
void scale_min(uint sbase, uint j, out uint sc, out uint m) {
if (j < 4u) {
sc = rdbyte(sbase + j) & 63u;
m = rdbyte(sbase + j + 4u) & 63u;
} else {
uint a = rdbyte(sbase + j + 4u);
uint b4 = rdbyte(sbase + j - 4u);
uint bj = rdbyte(sbase + j);
sc = (a & 0x0Fu) | ((b4 >> 6u) << 4u);
m = (a >> 4u) | ((bj >> 6u) << 4u);
}
}
void main() {
uint n = gl_GlobalInvocationID.x;
if (n >= nout) {
return;
}
uint nsb = k / 256u; // super-blocks per row
uint rowbase = n * nsb * 144u; // byte offset of row n
float acc = 0.0;
for (uint sb = 0u; sb < nsb; sb++) {
uint bb = rowbase + sb * 144u; // byte offset of this super-block
float d = rdscale(bb);
float dmin = rdscale(bb + 2u);
uint sbase = bb + 4u; // scales[12]
uint qbase = bb + 16u; // qs[128]
uint xsb = sb * 256u; // activation base for this super-block
// Walk the 256 weights exactly as to_float does: four 64-groups, each = 32 low nibbles
// (scale index is) then 32 high nibbles (scale index is+1).
uint is = 0u;
for (uint g = 0u; g < 4u; g++) {
uint joff = g * 64u; // output index of this 64-group within the super-block
uint qoff = qbase + g * 32u; // 32 qs bytes feeding this 64-group
uint sc1; uint m1; scale_min(sbase, is, sc1, m1);
uint sc2; uint m2; scale_min(sbase, is + 1u, sc2, m2);
float d1 = d * float(sc1); float mm1 = dmin * float(m1);
float d2 = d * float(sc2); float mm2 = dmin * float(m2);
for (uint l = 0u; l < 32u; l++) {
uint q = rdbyte(qoff + l);
float lo = float(q & 0x0Fu);
float hi = float(q >> 4u);
acc += (d1 * lo - mm1) * x[xsb + joff + l]; // low nibble -> elem joff+l
acc += (d2 * hi - mm2) * x[xsb + joff + 32u + l]; // high nibble -> elem joff+32+l
}
is += 2u;
}
}
y[n] = acc;
}