#version 450
// Register-blocked batched GEMM, NT layout: C[bt](m x n) = A[bt](m x k) * W[bt](n x k)^T.
// W is a Linear weight in its natural row-major [n,k] layout, so no transpose copy is needed
// (the staged B tile reads Bs[k,n] = W[n,k]). Each thread holds a TM x TN register micro-tile.
#extension GL_EXT_control_flow_attributes : enable
layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
layout(set = 0, binding = 0) readonly buffer A { float a[]; };
layout(set = 0, binding = 1) readonly buffer B { float b[]; };
layout(set = 0, binding = 2) writeonly buffer C { float c[]; };
layout(push_constant) uniform Pc { uint batch; uint m; uint k; uint n; };
const uint BM = 64u;
const uint BN = 64u;
const uint BK = 8u;
const uint TM = 4u;
const uint TN = 4u;
// threads = (BM/TM) * (BN/TN) = 16 * 16 = 256 (matches local_size_x)
shared float As[BM * BK];
shared float Bs[BK * BN];
void main() {
uint bt = gl_WorkGroupID.z;
uint rowBase = gl_WorkGroupID.y * BM;
uint colBase = gl_WorkGroupID.x * BN;
uint tid = gl_LocalInvocationID.x;
uint threadCol = tid % (BN / TN); // 0..15
uint threadRow = tid / (BN / TN); // 0..15
uint ao = bt * m * k;
uint bo = bt * k * n;
uint co = bt * m * n;
float acc[TM][TN];
[[unroll]] for (uint i = 0u; i < TM; i++)
[[unroll]] for (uint j = 0u; j < TN; j++)
acc[i][j] = 0.0;
uint ntiles = (k + BK - 1u) / BK;
for (uint t = 0u; t < ntiles; t++) {
uint kBase = t * BK;
for (uint i = tid; i < BM * BK; i += 256u) {
uint r = i / BK;
uint cc = i % BK;
uint gr = rowBase + r;
uint gc = kBase + cc;
As[i] = (gr < m && gc < k) ? a[ao + gr * k + gc] : 0.0;
}
for (uint i = tid; i < BK * BN; i += 256u) {
uint r = i / BN;
uint cc = i % BN;
uint gr = kBase + r;
uint gc = colBase + cc;
Bs[i] = (gr < k && gc < n) ? b[bo + gc * k + gr] : 0.0;
}
barrier();
[[unroll]] for (uint kk = 0u; kk < BK; kk++) {
float aReg[TM];
float bReg[TN];
[[unroll]] for (uint i = 0u; i < TM; i++)
aReg[i] = As[(threadRow * TM + i) * BK + kk];
[[unroll]] for (uint j = 0u; j < TN; j++)
bReg[j] = Bs[kk * BN + threadCol * TN + j];
[[unroll]] for (uint i = 0u; i < TM; i++)
[[unroll]] for (uint j = 0u; j < TN; j++)
acc[i][j] += aReg[i] * bReg[j];
}
barrier();
}
if (bt >= batch) return;
[[unroll]] for (uint i = 0u; i < TM; i++) {
uint gr = rowBase + threadRow * TM + i;
if (gr >= m) continue;
[[unroll]] for (uint j = 0u; j < TN; j++) {
uint gc = colBase + threadCol * TN + j;
if (gc < n)
c[co + gr * n + gc] = acc[i][j];
}
}
}