#version 450
// Batched matmul on the matrix cores via VK_KHR_cooperative_matrix.
// C[bt] (m x n, f32) = A[bt] (m x k, f16) * B[bt] (k x n, f16), row-major contiguous.
// One subgroup computes one 16x16 output tile, looping over K in 16-wide steps. Requires
// m, n, k to be multiples of 16 (the caller falls back to the tiled kernel otherwise).
#extension GL_KHR_cooperative_matrix : require
#extension GL_KHR_memory_scope_semantics : require
#extension GL_KHR_shader_subgroup_basic : require
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
#extension GL_EXT_shader_16bit_storage : require
// 64 lanes covers an AMD wave (wave32 -> 2 subgroups, wave64 -> 1); only subgroup 0 works the
// tile, so this is correct for either subgroup size without a specialization constant.
layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in;
layout(set = 0, binding = 0) readonly buffer A { float16_t a[]; };
layout(set = 0, binding = 1) readonly buffer B { float16_t b[]; };
layout(set = 0, binding = 2) writeonly buffer C { float c[]; };
layout(push_constant) uniform Pc { uint batch; uint m; uint k; uint n; };
const uint T = 16u;
void main() {
if (gl_SubgroupID != 0u) {
return; // uniform per subgroup: only subgroup 0 computes/stores this workgroup's tile
}
uint bt = gl_WorkGroupID.z;
uint trow = gl_WorkGroupID.y; // output tile row (units of 16)
uint tcol = gl_WorkGroupID.x; // output tile col (units of 16)
uint ao = bt * m * k;
uint bo = bt * k * n;
uint co = bt * m * n;
coopmat<float, gl_ScopeSubgroup, 16, 16, gl_MatrixUseAccumulator> acc =
coopmat<float, gl_ScopeSubgroup, 16, 16, gl_MatrixUseAccumulator>(0.0);
uint ktiles = k / T;
for (uint kt = 0u; kt < ktiles; kt++) {
coopmat<float16_t, gl_ScopeSubgroup, 16, 16, gl_MatrixUseA> ma;
coopmat<float16_t, gl_ScopeSubgroup, 16, 16, gl_MatrixUseB> mb;
coopMatLoad(ma, a, ao + (trow * T) * k + kt * T, k, gl_CooperativeMatrixLayoutRowMajor);
coopMatLoad(mb, b, bo + (kt * T) * n + tcol * T, n, gl_CooperativeMatrixLayoutRowMajor);
acc = coopMatMulAdd(ma, mb, acc);
}
coopMatStore(acc, c, co + (trow * T) * n + tcol * T, n, gl_CooperativeMatrixLayoutRowMajor);
}