#version 450
// Batched tiled matmul: C[bt] = A[bt](m x k) * B[bt](k x n), row-major, contiguous.
// 16x16 shared-memory tiles so each global element is read once per tile instead of
// once per output cell (cuts global memory traffic ~TILE x vs the naive kernel).
layout(local_size_x = 16, local_size_y = 16, local_size_z = 1) in;
layout(set = 0, binding = 0) readonly buffer A { float a[]; };
layout(set = 0, binding = 1) readonly buffer B { float b[]; };
layout(set = 0, binding = 2) writeonly buffer C { float c[]; };
layout(push_constant) uniform Pc { uint batch; uint m; uint k; uint n; };
const uint TILE = 16u;
shared float As[16][16];
shared float Bs[16][16];
void main() {
uint bt = gl_GlobalInvocationID.z;
uint row = gl_GlobalInvocationID.y;
uint col = gl_GlobalInvocationID.x;
uint lr = gl_LocalInvocationID.y;
uint lc = gl_LocalInvocationID.x;
uint ao = bt * m * k;
uint bo = bt * k * n;
uint co = bt * m * n;
float acc = 0.0;
uint ntiles = (k + TILE - 1u) / TILE;
for (uint t = 0u; t < ntiles; t++) {
uint acol = t * TILE + lc;
uint brow = t * TILE + lr;
As[lr][lc] = (row < m && acol < k) ? a[ao + row * k + acol] : 0.0;
Bs[lr][lc] = (brow < k && col < n) ? b[bo + brow * n + col] : 0.0;
barrier();
for (uint kk = 0u; kk < TILE; kk++) {
acc += As[lr][kk] * Bs[kk][lc];
}
barrier();
}
if (bt < batch && row < m && col < n) {
c[co + row * n + col] = acc;
}
}