#version 450
// Fused MoE grouped quant matvec (Q4_K): y[s, r] = sum_k W[ids[s], r, k] * x[s, k], reading the
// per-expert slice from a GGML Q4_K weight bank [E, n, k] in VRAM. Router gather + per-expert GEMM
// in one dispatch. Q4_K super-block = 144 bytes (256 weights, 8 sub-blocks of 32, 6-bit packed
// scales/mins); dequant = d*sc*nibble - dmin*m, byte-exact with BlockQ4K::to_float. One
// invocation = one y element.
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in;
layout(set = 0, binding = 0) readonly buffer W { uint w[]; }; // expert bank, raw Q4_K bytes
layout(set = 0, binding = 1) readonly buffer X { float x[]; }; // [S, k] activations
layout(set = 0, binding = 2) readonly buffer Ids { uint ids[]; }; // [S] expert id per slot
layout(set = 0, binding = 3) writeonly buffer Y { float y[]; }; // [S, n] outputs
layout(push_constant) uniform Pc { uint n; uint k; uint nrows; }; // k mult of 256
uint rdbyte(uint bo) { return bitfieldExtract(w[bo >> 2u], int((bo & 3u) * 8u), 8); }
float rdscale(uint bo) {
uint lo = rdbyte(bo);
uint hi = rdbyte(bo + 1u);
return unpackHalf2x16(lo | (hi << 8u)).x;
}
void scale_min(uint sbase, uint j, out uint sc, out uint m) {
if (j < 4u) {
sc = rdbyte(sbase + j) & 63u;
m = rdbyte(sbase + j + 4u) & 63u;
} else {
uint a = rdbyte(sbase + j + 4u);
uint b4 = rdbyte(sbase + j - 4u);
uint bj = rdbyte(sbase + j);
sc = (a & 0x0Fu) | ((b4 >> 6u) << 4u);
m = (a >> 4u) | ((bj >> 6u) << 4u);
}
}
void main() {
uint gid = gl_GlobalInvocationID.x;
uint total = nrows * n;
if (gid >= total) {
return;
}
uint s = gid / n;
uint r = gid - s * n;
uint expert = ids[s];
uint nsb = k / 256u;
uint rowbase = (expert * n + r) * nsb * 144u; // 144 B/super-block
uint xbase = s * k;
float acc = 0.0;
for (uint sb = 0u; sb < nsb; sb++) {
uint bb = rowbase + sb * 144u;
float d = rdscale(bb);
float dmin = rdscale(bb + 2u);
uint sbase = bb + 4u;
uint qbase = bb + 16u;
uint xsb = xbase + sb * 256u;
uint is = 0u;
for (uint g = 0u; g < 4u; g++) {
uint joff = g * 64u;
uint qoff = qbase + g * 32u;
uint sc1; uint m1; scale_min(sbase, is, sc1, m1);
uint sc2; uint m2; scale_min(sbase, is + 1u, sc2, m2);
float d1 = d * float(sc1); float mm1 = dmin * float(m1);
float d2 = d * float(sc2); float mm2 = dmin * float(m2);
for (uint l = 0u; l < 32u; l++) {
uint q = rdbyte(qoff + l);
acc += (d1 * float(q & 0x0Fu) - mm1) * x[xsb + joff + l];
acc += (d2 * float(q >> 4u) - mm2) * x[xsb + joff + 32u + l];
}
is += 2u;
}
}
y[gid] = acc;
}