struct PushConstants
{
uint m;
uint k;
uint n;
uint stride_a0;
uint stride_a1;
uint stride_b0;
uint stride_b1;
uint stride_y0;
uint stride_y1;
uint trans_a;
uint trans_b;
float alpha;
float beta;
uint has_c;
}
[[vk::push_constant]]
PushConstants pc;
[shader("compute")]
[numthreads(16, 16, 1)]
void main<T : IArithmetic>(
StructuredBuffer<T> a,
StructuredBuffer<T> b,
StructuredBuffer<T> c,
RWStructuredBuffer<T> y,
uint3 groupId: SV_GroupID, uint3 threadId: SV_GroupThreadID)
{
uint tile_row = groupId.y;
uint tile_col = groupId.x;
uint thread_row = threadId.y;
uint thread_col = threadId.x;
uint out_row = tile_row * 16 + thread_row;
uint out_col = tile_col * 16 + thread_col;
T acc = T(0);
const uint TILE_SIZE = 16;
uint stride_a_k = (pc.trans_a != 0) ? pc.stride_a0 : pc.stride_a1;
uint stride_a_m = (pc.trans_a != 0) ? pc.stride_a1 : pc.stride_a0;
uint stride_b_k = (pc.trans_b != 0) ? pc.stride_b1 : pc.stride_b0;
uint stride_b_n = (pc.trans_b != 0) ? pc.stride_b0 : pc.stride_b1;
uint num_k_tiles = (pc.k + TILE_SIZE - 1) / TILE_SIZE;
if (num_k_tiles == 0)
{
if (out_row < pc.m && out_col < pc.n)
{
T result = T(pc.alpha) * T(0);
if (pc.has_c != 0)
{
uint c_idx = out_row * pc.n + out_col;
result = result + T(pc.beta) * c[c_idx];
}
uint y_idx = out_row * pc.stride_y0 + out_col * pc.stride_y1;
y[y_idx] = result;
}
return;
}
static groupshared T tileA[16 * 16];
static groupshared T tileB[16 * 16];
uint idx = thread_row * TILE_SIZE + thread_col;
for (uint k_tile = 0; k_tile < num_k_tiles; k_tile++)
{
uint k_offset = k_tile * TILE_SIZE;
uint col_a = k_offset + thread_col;
uint row_b = k_offset + thread_row;
tileA[idx] = (out_row < pc.m && col_a < pc.k)
? a[out_row * stride_a_m + col_a * stride_a_k]
: T(0);
tileB[idx] = (row_b < pc.k && out_col < pc.n)
? b[row_b * stride_b_k + out_col * stride_b_n]
: T(0);
GroupMemoryBarrierWithGroupSync();
for (uint ki = 0; ki < TILE_SIZE; ki++)
{
acc = acc + tileA[thread_row * TILE_SIZE + ki] * tileB[ki * TILE_SIZE + thread_col];
}
GroupMemoryBarrierWithGroupSync();
}
if (out_row < pc.m && out_col < pc.n)
{
T result = T(pc.alpha) * acc;
if (pc.has_c != 0)
{
uint c_idx = out_row * pc.n + out_col;
result = result + T(pc.beta) * c[c_idx];
}
uint y_idx = out_row * pc.stride_y0 + out_col * pc.stride_y1;
y[y_idx] = result;
}
}