#version 450
// Tiled matrix multiply: C[i,j] = sum_k A[i,k] * B[k,j]
// A: [m, K] row-major, B: [K, n] row-major, C: [m, n] row-major
// Uses shared memory tiles for cache efficiency.
layout(local_size_x = 16, local_size_y = 16) in;
layout(set = 0, binding = 0) readonly buffer InputA { float a[]; };
layout(set = 0, binding = 1) readonly buffer InputB { float b[]; };
layout(set = 0, binding = 2) writeonly buffer Output { float result[]; };
layout(push_constant) uniform Params {
int m; // rows of A / rows of C
int K; // cols of A / rows of B
int n; // cols of B / cols of C
};
shared float tileA[16][16];
shared float tileB[16][16];
void main() {
uint row = gl_GlobalInvocationID.y;
uint col = gl_GlobalInvocationID.x;
uint localRow = gl_LocalInvocationID.y;
uint localCol = gl_LocalInvocationID.x;
float acc = 0.0;
uint numTiles = (uint(K) + 15u) / 16u;
for (uint t = 0u; t < numTiles; t++) {
uint aCol = t * 16u + localCol;
uint bRow = t * 16u + localRow;
tileA[localRow][localCol] = (row < uint(m) && aCol < uint(K))
? a[row * uint(K) + aCol] : 0.0;
tileB[localRow][localCol] = (bRow < uint(K) && col < uint(n))
? b[bRow * uint(n) + col] : 0.0;
barrier();
for (uint i = 0u; i < 16u; i++) {
acc += tileA[localRow][i] * tileB[i][localCol];
}
barrier();
}
if (row < uint(m) && col < uint(n)) {
result[row * uint(n) + col] = acc;
}
}