#version 450
// Register-blocked batched GEMM: C[bt] = A[bt](m x k) * B[bt](k x n), row-major, contiguous.
// 2D block-tiling: each workgroup computes a BM x BN output tile; each of the 256 threads computes
// a TM x TN micro-tile held in registers. Reusing the staged A/B tiles across the micro-tile raises
// arithmetic intensity far above the 1-output-per-thread `bmm` kernel -> much higher prefill FLOPS.
#extension GL_EXT_control_flow_attributes : enable
layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
layout(set = 0, binding = 0) readonly buffer A { float a[]; };
layout(set = 0, binding = 1) readonly buffer B { float b[]; };
layout(set = 0, binding = 2) writeonly buffer C { float c[]; };
layout(push_constant) uniform Pc { uint batch; uint m; uint k; uint n; };
const uint BM = 64u;
const uint BN = 64u;
const uint BK = 8u;
const uint TM = 4u;
const uint TN = 4u;
// threads = (BM/TM) * (BN/TN) = 16 * 16 = 256 (matches local_size_x)
shared float As[BM * BK];
shared float Bs[BK * BN];
void main() {
uint bt = gl_WorkGroupID.z;
uint rowBase = gl_WorkGroupID.y * BM;
uint colBase = gl_WorkGroupID.x * BN;
uint tid = gl_LocalInvocationID.x;
uint threadCol = tid % (BN / TN); // 0..15
uint threadRow = tid / (BN / TN); // 0..15
uint ao = bt * m * k;
uint bo = bt * k * n;
uint co = bt * m * n;
float acc[TM][TN];
[[unroll]] for (uint i = 0u; i < TM; i++)
[[unroll]] for (uint j = 0u; j < TN; j++)
acc[i][j] = 0.0;
uint ntiles = (k + BK - 1u) / BK;
for (uint t = 0u; t < ntiles; t++) {
uint kBase = t * BK;
for (uint i = tid; i < BM * BK; i += 256u) {
uint r = i / BK;
uint cc = i % BK;
uint gr = rowBase + r;
uint gc = kBase + cc;
As[i] = (gr < m && gc < k) ? a[ao + gr * k + gc] : 0.0;
}
for (uint i = tid; i < BK * BN; i += 256u) {
uint r = i / BN;
uint cc = i % BN;
uint gr = kBase + r;
uint gc = colBase + cc;
Bs[i] = (gr < k && gc < n) ? b[bo + gr * n + gc] : 0.0;
}
barrier();
[[unroll]] for (uint kk = 0u; kk < BK; kk++) {
float aReg[TM];
float bReg[TN];
[[unroll]] for (uint i = 0u; i < TM; i++)
aReg[i] = As[(threadRow * TM + i) * BK + kk];
[[unroll]] for (uint j = 0u; j < TN; j++)
bReg[j] = Bs[kk * BN + threadCol * TN + j];
[[unroll]] for (uint i = 0u; i < TM; i++)
[[unroll]] for (uint j = 0u; j < TN; j++)
acc[i][j] += aReg[i] * bReg[j];
}
barrier();
}
if (bt >= batch) return;
[[unroll]] for (uint i = 0u; i < TM; i++) {
uint gr = rowBase + threadRow * TM + i;
if (gr >= m) continue;
[[unroll]] for (uint j = 0u; j < TN; j++) {
uint gc = colBase + threadCol * TN + j;
if (gc < n)
c[co + gr * n + gc] = acc[i][j];
}
}
}