#version 450
// Matrix-vector multiply: out[i] = sum_j A[i,j] * x[j]
// A: [m, n] row-major, x: [n], out: [m]
// One thread per output row.
layout(local_size_x = 256) in;
layout(set = 0, binding = 0) readonly buffer InputA { float a[]; };
layout(set = 0, binding = 1) readonly buffer InputX { float x[]; };
layout(set = 0, binding = 2) writeonly buffer Output { float result[]; };
layout(push_constant) uniform Params {
int m; // rows of A / output dimension
int n; // cols of A / input dimension
};
void main() {
uint i = gl_GlobalInvocationID.x;
if (i >= uint(m)) return;
float acc = 0.0;
uint row_base = i * uint(n);
for (int j = 0; j < n; j++) {
acc += a[row_base + uint(j)] * x[j];
}
result[i] = acc;
}