#version 450
// Vector-matrix multiply: out[j] = sum_i(a[i] * W[i,j])
// a: [k], b: [k, n] (GGUF column-major: W[i,j] at index i + j*k), out: [n]
layout(local_size_x = 256) in;
layout(set = 0, binding = 0) readonly buffer InputA { float a[]; };
layout(set = 0, binding = 1) readonly buffer InputB { float b[]; };
layout(set = 0, binding = 2) writeonly buffer Output { float result[]; };
layout(push_constant) uniform Params {
int k; // inner dimension
int n; // output dimension
};
void main() {
uint j = gl_GlobalInvocationID.x;
if (j >= n) return;
float acc = 0.0;
for (int i = 0; i < k; i++) {
acc += a[i] * b[i + j * k];
}
result[j] = acc;
}