#version 450
// Q8_0 matrix-vector product (the decode/memory-bound path): y[n] = sum_k W[n,k]*x[k] with W
// stored quantized. Each row is K/32 blocks; per block = 1 fp16 scale + 32 int8, packed into
// 9 u32 (u32[0] = scale in low f16 via packHalf2x16; u32[1..8] = 32 int8, 4 per word). Reading
// ~1.125 B/weight instead of 4 cuts memory traffic ~3.5x — that is the lever on this ~256 GB/s
// APU, where decode is bandwidth-bound. One invocation computes one output element.
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in;
layout(set = 0, binding = 0) readonly buffer W { uint w[]; }; // quantized weights, 9 u32 / 32-block
layout(set = 0, binding = 1) readonly buffer X { float x[]; }; // activation vector, length k
layout(set = 0, binding = 2) writeonly buffer Y { float y[]; }; // output, length nout
layout(push_constant) uniform Pc { uint nout; uint k; }; // k is a multiple of 32
void main() {
uint n = gl_GlobalInvocationID.x;
if (n >= nout) {
return;
}
uint nblocks = k / 32u;
uint base = n * nblocks * 9u; // u32 offset of row n
float acc = 0.0;
for (uint b = 0u; b < nblocks; b++) {
uint off = base + b * 9u;
float scale = unpackHalf2x16(w[off]).x;
float bsum = 0.0;
uint xb = b * 32u;
for (uint j = 0u; j < 8u; j++) {
uint word = w[off + 1u + j];
uint xo = xb + j * 4u;
// bitfieldExtract on a signed int sign-extends the 8-bit lane.
bsum += float(bitfieldExtract(int(word), 0, 8)) * x[xo + 0u];
bsum += float(bitfieldExtract(int(word), 8, 8)) * x[xo + 1u];
bsum += float(bitfieldExtract(int(word), 16, 8)) * x[xo + 2u];
bsum += float(bitfieldExtract(int(word), 24, 8)) * x[xo + 3u];
}
acc += scale * bsum;
}
y[n] = acc;
}