#version 450
// Fused MoE grouped quant matvec (Q4_0): y[s, r] = sum_k W[ids[s], r, k] * x[s, k], reading the
// per-expert slice from a GGML Q4_0 weight bank [E, n, k] in VRAM. Router gather + per-expert GEMM
// in one dispatch (no CPU expert loop). Q4_0 block = 18 bytes = { f16 d ; u8 qs[16] }; low nibble
// of qs[j] -> weight j, high -> weight j+16, dequant (nibble-8)*d. One invocation = one y element.
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in;
layout(set = 0, binding = 0) readonly buffer W { uint w[]; }; // expert bank, raw Q4_0 bytes
layout(set = 0, binding = 1) readonly buffer X { float x[]; }; // [S, k] activations
layout(set = 0, binding = 2) readonly buffer Ids { uint ids[]; }; // [S] expert id per slot
layout(set = 0, binding = 3) writeonly buffer Y { float y[]; }; // [S, n] outputs
layout(push_constant) uniform Pc { uint n; uint k; uint nrows; }; // k mult of 32
uint rdbyte(uint bo) { return bitfieldExtract(w[bo >> 2u], int((bo & 3u) * 8u), 8); }
float rdscale(uint bo) {
uint lo = rdbyte(bo);
uint hi = rdbyte(bo + 1u);
return unpackHalf2x16(lo | (hi << 8u)).x;
}
void main() {
uint gid = gl_GlobalInvocationID.x;
uint total = nrows * n;
if (gid >= total) {
return;
}
uint s = gid / n;
uint r = gid - s * n;
uint expert = ids[s];
uint nblocks = k / 32u;
uint rowbase = (expert * n + r) * nblocks * 18u; // 18 B/block
uint xbase = s * k;
float acc = 0.0;
for (uint b = 0u; b < nblocks; b++) {
uint bb = rowbase + b * 18u;
float d = rdscale(bb);
uint qbase = bb + 2u;
uint xb = xbase + b * 32u;
float bsum = 0.0;
for (uint j = 0u; j < 16u; j++) {
uint q = rdbyte(qbase + j);
float x0 = float(int(q & 0x0Fu) - 8);
float x1 = float(int(q >> 4u) - 8);
bsum += x0 * x[xb + j];
bsum += x1 * x[xb + j + 16u];
}
acc += d * bsum;
}
y[gid] = acc;
}