#version 450
// Register-blocked coopmat matmul, NT layout: C(m x n,f32) = A(m x k,f16) * W(n x k,f16)^T.
// W is a Linear weight in natural row-major [n,k]; the B fragment loads it column-major (stride k)
// so no transpose copy is needed. A load and the rest match bmm_coopmat_rb.
// One subgroup computes an RM x RN grid of 16x16 output tiles, holding RM*RN accumulator
// fragments in registers. Per K-step it loads RM A-fragments + RN B-fragments and issues
// RM*RN MulAdds, so each loaded fragment is reused RN (resp. RM) times -- far higher
// arithmetic intensity than the naive 1-tile kernel (which reloads A/B for every tile).
// m, n, k must be multiples of 16 (caller guarantees); partial RM x RN grids at the matrix
// edge are handled by per-fragment bounds checks (the conditions are subgroup-uniform).
#extension GL_KHR_cooperative_matrix : require
#extension GL_KHR_memory_scope_semantics : require
#extension GL_KHR_shader_subgroup_basic : require
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
#extension GL_EXT_shader_16bit_storage : require
#extension GL_EXT_control_flow_attributes : enable
layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in;
layout(set = 0, binding = 0) readonly buffer A { float16_t a[]; };
layout(set = 0, binding = 1) readonly buffer B { float16_t b[]; };
layout(set = 0, binding = 2) writeonly buffer C { float c[]; };
layout(push_constant) uniform Pc { uint batch; uint m; uint k; uint n; };
const uint T = 16u;
const uint RM = 4u; // 16x16 output tiles per workgroup, row direction
const uint RN = 4u; // 16x16 output tiles per workgroup, col direction
void main() {
if (gl_SubgroupID != 0u) {
return; // only subgroup 0 works this workgroup's RM x RN tile grid
}
uint bt = gl_WorkGroupID.z;
uint trow0 = gl_WorkGroupID.y * RM; // first output tile row (units of 16)
uint tcol0 = gl_WorkGroupID.x * RN; // first output tile col (units of 16)
uint ao = bt * m * k;
uint bo = bt * k * n;
uint co = bt * m * n;
uint mt = m / T;
uint nt = n / T;
coopmat<float, gl_ScopeSubgroup, 16, 16, gl_MatrixUseAccumulator> acc[RM][RN];
[[unroll]] for (uint i = 0u; i < RM; i++)
[[unroll]] for (uint j = 0u; j < RN; j++)
acc[i][j] = coopmat<float, gl_ScopeSubgroup, 16, 16, gl_MatrixUseAccumulator>(0.0);
uint ktiles = k / T;
for (uint kt = 0u; kt < ktiles; kt++) {
coopmat<float16_t, gl_ScopeSubgroup, 16, 16, gl_MatrixUseA> ma[RM];
coopmat<float16_t, gl_ScopeSubgroup, 16, 16, gl_MatrixUseB> mb[RN];
[[unroll]] for (uint i = 0u; i < RM; i++)
if (trow0 + i < mt)
coopMatLoad(ma[i], a, ao + (trow0 + i) * T * k + kt * T, k, gl_CooperativeMatrixLayoutRowMajor);
[[unroll]] for (uint j = 0u; j < RN; j++)
if (tcol0 + j < nt)
coopMatLoad(mb[j], b, bo + (tcol0 + j) * T * k + kt * T, k, gl_CooperativeMatrixLayoutColumnMajor);
[[unroll]] for (uint i = 0u; i < RM; i++)
[[unroll]] for (uint j = 0u; j < RN; j++)
if (trow0 + i < mt && tcol0 + j < nt)
acc[i][j] = coopMatMulAdd(ma[i], mb[j], acc[i][j]);
}
[[unroll]] for (uint i = 0u; i < RM; i++)
[[unroll]] for (uint j = 0u; j < RN; j++)
if (trow0 + i < mt && tcol0 + j < nt)
coopMatStore(acc[i][j], c, co + (trow0 + i) * T * n + (tcol0 + j) * T, n, gl_CooperativeMatrixLayoutRowMajor);
}