#version 450
// Batched matmul C[b,m,n] = sum_k A[b,m,k] * B[b,k,n]. One output element per
// invocation. This maximal-occupancy mapping (m*n threads) is what Apple GPUs
// want — they hide global-memory latency with thread parallelism, so shared-
// memory tiling and register blocking both *regress* here (they cut occupancy
// and add barrier/register pressure that MoltenVK translates poorly). The
// throughput ceiling above this is the Metal matrix units (simdgroup_matmul),
// which need VK_KHR_cooperative_matrix — absent in current MoltenVK. Batch
// strides allow broadcasting a non-batched operand (stride 0).
//
// Shape: fully general — each invocation bounds-checks its (row, col, batch),
// so any M/N/K/batch is correct. This is the portability-driver (MoltenVK)
// fallback; native drivers use the tiled fp32 / cooperative-matrix kernels.
layout(local_size_x = 16, local_size_y = 16) in;
layout(std430, binding = 0) buffer Arena { float data[]; };
layout(push_constant) uniform PC {
uint m;
uint k;
uint n;
uint a_off;
uint b_off;
uint c_off;
uint batch;
uint a_bs; // A batch stride (0 ⇒ broadcast)
uint b_bs; // B batch stride (0 ⇒ broadcast)
uint c_bs; // C batch stride
} pc;
void main() {
uint col = gl_GlobalInvocationID.x; // n
uint row = gl_GlobalInvocationID.y; // m
uint bz = gl_GlobalInvocationID.z; // batch
if (col >= pc.n || row >= pc.m || bz >= pc.batch) { return; }
uint a_base = pc.a_off + bz * pc.a_bs + row * pc.k;
uint b_base = pc.b_off + bz * pc.b_bs + col;
float acc = 0.0;
for (uint kk = 0u; kk < pc.k; kk++) {
acc += data[a_base + kk] * data[b_base + kk * pc.n];
}
data[pc.c_off + bz * pc.c_bs + row * pc.n + col] = acc;
}