#version 450
// Fused matrix multiplication + bias addition + ReLU activation
// Computes: ReLU(A * B + bias)
// This fuses three operations into one for maximum efficiency
layout(local_size_x = 16, local_size_y = 16, local_size_z = 1) in;
layout(set = 0, binding = 0) buffer InputBufferA {
float a[];
};
layout(set = 0, binding = 1) buffer InputBufferB {
float b[];
};
layout(set = 0, binding = 2) buffer BiasBuffer {
float bias[];
};
layout(set = 0, binding = 3) buffer OutputBuffer {
float result[];
};
layout(set = 0, binding = 4) uniform UniformBuffer {
uint M; // Rows of A, rows of C
uint N; // Cols of B, cols of C
uint K; // Cols of A, rows of B
uint bias_size; // Size of bias vector (should equal N)
};
// Shared memory tiles for cooperative loading
shared float tileA[16][16];
shared float tileB[16][16];
shared float tileBias[16]; // Shared bias values
void main() {
// Global thread coordinates
uint globalRow = gl_GlobalInvocationID.y;
uint globalCol = gl_GlobalInvocationID.x;
// Local thread coordinates within workgroup
uint localRow = gl_LocalInvocationID.y;
uint localCol = gl_LocalInvocationID.x;
// Load bias values cooperatively (only need to do this once per workgroup)
if (localRow == 0 && globalCol < N) {
tileBias[localCol] = (globalCol < bias_size) ? bias[globalCol] : 0.0;
}
// Accumulated result for matrix multiplication
float matmul_sum = 0.0;
// Number of tiles needed to cover K dimension
uint numTiles = (K + 15) / 16;
// Process tiles across the K dimension for matrix multiplication
for (uint tile = 0; tile < numTiles; tile++) {
// Cooperative loading of tile A
uint aRow = globalRow;
uint aCol = tile * 16 + localCol;
if (aRow < M && aCol < K) {
tileA[localRow][localCol] = a[aRow * K + aCol];
} else {
tileA[localRow][localCol] = 0.0;
}
// Cooperative loading of tile B
uint bRow = tile * 16 + localRow;
uint bCol = globalCol;
if (bRow < K && bCol < N) {
tileB[localRow][localCol] = b[bRow * N + bCol];
} else {
tileB[localRow][localCol] = 0.0;
}
// Synchronize to ensure all threads have loaded their data
barrier();
// Compute partial dot product using shared memory
for (uint k = 0; k < 16; k++) {
matmul_sum += tileA[localRow][k] * tileB[k][localCol];
}
// Synchronize before loading next tile
barrier();
}
// Fused bias addition and ReLU activation
if (globalRow < M && globalCol < N) {
// Add bias (broadcast across rows)
float biased_result = matmul_sum + tileBias[localCol];
// Apply ReLU activation: max(0, x)
float activated_result = max(0.0, biased_result);
// Store final result
result[globalRow * N + globalCol] = activated_result;
}
}