#version 450
// Naive row-major matmul C[MxN] = A[MxK] * B[KxN]. Correctness-first; a tiled +
// cooperative-matrix variant comes later for perf. One invocation per output cell.
layout(local_size_x = 16, local_size_y = 16) in;
layout(set = 0, binding = 0) readonly buffer A { float a[]; };
layout(set = 0, binding = 1) readonly buffer B { float b[]; };
layout(set = 0, binding = 2) writeonly buffer C { float c[]; };
layout(push_constant) uniform Pc { uint m; uint k; uint n; };
void main() {
uint row = gl_GlobalInvocationID.y;
uint col = gl_GlobalInvocationID.x;
if (row < m && col < n) {
float acc = 0.0;
for (uint i = 0u; i < k; i++) {
acc += a[row * k + i] * b[i * n + col];
}
c[row * n + col] = acc;
}
}