#version 450
// Optimized tiled matrix multiplication compute shader
// Computes: C = A * B where A is MxK, B is KxN, C is MxN
// Uses shared memory tiling for maximum performance
layout(local_size_x = 16, local_size_y = 16, local_size_z = 1) in;
layout(set = 0, binding = 0) buffer InputBufferA {
float a[];
};
layout(set = 0, binding = 1) buffer InputBufferB {
float b[];
};
layout(set = 0, binding = 2) buffer OutputBuffer {
float result[];
};
layout(set = 0, binding = 3) uniform UniformBuffer {
uint M; // Rows of A, rows of C
uint N; // Cols of B, cols of C
uint K; // Cols of A, rows of B
};
// Shared memory tiles for cooperative loading
shared float tileA[16][16];
shared float tileB[16][16];
void main() {
// Global thread coordinates
uint globalRow = gl_GlobalInvocationID.y;
uint globalCol = gl_GlobalInvocationID.x;
// Local thread coordinates within workgroup
uint localRow = gl_LocalInvocationID.y;
uint localCol = gl_LocalInvocationID.x;
// Accumulated result for this thread
float sum = 0.0;
// Number of tiles needed to cover K dimension
uint numTiles = (K + 15) / 16;
// Process tiles across the K dimension
for (uint tile = 0; tile < numTiles; tile++) {
// Cooperative loading of tile A
uint aRow = globalRow;
uint aCol = tile * 16 + localCol;
// Load with bounds checking
if (aRow < M && aCol < K) {
tileA[localRow][localCol] = a[aRow * K + aCol];
} else {
tileA[localRow][localCol] = 0.0;
}
// Cooperative loading of tile B
uint bRow = tile * 16 + localRow;
uint bCol = globalCol;
// Load with bounds checking
if (bRow < K && bCol < N) {
tileB[localRow][localCol] = b[bRow * N + bCol];
} else {
tileB[localRow][localCol] = 0.0;
}
// Synchronize to ensure all threads have loaded their data
barrier();
// Compute partial dot product using shared memory
for (uint k = 0; k < 16; k++) {
sum += tileA[localRow][k] * tileB[k][localCol];
}
// Synchronize before loading next tile
barrier();
}
// Write final result with bounds checking
if (globalRow < M && globalCol < N) {
result[globalRow * N + globalCol] = sum;
}
}